Training dynamic Conformer models for Automatic Speech Recognition (ASR) using early-exiting training techniques.
Paper
Find our original paper on early-exit training, Training dynamic models using early exits for automatic speech recognition on resource-constrained devices
, on arXiv.
Also, our more recent work, Splitformer: An improved early-exit architecture for automatic speech recognition on edge devices
is also available on arXiv, with pretrained model weights on Hugging Face.
Acknowledgements
Incorporates code from Transformer PyTorch implementation by Hyunwoong Ko and SentencePiece unsupervised tokenizer.
Basic usage
Description | Command |
---|---|
Training an Attention Encoder-Decoder-based model | train.py --decoder_mode aed |
Training a CTC-based model | train.py --decoder_mode ctc --model_type model_name |
Inference with an Attention Encoder-Decoder-based model | inference.py --decoder_mode aed --load_model_path /path/to/model |
Inference with a CTC-based model | inference.py --decoder_mode ctc --model_type model_name --load_model_path /path/to/model |
Advanced usage examples
Description | Command |
---|---|
Training an AED-based model with 6 exits, one placed every 3 layers, for a total of 18 layers | train.py --decoder_mode aed --n_enc_exits 6 --n_enc_layers_per_exit 3 |
Training a CTC-based model for 75 epochs with an initial learning rate of 1e-6. The model is initialized from a pre-trained model checkpoint found at the given path | train.py --decoder_mode ctc --model_type model_name --n_epoch 75 --init_lr 1e-6 --load_model_path /path/to/model |
Inference with an AED-based architecture, based on the average of model checkpoints from epochs 95 through 100 found in the directory at the given path | inference.py --decoder_mode aed --load_model_dir /path/to/dir --avg_model_start 95 --avg_model_end 100 |
See below for additional configuration options.
Note: SentencePiece is used to tokenize target labels.
Training setup and options
Variable | Default value | Description |
---|---|---|
--decoder_mode |
-- | Required: Whether to use a connectionist temporal classification-based (ctc ) or attention encoder-decoder-based (aed ) decoder |
--model_type |
early_conformer |
Choose the model to use: early_conformer , splitformer or early_zipformer (Only for ctc decoder) |
--n_epochs |
10000 |
Number of training epochs |
--n_threads |
10 |
Number of threads for intraop parallelism on CPU. See PyTorch torch.set_num_threads method |
--n_workers |
10 |
Number of GPU workers for loading data |
--shuffle |
True |
Shuffles training data upon loading |
--save_model_dir |
/trained_model |
Directory in which to save model checkpoints |
--load_model_path |
None |
Path to model checkpoint to load for training/inference |
--load_model_dir |
None |
Directory containing models checkpoints for model averaging |
--avg_model_start |
None |
Starting epoch for model averaging |
--avg_model_end |
None |
End epoch for model averaging |
-
Note 1: In addition to the specified number of conformers and layers per conformer, the
splitformer
model automatically includes one extra parallel downsampled layer (a conformer with a single layer) before the first and last exits. Which adds a total of 2 extra layers compared to theearly_conformer
model with the same parameters. -
Note 2: Note 2: When using the
Early_zipformer
(the U-Net modified model), due to its unique architecture, the parameters must be--n_enc_exits 19
and--n_enc_layers_per_exit 1
. You can only change these parameters by adjusting the model's structure.
Model parameters
Variable | Default value | Description |
---|---|---|
--batch_size |
64 |
Batch size during training and inference |
--n_batch_split |
4 |
In each batch, items are ordered by length and split into this number of sub-batches, in order to minimize padding and maximize GPU performance |
--max_len |
2000 |
Maximum length in terms of number of characters for model inputs |
--d_model |
256 |
Dimensionality of the model |
--n_enc_layers_per_exit |
2 |
Number of encoder layers per exit (where number of exits is determined by --n_enc_exits). For example, --n_enc_layers_per_exit=2 and --n_enc_exits=6 results in a encoder with 6 exits and 12 total layers, with an exit occurring every 2 layers |
--n_enc_exits |
6 |
Number of exits in the model (where number of layers per exit is determined by --n_enc_layers_per_exit). For example, --n_enc_layers_per_exit=2 and --n_enc_exits=6 results in a encoder with 6 exits and 12 total layers, with an exit occurring every 2 layers |
--n_dec_layers |
6 |
Number of decoder layers in each exit in the encoder |
--n_heads |
6 |
Number of attention heads in each encoder layer |
--d_feed_forward |
2048 |
Dimensionality of the feed-forward network |
--drop_prob |
0.1 |
Probability of a given element of the input to be randomly dropped during training |
--depthwise_kernel_size |
31 |
Kernel size of the depthwise convolutions in each Conformer block |
--max_utterance_length |
360 |
Input items longer than this value in terms of number of labels will be dropped during training |
--aed_ce_weight |
0.7 |
For AED models: weight coefficient for the cross-entropy loss. |
--aed_ctc_weight |
0.3 |
For AED models: weight coefficient for the CTC loss. |
Audio preprocessing
Variable | Default value | Description |
---|---|---|
--sample_rate |
16000 |
Sample rate used in preprocessing raw audio inputs to the model |
--n_fft |
512 |
Size of Fast Fourier Transform used to generate spectrogram of raw audio input during preprocessing |
--win_length |
320 |
Window length used to generate spectrogram of raw audio input during preprocessing |
--hop_length |
160 |
Length of hop between STFT windows used to generate spectrogram of raw audio input during preprocessing |
--n_mels |
80 |
Number of mel filterbanks used to compute STFT of raw audio input during preprocessing |
Optimization
Variable | Default value | Description |
---|---|---|
--init_lr |
1e-5 |
Initial learning rate during training |
--adam_eps |
1e-9 |
Epsilon parameter used in AdamW optimization algorithm |
--weight_decay |
5e-4 |
Weight decay coefficient used in AdamW optimization algorithm |
--warmup |
-1 |
Number of learning rate warmup steps. Default behavior (-1): Warmup for the length of the dataloader. |
--clip |
1.0 |
Gradient norms higher than this value will be clipped during training. See PyTorch torch.nn.utils.clip_grad_norm_ function |
Inference parameters
Variable | Default value | Description |
---|---|---|
--beam_size |
10 |
Beam size for AED beam search inference |
--pen_alpha |
1.0 |
Sentence length penalty for AED beam search inference |