一、安装
git clone https://github.com/westny/dronalize.git
cd dronalize
conda env create -f build/environment.yml
二、数据准备inD
把inD数据集合放到../datasets里, 并在dronalize项目的configs增加inD配置
inD配置可以直接复制rounD的配置, 只需把example.yml文件里的dataset改为inD
三、数据预处理
python -m preprocessing.preprocess_urban --config 'inD' --use-threads $2 --add-supp 0
预处理完后,会在data目录生成数据集
四、训练
python train.py --add-name Test --dry-run 0 --use-cuda 1 --num-workers 4 --use-logger
这里需要注意, 程序会遍历configs里的配置, 要确保每个配置里的dataset已经预处理好
--pre-train:
python train.py --add-name Test --dry-run 0 --use-cuda 1 --num-workers 4 --use-logger 1 --pre-train inD/Net-Test-inD
这个是在训练好的模型基础上继续训练
五、测试
修改:models/prototype/litmodule.py
增加:
def test_step(self, data: HeteroData, *args) -> None:
ma_mask = data['agent']['ma_mask']
ptr = data['agent']['ptr']
loss, pred, trg = self(data)
self.min_ade.update(pred, trg, mask=ma_mask)
self.min_fde.update(pred, trg, mask=ma_mask)
self.min_apde.update(pred, trg, mask=ma_mask)
self.mr.update(pred, trg, mask=ma_mask)
metric_dict = {"val_loss": loss,
"val_min_ade": self.min_ade,
"val_min_fde": self.min_fde,
"val_min_apde": self.min_apde,
"val_mr": self.mr}
self.log_dict(metric_dict, on_step=False, on_epoch=True,
batch_size=trg.size(0), prog_bar=True)
python test.py --add-name Test --use-cuda 1 --num-workers 4 --use-logger 1
这里需要注意, 程序会遍历configs里的配置, 要确保每个配置里的dataset已经预处理好