一、安装

git clone https://github.com/westny/dronalize.git

cd dronalize

conda env create -f build/environment.yml

二、数据准备inD

把inD数据集合放到../datasets里, 并在dronalize项目的configs增加inD配置

inD配置可以直接复制rounD的配置, 只需把example.yml文件里的dataset改为inD

三、数据预处理

python -m preprocessing.preprocess_urban --config 'inD' --use-threads $2 --add-supp 0

预处理完后,会在data目录生成数据集

四、训练

python train.py --add-name Test --dry-run 0 --use-cuda 1 --num-workers 4 --use-logger

这里需要注意, 程序会遍历configs里的配置, 要确保每个配置里的dataset已经预处理好

--pre-train:

python train.py --add-name Test --dry-run 0 --use-cuda 1 --num-workers 4 --use-logger 1 --pre-train inD/Net-Test-inD

这个是在训练好的模型基础上继续训练

五、测试

修改:models/prototype/litmodule.py

增加:

    def test_step(self, data: HeteroData, *args) -> None:
        ma_mask = data['agent']['ma_mask']
        ptr = data['agent']['ptr']

        loss, pred, trg = self(data)

        self.min_ade.update(pred, trg, mask=ma_mask)
        self.min_fde.update(pred, trg, mask=ma_mask)
        self.min_apde.update(pred, trg, mask=ma_mask)
        self.mr.update(pred, trg, mask=ma_mask)

        metric_dict = {"val_loss": loss,
                       "val_min_ade": self.min_ade,
                       "val_min_fde": self.min_fde,
                       "val_min_apde": self.min_apde,
                       "val_mr": self.mr}

        self.log_dict(metric_dict, on_step=False, on_epoch=True,
                      batch_size=trg.size(0), prog_bar=True) 

python test.py --add-name Test --use-cuda 1 --num-workers 4 --use-logger 1

这里需要注意, 程序会遍历configs里的配置, 要确保每个配置里的dataset已经预处理好