Skip to content

This is MINIMAL implementations of various positional encoding method for the Transformers, It is very well-suited for learning purposes.

License

Notifications You must be signed in to change notification settings

ethanncai/SimplePositionalEncoding

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SimplePositionalEncoding

How to play with the code?

环境、数据集要求

由于环境简单,因此没有写requirement.txt,需要的只是简单的torch环境,下面的命令供参考

pip3 install torch torchvision tensorboard

对于数据集,这次采用的是已经在torch里面implement过的cifar-10,运行训练代码的时候会自动下载,其大小仅为100Mb+ 因此不作额外安装介绍

总体情况

本文提供三种给Transformer进行位置编码的方式,它们是:

  • Sin 绝对位置编码
  • RoPE 相对位置编码
  • 可学习的参数位置编码(绝对位置编码)

对于这三种位置编码,本项目提供了一个使用Vanilla ViT训练一个对于Cifar-10的分类模型的实验,这个实验implement了三种不同的位置编码模式,用户可以通过这份代码学习不同的位置编码的实际使用 方式以及可以对不同位置编码的效果进行对比

Getting started

对于RoPE,其复杂度较高,我们提供了单独测试的最小样例代码,可以直接运行对应脚本观察其对位置编码的“相对性”,运行:

python rope.py

然后便是主要的训练了,可以用下面三种不同的位置编码方式开启一个ViT的训练

# parameter position encoding
python train.py --pe_method param

# rope position encoding
python train.py --pe_method rope

# sin position encoding
python train.py --pe_method abs

然后可以使用tensorboard 查看保存在 runs 目录的训练信息

这是完整的参数列表

usage: train_vit.py [-h] [--pe_method {param,abs,rope}] [--batch_size BATCH_SIZE] [--num_workers NUM_WORKERS] [--image_size IMAGE_SIZE]
                    [--patch_size PATCH_SIZE] [--num_classes NUM_CLASSES] [--dim DIM] [--depth DEPTH] [--heads HEADS] [--mlp_dim MLP_DIM]
                    [--dropout DROPOUT] [--emb_dropout EMB_DROPOUT] [--num_epochs NUM_EPOCHS] [--lr LR] [--T_max T_MAX] [--log_freq LOG_FREQ]

Train ViT on CIFAR-10 with configurable PE method.

options:
  -h, --help            show this help message and exit
  --pe_method {param,abs,rope}
                        Positional Encoding method: 'param', 'abs', or 'rope'
  --batch_size BATCH_SIZE
                        Batch size for training and testing.
  --num_workers NUM_WORKERS
                        Number of workers for data loading.
  --image_size IMAGE_SIZE
                        Input image size.
  --patch_size PATCH_SIZE
                        Patch size for ViT.
  --num_classes NUM_CLASSES
                        Number of output classes.
  --dim DIM             Embedding dimension.
  --depth DEPTH         Number of transformer blocks.
  --heads HEADS         Number of attention heads.
  --mlp_dim MLP_DIM     MLP hidden dimension.
  --dropout DROPOUT     Dropout rate.
  --emb_dropout EMB_DROPOUT
                        Embedding dropout rate.
  --num_epochs NUM_EPOCHS
                        Total number of training epochs.
  --lr LR               Learning rate.
  --T_max T_MAX         Cosine annealing T_max.
  --log_freq LOG_FREQ   Log frequency (steps) for TensorBoard

番外

脚本 tiny_cases/how_complex_rotate.py探讨了复数乘法和二维旋转的关系,这是读懂RoPE代码非常重要的一个概念,强烈建议学习。

Reference

ViT 实现 Copied from this Repo, Thank you!

About

This is MINIMAL implementations of various positional encoding method for the Transformers, It is very well-suited for learning purposes.

Resources

License

Code of conduct

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages