由于环境简单,因此没有写requirement.txt,需要的只是简单的torch环境,下面的命令供参考
pip3 install torch torchvision tensorboard
对于数据集,这次采用的是已经在torch里面implement过的cifar-10,运行训练代码的时候会自动下载,其大小仅为100Mb+ 因此不作额外安装介绍
本文提供三种给Transformer进行位置编码的方式,它们是:
- Sin 绝对位置编码
- RoPE 相对位置编码
- 可学习的参数位置编码(绝对位置编码)
对于这三种位置编码,本项目提供了一个使用Vanilla ViT训练一个对于Cifar-10的分类模型的实验,这个实验implement了三种不同的位置编码模式,用户可以通过这份代码学习不同的位置编码的实际使用 方式以及可以对不同位置编码的效果进行对比
对于RoPE,其复杂度较高,我们提供了单独测试的最小样例代码,可以直接运行对应脚本观察其对位置编码的“相对性”,运行:
python rope.py
然后便是主要的训练了,可以用下面三种不同的位置编码方式开启一个ViT的训练
# parameter position encoding
python train.py --pe_method param
# rope position encoding
python train.py --pe_method rope
# sin position encoding
python train.py --pe_method abs
然后可以使用tensorboard 查看保存在 runs 目录的训练信息
这是完整的参数列表
usage: train_vit.py [-h] [--pe_method {param,abs,rope}] [--batch_size BATCH_SIZE] [--num_workers NUM_WORKERS] [--image_size IMAGE_SIZE]
[--patch_size PATCH_SIZE] [--num_classes NUM_CLASSES] [--dim DIM] [--depth DEPTH] [--heads HEADS] [--mlp_dim MLP_DIM]
[--dropout DROPOUT] [--emb_dropout EMB_DROPOUT] [--num_epochs NUM_EPOCHS] [--lr LR] [--T_max T_MAX] [--log_freq LOG_FREQ]
Train ViT on CIFAR-10 with configurable PE method.
options:
-h, --help show this help message and exit
--pe_method {param,abs,rope}
Positional Encoding method: 'param', 'abs', or 'rope'
--batch_size BATCH_SIZE
Batch size for training and testing.
--num_workers NUM_WORKERS
Number of workers for data loading.
--image_size IMAGE_SIZE
Input image size.
--patch_size PATCH_SIZE
Patch size for ViT.
--num_classes NUM_CLASSES
Number of output classes.
--dim DIM Embedding dimension.
--depth DEPTH Number of transformer blocks.
--heads HEADS Number of attention heads.
--mlp_dim MLP_DIM MLP hidden dimension.
--dropout DROPOUT Dropout rate.
--emb_dropout EMB_DROPOUT
Embedding dropout rate.
--num_epochs NUM_EPOCHS
Total number of training epochs.
--lr LR Learning rate.
--T_max T_MAX Cosine annealing T_max.
--log_freq LOG_FREQ Log frequency (steps) for TensorBoard
脚本 tiny_cases/how_complex_rotate.py
探讨了复数乘法和二维旋转的关系,这是读懂RoPE代码非常重要的一个概念,强烈建议学习。
ViT 实现 Copied from this Repo, Thank you!