diff --git a/examples/asr/asr_eou/speech_to_text_eou_eval.py b/examples/asr/asr_eou/speech_to_text_eou_eval.py new file mode 100644 index 000000000000..17d00a385be4 --- /dev/null +++ b/examples/asr/asr_eou/speech_to_text_eou_eval.py @@ -0,0 +1,112 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example usage: + +```bash +TEST_MANIFEST="[/path/to/your/test_manifest.json,/path/to/your/test_manifest2.json,...]" +TEST_NAME="[test_name1,test_name2,...]" +TEST_BATCH=32 +NUM_WORKERS=8 + +PRETRAINED_NEMO=/path/to/EOU/model.nemo +CONFIG_NAME=fastconformer_transducer_bpe_streaming + +python speech_to_text_eou_eval.py \ + --config-name $CONFIG_NAME \ + ++init_from_nemo_model=$PRETRAINED_NEMO \ + ~model.train_ds \ + ~model.validation_ds \ + ++model.test_ds.defer_setup=true \ + ++model.test_ds.sample_rate=16000 \ + ++model.test_ds.manifest_filepath=$TEST_MANIFEST \ + ++model.test_ds.name=$TEST_NAME \ + ++model.test_ds.batch_size=$TEST_BATCH \ + ++model.test_ds.num_workers=$NUM_WORKERS \ + ++model.test_ds.drop_last=false \ + ++model.test_ds.force_finite=true \ + ++model.test_ds.shuffle=false \ + ++model.test_ds.pin_memory=true \ + exp_manager.create_wandb_logger=false +``` + +""" + + +import lightning.pytorch as pl +import torch + +torch.set_float32_matmul_precision("highest") +from omegaconf import DictConfig, OmegaConf, open_dict + +from nemo.collections.asr.models import ASRModel +from nemo.core.classes import typecheck +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager +from nemo.utils.trainer_utils import resolve_trainer_cfg + +typecheck.set_typecheck_enabled(False) + + +def load_model(cfg: DictConfig, trainer: pl.Trainer) -> ASRModel: + if "init_from_nemo_model" in cfg: + logging.info(f"Loading model from local file: {cfg.init_from_nemo_model}") + model = ASRModel.restore_from(cfg.init_from_nemo_model, trainer=trainer) + elif "init_from_pretrained_model" in cfg: + logging.info(f"Loading model from remote: {cfg.init_from_pretrained_model}") + model = ASRModel.from_pretrained(cfg.init_from_pretrained_model, trainer=trainer) + else: + raise ValueError( + "Please provide either 'init_from_nemo_model' or 'init_from_pretrained_model' in the config file." + ) + if cfg.get("init_from_ptl_ckpt", None): + logging.info(f"Loading weights from checkpoint: {cfg.init_from_ptl_ckpt}") + state_dict = torch.load(cfg.init_from_ptl_ckpt, map_location='cpu', weights_only=False)['state_dict'] + model.load_state_dict(state_dict, strict=True) + return model + + +@hydra_runner(config_path="../conf/asr_eou", config_name="fastconformer_transducer_bpe_streaming") +def main(cfg): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + trainer = pl.Trainer(**resolve_trainer_cfg(cfg.trainer)) + exp_manager(trainer, cfg.get("exp_manager", None)) + + asr_model = load_model(cfg, trainer) + asr_model = asr_model.eval() # Set the model to evaluation mode + if hasattr(asr_model, 'wer'): + asr_model.wer.log_prediction = False + + with open_dict(asr_model.cfg): + if "save_pred_to_file" in cfg: + asr_model.cfg.save_pred_to_file = cfg.save_pred_to_file + if "calclate_eou_metrics" in cfg: + asr_model.cfg.calclate_eou_metrics = cfg.calclate_eou_metrics + if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: + with open_dict(cfg.model.test_ds): + cfg.model.test_ds.pad_eou_label_secs = asr_model.cfg.get('pad_eou_label_secs', 0.0) + asr_model.setup_multiple_test_data(test_data_config=cfg.model.test_ds) + trainer.test(asr_model) + else: + raise ValueError( + "No test dataset provided. Please provide a test dataset in the config file under model.test_ds." + ) + logging.info("Test completed.") + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/asr/asr_eou/speech_to_text_rnnt_eou_train.py b/examples/asr/asr_eou/speech_to_text_rnnt_eou_train.py new file mode 100644 index 000000000000..81249f703412 --- /dev/null +++ b/examples/asr/asr_eou/speech_to_text_rnnt_eou_train.py @@ -0,0 +1,348 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example usage: + +1. Prepare dataset based on /nemo/collections/asr/data/audio_to_eou_label_lhotse.py + Specifically, each sample in the jsonl manifest should have the following fields: + { + "audio_filepath": "/path/to/audio.wav", + "text": "The text of the audio." + "offset": 0.0, # offset of the audio, in seconds + "duration": 3.0, # duration of the audio, in seconds + "sou_time": 0.2, # start of utterance time, in seconds + "eou_time": 1.5, # end of utterance time, in seconds + } + +2. If using a normal ASR model as initialization: + - Add special tokens and to the tokenizer of pretrained model, by refering to the script + /scripts/asr_eou/tokenizers/add_special_tokens_to_sentencepiece.py + - If pretrained model is HybridRNNTCTCBPEModel, convert it to RNNT using the script + /examples/asr/asr_hybrid_transducer_ctc/helpers/convert_nemo_asr_hybrid_to_ctc.py + +3. Run the following command to train the ASR-EOU model: +```bash +#!/bin/bash + +TRAIN_MANIFEST=/path/to/train_manifest.json +VAL_MANIFEST=/path/to/val_manifest.json +NOISE_MANIFEST=/path/to/noise_manifest.json + +PRETRAINED_NEMO=/path/to/pretrained_model.nemo +TOKENIZER_DIR=/path/to/tokenizer_dir + +BATCH_SIZE=16 +NUM_WORKERS=8 +LIMIT_TRAIN_BATCHES=1000 +VAL_CHECK_INTERVAL=1000 +MAX_STEPS=1000000 + +EXP_NAME=fastconformer_transducer_bpe_streaming_eou +SCRIPT=${NEMO_PATH}/examples/asr/asr_eou/speech_to_text_rnnt_eou_train.py +CONFIG_PATH=${NEMO_PATH}/examples/asr/conf/asr_eou +CONFIG_NAME=fastconformer_transducer_bpe_streaming + +CUDA_VISIBLE_DEVICES=0 python $SCRIPT \ + --config-path $CONFIG_PATH \ + --config-name $CONFIG_NAME \ + ++init_from_nemo_model=$PRETRAINED_NEMO \ + model.encoder.att_context_size="[70,1]" \ + model.tokenizer.dir=$TOKENIZER_DIR \ + model.train_ds.manifest_filepath=$TRAIN_MANIFEST \ + model.train_ds.augmentor.noise.manifest_path=$NOISE_MANIFEST \ + model.validation_ds.manifest_filepath=$VAL_MANIFEST \ + model.train_ds.batch_size=$BATCH_SIZE \ + model.train_ds.num_workers=$NUM_WORKERS \ + model.validation_ds.batch_size=$BATCH_SIZE \ + model.validation_ds.num_workers=$NUM_WORKERS \ + ~model.test_ds \ + trainer.limit_train_batches=$LIMIT_TRAIN_BATCHES \ + trainer.val_check_interval=$VAL_CHECK_INTERVAL \ + trainer.max_steps=$MAX_STEPS \ + exp_manager.name=$EXP_NAME +``` + +""" + +from dataclasses import is_dataclass +from typing import Optional + +import lightning.pytorch as pl +from omegaconf import DictConfig, OmegaConf, open_dict + +from nemo.collections.asr.models import ASRModel, EncDecHybridRNNTCTCBPEModel, EncDecRNNTBPEModel +from nemo.collections.asr.models.asr_eou_models import EncDecRNNTBPEEOUModel +from nemo.collections.asr.modules.rnnt import RNNTDecoder, RNNTJoint +from nemo.core import adapter_mixins +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager +from nemo.utils.trainer_utils import resolve_trainer_cfg + + +def add_global_adapter_cfg(model, global_adapter_cfg): + # Convert to DictConfig from dict or Dataclass + if is_dataclass(global_adapter_cfg): + global_adapter_cfg = OmegaConf.structured(global_adapter_cfg) + + if not isinstance(global_adapter_cfg, DictConfig): + global_adapter_cfg = DictConfig(global_adapter_cfg) + + # Update the model.cfg with information about the new adapter global cfg + with open_dict(global_adapter_cfg), open_dict(model.cfg): + if 'adapters' not in model.cfg: + model.cfg.adapters = OmegaConf.create({}) + + # Add the global config for adapters to the model's internal config + model.cfg.adapters[model.adapter_global_cfg_key] = global_adapter_cfg + + # Update all adapter modules (that already exist) with this global adapter config + model.update_adapter_cfg(model.cfg.adapters) + + +def update_model_config_to_support_adapter(model_cfg): + with open_dict(model_cfg): + # Update encoder adapter compatible config + adapter_metadata = adapter_mixins.get_registered_adapter(model_cfg.encoder._target_) + if adapter_metadata is not None: + model_cfg.encoder._target_ = adapter_metadata.adapter_class_path + + +def setup_adapters(cfg: DictConfig, model: ASRModel): + # Setup adapters + with open_dict(cfg.model.adapter): + # Extract the name of the adapter (must be give for training) + adapter_name = cfg.model.adapter.pop("adapter_name") + adapter_type = cfg.model.adapter.pop("adapter_type") + adapter_module_name = cfg.model.adapter.pop("adapter_module_name", None) + + # Resolve the config of the specified `adapter_type` + if adapter_type not in cfg.model.adapter.keys(): + raise ValueError( + f"Adapter type ({adapter_type}) config could not be found. Adapter setup config - \n" + f"{OmegaConf.to_yaml(cfg.model.adapter)}" + ) + + adapter_type_cfg = cfg.model.adapter[adapter_type] + print(f"Found `{adapter_type}` config :\n" f"{OmegaConf.to_yaml(adapter_type_cfg)}") + + # Augment adapter name with module name, if not provided by user + if adapter_module_name is not None and ':' not in adapter_name: + adapter_name = f'{adapter_module_name}:{adapter_name}' + + # Extract the global adapter config, if provided + adapter_global_cfg = cfg.model.adapter.pop(model.adapter_global_cfg_key, None) + if adapter_global_cfg is not None: + add_global_adapter_cfg(model, adapter_global_cfg) + + model.add_adapter(adapter_name, cfg=adapter_type_cfg) + assert model.is_adapter_available() + + # Disable all other adapters, enable just the current adapter. + model.set_enabled_adapters(enabled=False) # disable all adapters prior to training + model.set_enabled_adapters(adapter_name, enabled=True) # enable just one adapter by name + + model.freeze() # freeze whole model by default + if not cfg.model.get("freeze_decoder", True): + model.decoder.unfreeze() + if hasattr(model, 'joint') and not cfg.model.get(f"freeze_joint", True): + model.joint.unfreeze() + + # Activate dropout() and other modules that depend on train mode. + model = model.train() + # Then, Unfreeze just the adapter weights that were enabled above (no part of encoder/decoder/joint/etc) + model.unfreeze_enabled_adapters() + return model + + +def get_pretrained_model_name(cfg: DictConfig) -> Optional[str]: + if hasattr(cfg, 'init_from_ptl_ckpt') and cfg.init_from_ptl_ckpt is not None: + raise NotImplementedError( + "Currently for simplicity of single script for all model types, we only support `init_from_nemo_model` and `init_from_pretrained_model`" + ) + nemo_model_path = cfg.get('init_from_nemo_model', None) + pretrained_name = cfg.get('init_from_pretrained_model', None) + if nemo_model_path is not None and pretrained_name is not None: + raise ValueError("Only pass `init_from_nemo_model` or `init_from_pretrained_model` but not both") + elif nemo_model_path is None and pretrained_name is None: + return None + + if nemo_model_path: + return nemo_model_path + if pretrained_name: + return pretrained_name + return None + + +def init_from_pretrained_nemo(model: EncDecRNNTBPEEOUModel, pretrained_model_path: str, cfg: DictConfig): + """ + Load the pretrained model from a .nemo file or remote checkpoint. If the pretrained model has exactly + the same vocabulary size as the current model, the whole model will be loaded directly. Otherwise, + the encoder and decoder weights will be loaded separately and the EOU/EOB classes will be handled separately. + """ + if pretrained_model_path.endswith('.nemo'): + pretrained_model = ASRModel.restore_from(restore_path=pretrained_model_path) # type: EncDecRNNTBPEModel + else: + pretrained_model = ASRModel.from_pretrained(pretrained_model_path) # type: EncDecRNNTBPEModel + + if not isinstance(pretrained_model, (EncDecRNNTBPEModel, EncDecHybridRNNTCTCBPEModel)): + raise TypeError( + f"Pretrained model {pretrained_model.__class__} is not EncDecRNNTBPEModel or EncDecHybridRNNTCTCBPEModel." + ) + + try: + model.load_state_dict(pretrained_model.state_dict(), strict=True) + logging.info( + f"Pretrained model from {pretrained_model_path} has exactly the same model structure, skip further loading." + ) + return + except Exception: + logging.warning( + f"Pretrained model {pretrained_model_path} has different model structure, try loading weights separately and add EOU/EOB classes." + ) + + # Load encoder state dict into the model + model.encoder.load_state_dict(pretrained_model.encoder.state_dict(), strict=True) + logging.info(f"Encoder weights loaded from {pretrained_model_path}.") + + # Load decoder state dict into the model + decoder = model.decoder # type: RNNTDecoder + pretrained_decoder = pretrained_model.decoder # type: RNNTDecoder + if not isinstance(decoder, RNNTDecoder) or not isinstance(pretrained_decoder, RNNTDecoder): + raise TypeError( + f"Decoder {decoder.__class__} is not RNNTDecoder or pretrained decoder {pretrained_decoder.__class__} is not RNNTDecoder." + ) + + decoder.prediction["dec_rnn"].load_state_dict(pretrained_decoder.prediction["dec_rnn"].state_dict(), strict=True) + + decoder_embed_states = decoder.prediction["embed"].state_dict()['weight'] # shape: [num_classes+2, hid_dim] + pretrained_decoder_embed_states = pretrained_decoder.prediction["embed"].state_dict()[ + 'weight' + ] # shape: [num_classes, hid_dim] + if decoder_embed_states.shape[0] != pretrained_decoder_embed_states.shape[0] + 2: + raise ValueError( + f"Size mismatched between pretrained ({pretrained_decoder_embed_states.shape[0]}+2) and current model ({decoder_embed_states.shape[0]}), skip loading decoder embedding." + ) + + decoder_embed_states[:-3, :] = pretrained_decoder_embed_states[:-1, :] # everything except EOU, EOB and blank + decoder_embed_states[-1, :] = pretrained_decoder_embed_states[-1, :] # blank class + decoder.prediction["embed"].load_state_dict({"weight": decoder_embed_states}, strict=True) + logging.info(f"Decoder weights loaded from {pretrained_model_path}.") + + # Load joint network weights if new model's joint network has two more classes than the pretrained model + joint_network = model.joint # type: RNNTJoint + pretrained_joint_network = pretrained_model.joint # type: RNNTJoint + assert isinstance(joint_network, RNNTJoint), f"Joint network {joint_network.__class__} is not RNNTJoint." + assert isinstance( + pretrained_joint_network, RNNTJoint + ), f"Pretrained joint network {pretrained_joint_network.__class__} is not RNNTJoint." + joint_network.pred.load_state_dict(pretrained_joint_network.pred.state_dict(), strict=True) + joint_network.enc.load_state_dict(pretrained_joint_network.enc.state_dict(), strict=True) + + if joint_network.num_classes_with_blank != pretrained_joint_network.num_classes_with_blank + 2: + raise ValueError( + f"Size mismatched between pretrained ({pretrained_joint_network.num_classes_with_blank}+2) and current model ({joint_network.num_classes_with_blank}), skip loading joint network." + ) + + # Load the joint network weights + pretrained_joint_state = pretrained_joint_network.joint_net.state_dict() + joint_state = joint_network.joint_net.state_dict() + pretrained_joint_clf_weight = pretrained_joint_state['2.weight'] # shape: [num_classes, hid_dim] + pretrained_joint_clf_bias = pretrained_joint_state['2.bias'] if '2.bias' in pretrained_joint_state else None + + token_init_method = cfg.model.get('token_init_method', 'constant') + # Copy the weights and biases from the pretrained model to the new model + # shape: [num_classes+2, hid_dim] + joint_state['2.weight'][:-3, :] = pretrained_joint_clf_weight[:-1, :] # everything except EOU, EOB and blank + joint_state['2.weight'][-1, :] = pretrained_joint_clf_weight[-1, :] # blank class + + value = None + if token_init_method == 'min': + # set the EOU and EOB class to the minimum value of the pretrained model + value = pretrained_joint_clf_weight.min(dim=0)[0] + elif token_init_method == 'max': + # set the EOU and EOB class to the maximum value of the pretrained model + value = pretrained_joint_clf_weight.max(dim=0)[0] + elif token_init_method == 'mean': + # set the EOU and EOB class to the mean value of the pretrained model + value = pretrained_joint_clf_weight.mean(dim=0) + elif token_init_method == 'constant': + value = cfg.model.get('token_init_weight_value', None) + elif token_init_method: + raise ValueError(f"Unknown token_init_method: {token_init_method}.") + + if value is not None: + joint_state['2.weight'][-2, :] = value # EOB class + joint_state['2.weight'][-3, :] = value # EOU class + + if pretrained_joint_clf_bias is not None and '2.bias' in joint_state: + joint_state['2.bias'][:-3] = pretrained_joint_clf_bias[:-1] # everything except EOU, EOB and blank + joint_state['2.bias'][-1] = pretrained_joint_clf_bias[-1] # blank class + value = None + if token_init_method == 'constant': + value = cfg.model.get('token_init_bias_value', None) + elif token_init_method == 'min': + # set the EOU and EOB class to the minimum value of the pretrained model + value = pretrained_joint_clf_bias.min() + elif token_init_method == 'max': + # set the EOU and EOB class to the maximum value of the pretrained model + value = pretrained_joint_clf_bias.max() + elif token_init_method == 'mean': + # set the EOU and EOB class to the mean value of the pretrained model + value = pretrained_joint_clf_bias.mean() + elif token_init_method: + raise ValueError(f"Unknown token_init_method: {token_init_method}.") + + if value is not None: + joint_state['2.bias'][-2] = value # EOB class + joint_state['2.bias'][-3] = value # EOU class + + # Load the joint network weights + joint_network.joint_net.load_state_dict(joint_state, strict=True) + logging.info(f"Joint network weights loaded from {pretrained_model_path}.") + + +@hydra_runner(config_path="../conf/asr_eou", config_name="fastconformer_transducer_bpe_streaming") +def main(cfg): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + trainer = pl.Trainer(**resolve_trainer_cfg(cfg.trainer)) + exp_manager(trainer, cfg.get("exp_manager", None)) + + if cfg.model.get("adapter", None) is not None: + update_model_config_to_support_adapter(cfg.model) + + asr_model = EncDecRNNTBPEEOUModel(cfg=cfg.model, trainer=trainer) + + init_from_model = get_pretrained_model_name(cfg) + if init_from_model: + init_from_pretrained_nemo(asr_model, init_from_model, cfg) + + if cfg.model.get("freeze_encoder", False): + logging.info("Freezing encoder weights.") + asr_model.encoder.freeze() + + if cfg.model.get("adapter", None) is not None: + asr_model = setup_adapters(cfg, asr_model) + + trainer.fit(asr_model) + + if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: + if asr_model.prepare_test(trainer): + trainer.test(asr_model) + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/asr/asr_hybrid_transducer_ctc/helpers/convert_nemo_asr_hybrid_to_ctc.py b/examples/asr/asr_hybrid_transducer_ctc/helpers/convert_nemo_asr_hybrid_to_ctc.py index 199e399ead11..34afa8309084 100644 --- a/examples/asr/asr_hybrid_transducer_ctc/helpers/convert_nemo_asr_hybrid_to_ctc.py +++ b/examples/asr/asr_hybrid_transducer_ctc/helpers/convert_nemo_asr_hybrid_to_ctc.py @@ -20,7 +20,7 @@ in NeMo. The resulting .nemo file will be a pure CTC or RNNT model, and can be used like any other .nemo model including in nemo2riva. -Usage: python convert_nemo_asr_hybrid_to_ctc.py -i /path/to/hybrid.nemo -o /path/to/saved_ctc_model.nemo -m ctc|rnnt +Usage: python convert_nemo_asr_hybrid_to_ctc.py -i /path/to/hybrid.nemo -o /path/to/saved_ctc_model.nemo -t ctc|rnnt """ diff --git a/examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming.yaml b/examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming.yaml new file mode 100644 index 000000000000..6ec564245b8c --- /dev/null +++ b/examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming.yaml @@ -0,0 +1,318 @@ +# It contains the default values for training a cache-aware streaming FastConformer-Transducer ASR+EOU model, large size (~115M) with sub-word encoding. + +# You may find more detail: +# FastConformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#fast-conformer +# Cache-aware Conformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#cache-aware-streaming-conformer +# FastConformer-Transducer's architecture config, along with the optimal batch size and precision: NeMo/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml + +name: "FastConformer-Transducer-BPE-Streaming-EOU" + +model: + token_init_method: "constant" # choices=['min', 'max', 'mean', 'constant'] + token_init_weight_value: null # only applicable when token_init_method='constant' + token_init_bias_value: -1000.0 # only applicable when token_init_method='constant' + + sample_rate: 16000 + compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag. + log_prediction: true # enables logging sample predictions in the output during training + skip_nan_grad: false + + model_defaults: + enc_hidden: ${model.encoder.d_model} + pred_hidden: 640 + joint_hidden: 640 + + train_ds: + manifest_filepath: ??? + tarred_audio_filepaths: null + sample_rate: ${model.sample_rate} + max_duration: 30 # you may need to update it for your dataset + min_duration: 0.1 + defer_setup: true + batch_duration: null # you may disable batch_duration by setting it to `null` + batch_size: 16 + shuffle: true + drop_last: true + num_workers: 8 + pin_memory: true + quadratic_duration: 30 + num_buckets: 30 + num_cuts_for_bins_estimate: 10000 + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + + ignore_eob_label: true # ignore backchannel and treat them the same as EOU + random_padding: + prob: 0.99 + min_post_pad_duration: 3.0 + min_pre_pad_duration: 0.0 + max_pad_duration: 6.0 # maximum duration of pre/post padding in seconds + max_total_duration: 40.0 # maximum total duration of the padded audio in seconds + pad_distribution: 'uniform' # distribution of padding duration, 'uniform' or 'normal' + normal_mean: 0.5 # mean of normal distribution used when pad_distribution='normal' + normal_std: 2.0 # standard deviation of normal distribution used when pad_distribution='normal' + + augmentor: + white_noise: + prob: 0.9 + min_level: -90 + max_level: -46 + gain: + prob: 0.2 + min_gain_dbfs: -10.0 + max_gain_dbfs: 10.0 + noise: + prob: 0.9 + manifest_path: ??? + min_snr_db: 0 + max_snr_db: 20 + max_gain_db: 300.0 + + validation_ds: + manifest_filepath: ??? + tarred_audio_filepaths: null + sample_rate: ${model.sample_rate} + max_duration: 30 # you may need to update it for your dataset + min_duration: 0.1 + defer_setup: true + batch_duration: null # you may disable batch_duration by setting it to `null` + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + quadratic_duration: 30 + num_buckets: 30 + num_cuts_for_bins_estimate: 10000 + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + ignore_eob_label: true # ignore backchannel and treat them the same as EOU + + test_ds: + manifest_filepath: null + tarred_audio_filepaths: null + sample_rate: ${model.sample_rate} + max_duration: 30 # you may need to update it for your dataset + min_duration: 0.1 + defer_setup: true + batch_duration: null # you may disable batch_duration by setting it to `null` + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + quadratic_duration: 30 + num_buckets: 30 + num_cuts_for_bins_estimate: 10000 + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + ignore_eob_label: true # ignore backchannel and treat them the same as EOU + + # You may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py + # We recommend to use vocab size of 1024 with SPE Unigram for most languages + tokenizer: + dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe) + type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "NA" # No normalization for mel-spectogram makes streaming easier + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 128 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 17 + d_model: 512 + use_bias: false # whether to apply bias in the feedforward, MHA and convolution modules + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: true + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + # for att_context_style=regular, the right context is recommended to be a small number around 0 to 3 as multiple-layers may increase the effective right context too large + # for att_context_style=chunked_limited, the left context need to be dividable by the right context plus one + # look-ahead(secs) = att_context_size[1]*subsampling_factor*window_stride, example: 13*8*0.01=1.04s + + # For multi-lookahead models, you may specify a list of context sizes. During the training, different context sizes would be used randomly with the distribution specified by att_context_probs. + # The first item in the list would be the default during test/validation/inference. + # An example of settings for multi-lookahead: + # att_context_size: [[70,13],[70,6],[70,1],[70,0]] + # att_context_probs: [0.25, 0.25, 0.25, 0.25, 0.25] + att_context_size: [70, 1] # -1 means unlimited context + att_context_style: chunked_limited # regular or chunked_limited + att_context_probs: null + + xscaling: false # scales up the input embeddings by sqrt(d_model) + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'layer_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + # Recommend to use causal convolutions as it would increase the effective right context and therefore the look-ahead significantly + conv_context_size: causal + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + decoder: + _target_: nemo.collections.asr.modules.RNNTDecoder + normalization_mode: null # Currently only null is supported for export. + random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf + blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference. + + prednet: + pred_hidden: ${model.model_defaults.pred_hidden} + pred_rnn_layers: 1 + t_max: null + dropout: 0.2 + + joint: + _target_: nemo.collections.asr.modules.RNNTJoint + log_softmax: null # 'null' would set it automatically according to CPU/GPU device + preserve_memory: false # dramatically slows down training, but might preserve some memory + + # Fuses the computation of prediction net + joint net + loss + WER calculation + # to be run on sub-batches of size `fused_batch_size`. + # When this flag is set to true, consider the `batch_size` of *_ds to be just `encoder` batch size. + # `fused_batch_size` is the actual batch size of the prediction net, joint net and transducer loss. + # Using small values here will preserve a lot of memory during training, but will make training slower as well. + # An optimal ratio of fused_batch_size : *_ds.batch_size is 1:1. + # However, to preserve memory, this ratio can be 1:8 or even 1:16. + # Extreme case of 1:B (i.e. fused_batch_size=1) should be avoided as training speed would be very slow. + fuse_loss_wer: true + fused_batch_size: 4 + + jointnet: + joint_hidden: ${model.model_defaults.joint_hidden} + activation: "relu" + dropout: 0.2 + + decoding: + strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd. + + # greedy strategy config + greedy: + max_symbols: 10 + + # beam strategy config + beam: + beam_size: 2 + return_best_hypothesis: False + score_norm: true + tsd_max_sym_exp: 50 # for Time Synchronous Decoding + alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding + + # config for InterCTC loss: https://arxiv.org/abs/2102.03216 + # specify loss weights and which layers to use for InterCTC + # e.g., to reproduce the paper results, set loss_weights: [0.3] + # and apply_at_layers: [8] (assuming 18 layers). Note that final + # layer loss coefficient is automatically adjusted (to 0.7 in above example) + interctc: + loss_weights: [] + apply_at_layers: [] + + loss: + loss_name: "default" + warprnnt_numba_kwargs: + # FastEmit regularization: https://arxiv.org/abs/2010.11148 + # You may enable FastEmit to increase the accuracy and reduce the latency of the model for streaming + # You may set it to lower values like 1e-3 for models with larger right context + fastemit_lambda: 3e-2 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. + + optim: + name: adamw + lr: 5.0 # 1e-4 + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: NoamAnnealing # NoamAnnealing CosineAnnealing + # scheduler config override + d_model: ${model.encoder.d_model} + warmup_steps: 10000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: 100000 # computed at runtime if not set + val_check_interval: 1000 # an int for number of iterations + limit_train_batches: ${trainer.val_check_interval} + accelerator: auto + strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + gradient_as_bucket_view: true + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + precision: 32 # 16, 32, or bf16 + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + use_distributed_sampler: false + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_wer" + mode: "min" + save_top_k: 5 + filename: '${exp_manager.name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}' + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + resume_if_exists: false + resume_ignore_no_checkpoint: false + + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming_adapter.yaml b/examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming_adapter.yaml new file mode 100644 index 000000000000..7cc70cf00378 --- /dev/null +++ b/examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming_adapter.yaml @@ -0,0 +1,371 @@ +# It contains the default values for training a cache-aware streaming FastConformer-Transducer ASR+EOU model, large size (~115M) with sub-word encoding. + +# You may find more detail: +# FastConformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#fast-conformer +# Cache-aware Conformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#cache-aware-streaming-conformer +# FastConformer-Transducer's architecture config, along with the optimal batch size and precision: NeMo/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml + +name: "FastConformer-Transducer-BPE-Streaming-EOU-adapter" + +model: + token_init_method: "constant" # choices=['min', 'max', 'mean', 'constant'] + token_init_weight_value: null # only applicable when token_init_method='constant' + token_init_bias_value: -1000.0 # only applicable when token_init_method='constant' + + sample_rate: 16000 + compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag. + log_prediction: true # enables logging sample predictions in the output during training + skip_nan_grad: false + + model_defaults: + enc_hidden: ${model.encoder.d_model} + pred_hidden: 640 + joint_hidden: 640 + + adapter: + ### Config of the adapter training/eval script ### + adapter_name: "eou-adapter" # Name of the adapter, used by the script + adapter_type: "linear" # Type of the adapter. Corresponds to the subconfigs below. + adapter_module_name: null # Name of the adapter module. Combine multiple modules with '+' between module names. + adapter_state_dict_name: "adapters.pt" # If the individual adapters must be saved, a file name can be provided here. null disables this. + + ### Adapter Configs ### + # Linear / Houlsby Adapter (https://arxiv.org/abs/1902.00751) + linear: + # Config of the adapter module itself + _target_: nemo.collections.common.parts.adapter_modules.LinearAdapter + in_features: ${model.encoder.d_model} # User must provide the output dimension of the layers of the model, which is the input dimension of this adapter. + dim: 32 # The hidden dimension of the adapter, as chosen by user, but small values are preferred to reduce param count. + activation: swish + norm_position: 'pre' # Can be `pre` or `post` + dropout: 0.0 # float, dropout for the adapter + + # Adapter strategy config + adapter_strategy: + _target_: nemo.core.classes.mixins.adapter_mixin_strategies.ResidualAddAdapterStrategy + stochastic_depth: 0.0 # float, setting to > 0 will enable stochastic depth for each adapter block. + l2_lambda: 0.0 # float, setting to > 0 will enable l2 norm auxiliary loss for each adapter's output. + + # Tiny-Attention Adapter (https://arxiv.org/abs/2211.01979) + # NOTE: Only supported for Attention based encoders. Make sure to pass `adapter_module_name` as "encoder" + tiny_attn: + # Config of the adapter module itself + # Defaults to Relative Positional Encoding MHA + # _target_ can instead be .MultiHeadAttentionAdapter if Conformer was originally using Absolute Positional Encoding. + _target_: nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module.RelPositionMultiHeadAttentionAdapter + n_feat: ${model.encoder.d_model} # User must provide the output dimension of the layers of the model, which is the input dimension of this adapter. + n_head: 1 # Number of heads for attention. + proj_dim: -1 # Can be `null` - to avoid projection, > 0 for explicit dim, or -1 to default to `n_head` + dropout_rate: 0.0 # float, dropout for the adapter + + # Adapter strategy config + adapter_strategy: + _target_: nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module.MHAResidualAddAdapterStrategy + stochastic_depth: 0.0 # float, setting to > 0 will enable stochastic depth for each adapter block. + l2_lambda: 0.0 # float, setting to > 0 will enable l2 norm auxiliary loss for each adapter's output. + + # Optional global config available to all adapters at a global level. + # A global config is shared across every layer of the adapters, defining global properties rather + # than properties local to the adapter (as defined above). + # This can be useful in order to select *which type of adapter* is added, *what adapters to enable*, + # and further global operations that can decide dynamically how to support the requested adapter. + global_cfg: + check_encoder_adapter: True # ASR adapter key, determines whether to check if encoder adapter modules is supported + check_decoder_adapter: True # ASR adapter key, determines whether to check if decoder adapter modules is supported + check_joint_adapter: True # ASR adapter key, determines whether to check if joint adapter modules is supported + + freeze_decoder: ${model.adapter.global_cfg.check_decoder_adapter} + freeze_joint: ${model.adapter.global_cfg.check_joint_adapter} + + train_ds: + manifest_filepath: ??? + tarred_audio_filepaths: null + sample_rate: ${model.sample_rate} + max_duration: 30 # you may need to update it for your dataset + min_duration: 0.1 + defer_setup: true + batch_duration: null # you may disable batch_duration by setting it to `null` + batch_size: 16 + shuffle: true + drop_last: true + num_workers: 8 + pin_memory: true + quadratic_duration: 30 + num_buckets: 30 + num_cuts_for_bins_estimate: 10000 + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + + ignore_eob_label: true # ignore backchannel and treat them the same as EOU + random_padding: + prob: 0.99 + min_pad_duration: 1.0 # minimum duration of pre/post padding in seconds + max_pad_duration: 10.0 # maximum duration of pre/post padding in seconds + max_total_duration: 40.0 # maximum total duration of the padded audio in seconds + pad_distribution: 'uniform' # distribution of padding duration, 'uniform' or 'normal' + normal_mean: 0.5 # mean of normal distribution used when pad_distribution='normal' + normal_std: 2.0 # standard deviation of normal distribution used when pad_distribution='normal' + + augmentor: + white_noise: + prob: 0.9 + min_level: -90 + max_level: -46 + gain: + prob: 0.2 + min_gain_dbfs: -10.0 + max_gain_dbfs: 10.0 + noise: + prob: 0.9 + manifest_path: ??? + min_snr_db: 0 + max_snr_db: 20 + max_gain_db: 300.0 + + validation_ds: + manifest_filepath: ??? + tarred_audio_filepaths: null + sample_rate: ${model.sample_rate} + max_duration: 30 # you may need to update it for your dataset + min_duration: 0.1 + defer_setup: true + batch_duration: null # you may disable batch_duration by setting it to `null` + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + quadratic_duration: 30 + num_buckets: 30 + num_cuts_for_bins_estimate: 10000 + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + ignore_eob_label: true # ignore backchannel and treat them the same as EOU + + test_ds: + manifest_filepath: null + tarred_audio_filepaths: null + sample_rate: ${model.sample_rate} + max_duration: 30 # you may need to update it for your dataset + min_duration: 0.1 + defer_setup: true + batch_duration: null # you may disable batch_duration by setting it to `null` + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + quadratic_duration: 30 + num_buckets: 30 + num_cuts_for_bins_estimate: 10000 + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + ignore_eob_label: true # ignore backchannel and treat them the same as EOU + + # You may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py + # We recommend to use vocab size of 1024 with SPE Unigram for most languages + tokenizer: + dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe) + type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "NA" # No normalization for mel-spectogram makes streaming easier + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 17 + d_model: 512 + use_bias: false # whether to apply bias in the feedforward, MHA and convolution modules + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: true + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + # for att_context_style=regular, the right context is recommended to be a small number around 0 to 3 as multiple-layers may increase the effective right context too large + # for att_context_style=chunked_limited, the left context need to be dividable by the right context plus one + # look-ahead(secs) = att_context_size[1]*subsampling_factor*window_stride, example: 13*8*0.01=1.04s + + # For multi-lookahead models, you may specify a list of context sizes. During the training, different context sizes would be used randomly with the distribution specified by att_context_probs. + # The first item in the list would be the default during test/validation/inference. + # An example of settings for multi-lookahead: + # att_context_size: [[70,13],[70,6],[70,1],[70,0]] + # att_context_probs: [0.25, 0.25, 0.25, 0.25, 0.25] + att_context_size: [70, 1] # -1 means unlimited context + att_context_style: chunked_limited # regular or chunked_limited + att_context_probs: null + + xscaling: false # scales up the input embeddings by sqrt(d_model) + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'layer_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + # Recommend to use causal convolutions as it would increase the effective right context and therefore the look-ahead significantly + conv_context_size: causal + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + decoder: + _target_: nemo.collections.asr.modules.RNNTDecoder + normalization_mode: null # Currently only null is supported for export. + random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf + blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference. + + prednet: + pred_hidden: ${model.model_defaults.pred_hidden} + pred_rnn_layers: 1 + t_max: null + dropout: 0.2 + + joint: + _target_: nemo.collections.asr.modules.RNNTJoint + log_softmax: null # 'null' would set it automatically according to CPU/GPU device + preserve_memory: false # dramatically slows down training, but might preserve some memory + + # Fuses the computation of prediction net + joint net + loss + WER calculation + # to be run on sub-batches of size `fused_batch_size`. + # When this flag is set to true, consider the `batch_size` of *_ds to be just `encoder` batch size. + # `fused_batch_size` is the actual batch size of the prediction net, joint net and transducer loss. + # Using small values here will preserve a lot of memory during training, but will make training slower as well. + # An optimal ratio of fused_batch_size : *_ds.batch_size is 1:1. + # However, to preserve memory, this ratio can be 1:8 or even 1:16. + # Extreme case of 1:B (i.e. fused_batch_size=1) should be avoided as training speed would be very slow. + fuse_loss_wer: true + fused_batch_size: 4 + + jointnet: + joint_hidden: ${model.model_defaults.joint_hidden} + activation: "relu" + dropout: 0.2 + + decoding: + strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd. + + # greedy strategy config + greedy: + max_symbols: 10 + + # beam strategy config + beam: + beam_size: 2 + return_best_hypothesis: False + score_norm: true + tsd_max_sym_exp: 50 # for Time Synchronous Decoding + alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding + + # config for InterCTC loss: https://arxiv.org/abs/2102.03216 + # specify loss weights and which layers to use for InterCTC + # e.g., to reproduce the paper results, set loss_weights: [0.3] + # and apply_at_layers: [8] (assuming 18 layers). Note that final + # layer loss coefficient is automatically adjusted (to 0.7 in above example) + interctc: + loss_weights: [] + apply_at_layers: [] + + loss: + loss_name: "default" + warprnnt_numba_kwargs: + # FastEmit regularization: https://arxiv.org/abs/2010.11148 + # You may enable FastEmit to increase the accuracy and reduce the latency of the model for streaming + # You may set it to lower values like 1e-3 for models with larger right context + fastemit_lambda: 3e-2 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. + + optim: + name: adamw + lr: 5.0 # 1e-4 + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: NoamAnnealing # NoamAnnealing CosineAnnealing + # scheduler config override + d_model: ${model.encoder.d_model} + warmup_steps: 10000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: 100000 # computed at runtime if not set + val_check_interval: 1000 # an int for number of iterations + limit_train_batches: ${trainer.val_check_interval} + accelerator: auto + strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + gradient_as_bucket_view: true + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + precision: 32 # 16, 32, or bf16 + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + use_distributed_sampler: false + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_wer" + mode: "min" + save_top_k: 5 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + resume_if_exists: false + resume_ignore_no_checkpoint: false + + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index 2cdc3a30b96d..abca4f374656 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -302,7 +302,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis if cfg.decoder_type and cfg.decoder_type != 'rnnt': raise ValueError('RNNT model only support rnnt decoding!') - if cfg.decoder_type and hasattr(asr_model.encoder, 'set_default_att_context_size'): + if cfg.att_context_size and hasattr(asr_model.encoder, 'set_default_att_context_size'): asr_model.encoder.set_default_att_context_size(cfg.att_context_size) # Setup decoding strategy diff --git a/examples/asr/transcribe_speech_manifest_distributed.py b/examples/asr/transcribe_speech_manifest_distributed.py new file mode 100644 index 000000000000..9ecde83b4a8e --- /dev/null +++ b/examples/asr/transcribe_speech_manifest_distributed.py @@ -0,0 +1,301 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from copy import deepcopy +from dataclasses import dataclass, field +from math import ceil +from pathlib import Path +from typing import List + +from omegaconf import ListConfig +from tqdm import tqdm +from transcribe_speech import TranscriptionConfig as SingleTranscribeConfig +from transcribe_speech import main as single_transcribe_main + +from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest +from nemo.core.config import hydra_runner +from nemo.utils import logging + +""" +Transcribe audio manifests on distributed GPUs. Useful for transcription of moderate amounts of audio data. +This script also supports splitting the manifest into chunks and merging the results back together. +This script is a modified version of `transcribe_speech_distributed.py` that only takes manifest files as input. +It is useful for transcribing a large amount of audio data that does not fit into a single job. + +# Arguments + model_path: path to .nemo ASR checkpoint + pretrained_name: name of pretrained ASR model (from NGC registry) + dataset_manifest: path to dataset JSON manifest file (in NeMo formats), can be a comma-separated list of manifest files + or a directory containing manifest files + pattern: pattern to glob the manifest files if `dataset_manifest` is a directory + output_dir: directory to write the transcriptions + + compute_langs: Bool to request language ID information (if the model supports it) + timestamps: Bool to request greedy time stamp information (if the model supports it) by default None + + (Optionally: You can limit the type of timestamp computations using below overrides) + ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word, segment]) + rnnt_decoding.rnnt_timestamp_type="all" # (default all, can be [all, char, word, segment]) + + output_filename: Output filename where the transcriptions will be written + batch_size: batch size during inference + presort_manifest: sorts the provided manifest by audio length for faster inference (default: True) + + cuda: Optional int to enable or disable execution of model on certain CUDA device. + allow_mps: Bool to allow using MPS (Apple Silicon M-series GPU) device if available + amp: Bool to decide if Automatic Mixed Precision should be used during inference + audio_type: Str filetype of the audio. Supported = wav, flac, mp3 + + overwrite_transcripts: Bool which when set allows repeated transcriptions to overwrite previous results. + + ctc_decoding: Decoding sub-config for CTC. Refer to documentation for specific values. + rnnt_decoding: Decoding sub-config for RNNT. Refer to documentation for specific values. + + calculate_wer: Bool to decide whether to calculate wer/cer at end of this script + clean_groundtruth_text: Bool to clean groundtruth text + langid: Str used for convert_num_to_words during groundtruth cleaning + use_cer: Bool to use Character Error Rate (CER) or Word Error Rate (WER) + + calculate_rtfx: Bool to calculate the RTFx throughput to transcribe the input dataset. + +# Usage +ASR model can be specified by either "model_path" or "pretrained_name". +append_pred - optional. Allows you to add more than one prediction to an existing .json +pred_name_postfix - optional. The name you want to be written for the current model +Results are returned in a JSON manifest file. + +```bash +CUDA_VISIBLE_DEVICES=1 python transcribe_speech_distributed.py \ + model_path= \ + dataset_manifest="" \ + output_dir="" \ + output_filename="" \ + clean_groundtruth_text=True \ + langid='en' \ + batch_size=32 \ + timestamps=False \ + compute_langs=False \ + amp=True \ + append_pred=False \ + pred_name_postfix="" \ + split_size=10000 \ + num_nodes=1 \ + node_idx=0 \ + num_gpus_per_node=1 \ + gpu_idx=0 +``` + +If you use Slurm, you can use this params to configure the script: +```bash + gpu_idx=\$SLURM_LOCALID \ + num_gpus_per_node=\$SLURM_GPUS_ON_NODE \ + num_nodes=\$SLURM_JOB_NUM_NODES \ + node_idx=\$SLURM_NODEID +``` + +""" + + +@dataclass +class ModelChangeConfig: + """ + Sub-config for changes specific to the Conformer Encoder + """ + + conformer: ConformerChangeConfig = field(default_factory=ConformerChangeConfig) + + +@dataclass +class TranscriptionConfig(SingleTranscribeConfig): + """ + Transcription Configuration for audio to text transcription. + """ + + # General configs + pattern: str = "*.json" + output_dir: str = "transcribe_output/" + + # Distributed config + num_nodes: int = 1 # total number of nodes + node_idx: int = 0 # index of the current node + num_gpus_per_node: int = 1 # number of GPUs per node + gpu_idx: int = 0 # index of the current GPU + bind_gpu_to_cuda: bool = ( + False # If False, the script will just do .cuda() on the model, otherwise it will do .to(f"cuda:{gpu_idx}") + ) + + # handle long manifest + split_size: int = -1 # -1 means no split, otherwise split the manifest into chunks of this size + + +def get_unfinished_manifest(manifest_list: List[Path], output_dir: Path): + unfinished = [] + for manifest_file in manifest_list: + output_manifest_file = output_dir / manifest_file.name + if not output_manifest_file.exists(): + unfinished.append(manifest_file) + return sorted(unfinished) + + +def get_manifest_for_current_rank( + manifest_list: List[Path], gpu_id: int = 0, num_gpu: int = 1, node_idx: int = 0, num_node: int = 1 +): + node_manifest_list = [] + assert num_node > 0, f"num_node ({num_node}) must be greater than 0" + assert num_gpu > 0, f"num_gpu ({num_gpu}) must be greater than 0" + assert 0 <= gpu_id < num_gpu, f"gpu_id ({gpu_id}) must be in range [0, {num_gpu})" + assert 0 <= node_idx < num_node, f"node_idx ({node_idx}) must be in range [0, {num_node})" + for i, manifest_file in enumerate(manifest_list): + if (i + node_idx) % num_node == 0: + node_manifest_list.append(manifest_file) + + gpu_manifest_list = [] + for i, manifest_file in enumerate(node_manifest_list): + if (i + gpu_id) % num_gpu == 0: + gpu_manifest_list.append(manifest_file) + return gpu_manifest_list + + +def maybe_split_manifest(manifest_list: List[Path], cfg: TranscriptionConfig) -> List[Path]: + if cfg.split_size is None or cfg.split_size <= 0: + return manifest_list + + all_sharded_manifest_files = [] + sharded_manifest_dir = Path(cfg.output_dir) / "sharded_manifest_todo" + sharded_manifest_dir.mkdir(parents=True, exist_ok=True) + + sharded_manifest_done_dir = Path(cfg.output_dir) / "sharded_manifest_done" + sharded_manifest_done_dir.mkdir(parents=True, exist_ok=True) + cfg.output_dir = sharded_manifest_done_dir + + logging.info(f"Splitting {len(manifest_list)} manifest files by every {cfg.split_size} samples.") + for manifest_file in tqdm(manifest_list, total=len(manifest_list), desc="Splitting manifest files"): + manifest = read_manifest(manifest_file) + + num_chunks = ceil(len(manifest) / cfg.split_size) + for i in range(num_chunks): + chunk_manifest = manifest[i * cfg.split_size : (i + 1) * cfg.split_size] + sharded_manifest_file = sharded_manifest_dir / f"{manifest_file.stem}--tpart_{i}.json" + write_manifest(sharded_manifest_file, chunk_manifest) + all_sharded_manifest_files.append(sharded_manifest_file) + + return all_sharded_manifest_files + + +def maybe_merge_manifest(cfg: TranscriptionConfig): + if cfg.split_size is None or cfg.split_size <= 0: + return + + # only merge manifest on the first GPU of the first node + if not cfg.gpu_idx == 0 and cfg.node_idx == 0: + return + + sharded_manifest_dir = Path(cfg.output_dir) + sharded_manifests = list(sharded_manifest_dir.glob("*--tpart_*.json")) + if not sharded_manifests: + logging.info(f"No sharded manifest files found in {sharded_manifest_dir}") + return + + logging.info(f"Merging {len(sharded_manifests)} sharded manifest files.") + manifest_dict = defaultdict(list) + for sharded_manifest in sharded_manifests: + data_name = sharded_manifest.stem.split("--tpart_")[0] + manifest_dict[data_name].append(sharded_manifest) + + output_dir = Path(cfg.output_dir).parent + for data_name, sharded_manifest_list in tqdm( + manifest_dict.items(), total=len(manifest_dict), desc="Merging manifest files" + ): + merged_manifest = [] + for sharded_manifest in sharded_manifest_list: + manifest = read_manifest(sharded_manifest) + merged_manifest.extend(manifest) + output_manifest = output_dir / f"{data_name}.json" + write_manifest(output_manifest, merged_manifest) + logging.info(f"Merged manifest files saved to {output_dir}") + + +@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) +def run_distributed_transcribe(cfg: TranscriptionConfig): + + logging.info(f"Running distributed transcription with config: {cfg}") + + if cfg.dataset_manifest is None: + raise ValueError("`dataset_manifest` is required") + + # load the manifest + if isinstance(cfg.dataset_manifest, str) and "," in cfg.dataset_manifest: + manifest_list = cfg.dataset_manifest.split(",") + elif isinstance(cfg.dataset_manifest, (ListConfig, list)): + manifest_list = cfg.dataset_manifest + else: + input_manifest = Path(cfg.dataset_manifest) + if input_manifest.is_dir(): + manifest_list = list(input_manifest.glob(cfg.pattern)) + elif input_manifest.is_file(): + manifest_list = [input_manifest] + else: + raise ValueError(f"Invalid manifest file or directory: {input_manifest}") + + if not manifest_list: + raise ValueError(f"No manifest files found matching pattern: {cfg.pattern} in {input_manifest}") + + manifest_list = maybe_split_manifest(manifest_list, cfg) + original_manifest_list = list(manifest_list) + logging.info(f"Found {len(manifest_list)} manifest files.") + + output_dir = Path(cfg.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + unfinished_manifest = get_unfinished_manifest(manifest_list, output_dir=output_dir) + if not unfinished_manifest: + maybe_merge_manifest(cfg) + logging.info("All manifest files have been processed. Exiting.") + return + logging.info(f"Found {len(unfinished_manifest)} unfinished manifest files.") + + manifest_list = get_manifest_for_current_rank( + unfinished_manifest, + gpu_id=cfg.gpu_idx, + num_gpu=cfg.num_gpus_per_node, + node_idx=cfg.node_idx, + num_node=cfg.num_nodes, + ) + if not manifest_list: + logging.info(f"No manifest files found for GPU {cfg.gpu_idx} on node {cfg.node_idx}. Exiting.") + return + + logging.info(f"Processing {len(manifest_list)} manifest files with GPU {cfg.gpu_idx} on node {cfg.node_idx}.") + + cfg.cuda = cfg.gpu_idx if cfg.bind_gpu_to_cuda else None + for manifest_file in tqdm(manifest_list): + logging.info(f"Processing {manifest_file}...") + output_filename = output_dir / Path(manifest_file).name + curr_cfg = deepcopy(cfg) + curr_cfg.dataset_manifest = str(manifest_file) + curr_cfg.output_filename = str(output_filename) + + single_transcribe_main(curr_cfg) + + # check if all manifest files have been processed + unfinished_manifest = get_unfinished_manifest(original_manifest_list, output_dir=output_dir) + if not unfinished_manifest: + maybe_merge_manifest(cfg) + logging.info("All manifest files have been processed. Exiting.") + return + + +if __name__ == '__main__': + run_distributed_transcribe() # noqa pylint: disable=no-value-for-parameter diff --git a/nemo/collections/asr/data/audio_to_eou_label_lhotse.py b/nemo/collections/asr/data/audio_to_eou_label_lhotse.py new file mode 100644 index 000000000000..2283325d9cd4 --- /dev/null +++ b/nemo/collections/asr/data/audio_to_eou_label_lhotse.py @@ -0,0 +1,626 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Dict, List, Optional + +import numpy as np +import torch.utils.data +from lhotse.cut import Cut, CutSet, MixedCut +from lhotse.dataset import AudioSamples +from lhotse.dataset.collation import collate_vectors +from omegaconf import DictConfig, OmegaConf + +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.common.tokenizers.aggregate_tokenizer import TokenizerWrapper +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType +from nemo.utils import logging + +NON_SPEECH_LABEL = 0 +SPEECH_LABEL = 1 +EOU_LABEL = 2 +EOB_LABEL = 3 +EOU_STRING = '' +EOB_STRING = '' + + +EOU_LENGTH_PERTURBATION = ['speed', 'time_stretch'] +EOU_PROHIBITED_AUGMENTATIONS = ['random_segment'] + + +def first_supervised_cut(maybe_mixed_cut): + """ + Get the first supervised cut from a mixed cut, skip the noise cut in case the noise cut has supervision. + Args: + maybe_mixed_cut: Cut or MixedCut + Returns: + Cut: The first supervised cut from the mixed cut + """ + if isinstance(maybe_mixed_cut, MixedCut): + return [ + t.cut + for t in maybe_mixed_cut.tracks + if len(t.cut.supervisions) > 0 and not t.cut.custom.get("is_mixed_noise", False) + ][0] + return maybe_mixed_cut + + +@dataclass +class AudioToTextEOUBatch: + """ + Data class for ASR-EOU batch. + """ + + sample_ids: List | None = None + audio_filepaths: List | None = None + audio_signal: torch.Tensor | None = None + audio_lengths: torch.Tensor | None = None + text_tokens: torch.Tensor | None = None + text_token_lengths: torch.Tensor | None = None + eou_targets: torch.Tensor | None = None + eou_target_lengths: torch.Tensor | None = None + + +@dataclass +class RandomPaddingConfig: + prob: float = 0.9 # probability of applying padding + min_pad_duration: float = 0.0 # minimum duration of pre/post padding in seconds + max_pad_duration: float = 5.0 # maximum duration of pre/post padding in seconds + max_total_duration: float = 40.0 # maximum total duration of the padded audio in seconds + min_pre_pad_duration: float = 0.0 # minimum duration of pre-padding in seconds + min_post_pad_duration: float = 2.0 # minimum duration of post-padding in seconds + pad_distribution: str = 'uniform' # distribution of padding duration, 'uniform' or 'normal' or 'constant' + normal_mean: float = 0.5 # mean of normal distribution for padding duration + normal_std: float = 2.0 # standard deviation of normal distribution for padding duration + pre_pad_duration: float = 0.2 # amount of left-padding when pad_distribution='constant' + post_pad_duration: float = 3.0 # amount of right-padding when pad_distribution='constant' + + +class LhotseSpeechToTextBpeEOUDataset(torch.utils.data.Dataset): + """ + This dataset processes the audio data and the corresponding text data to generate the ASR labels, + along with EOU labels for each frame. The audios used in this dataset should only contain speech with + NO precedding or following silence. The dataset also randomly pads non-speech frames before and after + the audio signal for training EOU prediction task. + + To generate EOU labels, the last frame of utterance will be marked as "end of utterance" (labeled as `2`), + while if it's a backchannel utterance it'll be marked asd "end of backchannel" (labeled as `3`). + The rest of the speech frames will be marked as "speech" (labeled as `1`). + The padded non-speech signals will be marked as "non-speech" (labeled as 0). + + Args: + cfg: DictConfig object container following keys, usually taken from your `model.train_ds` + or `model.validation_ds` config: + ``` + sample_rate: # int, Sample rate of the audio signal + window_stride: # float, Window stride for audio encoder + subsampling_factor: # Subsampling factor for audio encoder + random_padding: # Random padding configuration + prob: 0.9 # probability of applying padding + min_pad_duration: 0.5 # minimum duration of pre/post padding in seconds + max_pad_duration: 2.0 # maximum duration of pre/post padding in seconds + max_total_duration: 30.0 # maximum total duration of the padded audio in seconds + pad_distribution: 'uniform' # distribution of padding duration, 'uniform' or 'normal' or 'constant' + normal_mean: 0.5 # mean of normal distribution for padding duration + normal_std: 2.0 # standard deviation of normal distribution for padding duration + pre_pad_duration: 0.2 # amount of left-padding when pad_distribution='constant' + post_pad_duration: 3.0 # amount of right-padding when pad_distribution='constant' + ``` + + Returns: + audio: torch.Tensor of audio signal + audio_lens: torch.Tensor of audio signal length + text_tokens: torch.Tensor of text text_tokens + text_token_lens: torch.Tensor of text token length + eou_targets (optional): torch.Tensor of EOU labels + eou_target_lens (optional): torch.Tensor of EOU label length + + The input manifest should be a jsonl file where each line is a python dictionary. + Example manifest sample: + { + "audio_filepath": "/path/to/audio.wav", + "offset": 0.0, + "duration": 6.0, + "sou_time": [0.3, 4.0], + "eou_time": [1.3, 4.5], + "utterances": ["Tell me a joke", "Ah-ha"], + "is_backchannel": [False, True], + } + + Padding logic: + 0. Don't pad when `random_padding` is None or during validation/test + 1. randomly draw a probability to decide whether to apply padding + 2. if not padding or audio duration is longer than the maximum duration, + 1) return the original audio and EOU labels + 3. if apply padding, + 1) get the max padding duration based on the maximum total duration and the audio duration + 2) randomly draw a total padding duration based on the given distribution + 3) randomly split the total padding duration into pre-padding and post-padding + 4) randomly generate the non-speech signal (audio signal=0) for pre-padding and post-padding + 5) concatenate the pre-padding, audio, and post-padding to get the padded audio signal + 6) update the EOU labels accordingly + + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Define the output types of the dataset.""" + return { + 'audio': NeuralType(('B', 'T'), AudioSignal()), + 'audio_lens': NeuralType(tuple('B'), LengthsType()), + 'eou_targets': NeuralType(('B', 'T'), LabelsType()), + 'eou_target_lens': NeuralType(tuple('B'), LengthsType()), + 'text_tokens': NeuralType(tuple('B', 'T'), LengthsType(), optional=True), + 'text_token_lens': NeuralType(tuple('B'), LengthsType(), optional=True), + } + + def __init__(self, cfg: DictConfig, tokenizer: TokenizerSpec, return_cuts: bool = False): + super().__init__() + self.cfg = cfg + self.return_cuts = return_cuts + self.eou_string = self.cfg.get('eou_string', EOU_STRING) + self.eob_string = self.cfg.get('eob_string', EOB_STRING) + if cfg.get('check_tokenizer', True): + self._check_special_tokens(tokenizer) + + self.tokenizer = TokenizerWrapper(tokenizer) + self.load_audio = AudioSamples(fault_tolerant=True) + self.sample_rate = self.cfg.get('sample_rate', 16000) + self.window_stride = self.cfg.get('window_stride', 0.01) + self.num_sample_per_mel_frame = int( + self.window_stride * self.sample_rate + ) # 160 samples for every 1ms by default + self.num_mel_frame_per_target_frame = int(self.cfg.get('subsampling_factor', 8)) + self.add_sep_before_eou = self.cfg.get('add_sep_before_eou', False) + self.add_eou_to_text = self.cfg.get('add_eou_to_text', True) + self.pad_eou_label_secs = self.cfg.get('pad_eou_label_secs', 0.0) + self.padding_cfg = self.cfg.get('random_padding', None) + if self.padding_cfg is not None: + self.padding_cfg = OmegaConf.to_container(self.padding_cfg, resolve=True) + self.padding_cfg = RandomPaddingConfig(**self.padding_cfg) + self.ignore_eob_label = self.cfg.get('ignore_eob_label', False) + self.augmentor = None + self.len_augmentor = None + self.skip_augment = self.cfg.get("skip_augment", False) + logging.info(f"EOU dataset with skip augmentations = {self.skip_augment}") + if self.cfg.get('augmentor', None) is not None: + augmentor = {} + len_augmentor = {} + aug_cfg = OmegaConf.to_container(self.cfg.augmentor, resolve=True) + for k, v in aug_cfg.items(): + if k in EOU_PROHIBITED_AUGMENTATIONS: + logging.warning(f"EOU dataset does not support {k} augmentation, skipping.") + continue + if k in EOU_LENGTH_PERTURBATION: + len_augmentor[k] = v + else: + augmentor[k] = v + + if len(augmentor) > 0: + logging.info(f"EOU dataset will apply augmentations: {augmentor}") + self.augmentor = process_augmentations(augmentor) + if len(len_augmentor) > 0: + logging.info(f"EOU dataset will apply length augmentations: {len_augmentor}") + self.len_augmentor = process_augmentations(len_augmentor) + + def _check_special_tokens(self, tokenizer: TokenizerSpec): + """ + Check if the special tokens are in the tokenizer vocab. + """ + special_tokens = set([self.eou_string, self.eob_string]) + vocab_size = tokenizer.vocab_size + special_tokens_in_vocab = set([tokenizer.ids_to_text(vocab_size - 1), tokenizer.ids_to_text(vocab_size - 2)]) + if special_tokens != special_tokens_in_vocab: + raise ValueError( + f"Input special tokens {special_tokens} don't match with the tokenizer vocab {special_tokens_in_vocab}. " + f"Please add them to tokenizer or change input `eou_string` and/or `eob_string` accordingly. " + "Special tokens should be added as the last two tokens in the new tokenizer. " + "Please refer to scripts/asr_end_of_utterance/tokenizers/add_special_tokens_to_sentencepiece.py for details." + ) + + def _simple_getitem(self, cuts: CutSet) -> AudioToTextEOUBatch: + """ + Simple getitem function when skipping all augmentations. + """ + audio, audio_lens, cuts = self.load_audio(cuts) + if self.return_cuts: + return audio, audio_lens, cuts + + eou_targets = [] + text_tokens = [] + sample_ids = [] + audio_filepaths = [] + for i in range(len(cuts)): + c = cuts[i] + if isinstance(c, MixedCut): + c = first_supervised_cut(c) + + sample_ids.append(c.id) + audio_filepaths.append(c.recording.sources[0].source) + # Get EOU labels and text tokens + eou_targets_i = self._get_frame_labels(c, audio_lens[i]) + text_tokens_i = self._get_text_tokens(c) + eou_targets.append(eou_targets_i) + text_tokens.append(text_tokens_i) + + eou_target_lens = torch.tensor([t.size(0) for t in eou_targets], dtype=torch.long) + eou_targets = collate_vectors(eou_targets, padding_value=0) + text_token_lens = torch.tensor([t.size(0) for t in text_tokens], dtype=torch.long) + text_tokens = collate_vectors(text_tokens, padding_value=0) + return AudioToTextEOUBatch( + sample_ids=sample_ids, + audio_filepaths=audio_filepaths, + audio_signal=audio, + audio_lengths=audio_lens, + text_tokens=text_tokens, + text_token_lengths=text_token_lens, + eou_targets=eou_targets, + eou_target_lengths=eou_target_lens, + ) + + def __getitem__(self, cuts: CutSet) -> AudioToTextEOUBatch: + if self.skip_augment: + return self._simple_getitem(cuts) + + audio, audio_lens, cuts = self.load_audio(cuts) + audio_signals = [] + audio_lengths = [] + eou_targets = [] + text_tokens = [] + sample_ids = [] + audio_filepaths = [] + + for i in range(len(cuts)): + c = cuts[i] + if isinstance(c, MixedCut): + c = c.first_non_padding_cut + + sample_ids.append(c.id) + audio_filepaths.append(c.recording.sources[0].source) + + audio_i = audio[i] + audio_len_i = audio_lens[i] + + # Maybe apply speed perturbation, this has to be done before getting the EOU labels + audio_i, audio_len_i = self._maybe_augment_length(audio_i, audio_len_i) + + # Get EOU labels and text tokens + eou_targets_i = self._get_frame_labels(c, audio_len_i) + text_tokens_i = self._get_text_tokens(c) + + # Maybe apply random padding to both sides of the audio + audio_i, audio_len_i, eou_targets_i = self._random_pad_audio(audio_i, audio_len_i, eou_targets_i) + + # Maybe apply augmentations to the audio signal after padding + audio_i, audio_len_i = self._maybe_augment_audio(audio_i, audio_len_i) + + # Append the processed audio, EOU labels, and text tokens to the lists + audio_signals.append(audio_i) + audio_lengths.append(audio_len_i) + eou_targets.append(eou_targets_i) + text_tokens.append(text_tokens_i) + + audio_signals = collate_vectors(audio_signals, padding_value=0) + audio_lengths = torch.tensor(audio_lengths, dtype=torch.long) + eou_target_lens = torch.tensor([t.size(0) for t in eou_targets], dtype=torch.long) + eou_targets = collate_vectors(eou_targets, padding_value=0) + text_token_lens = torch.tensor([t.size(0) for t in text_tokens], dtype=torch.long) + text_tokens = collate_vectors(text_tokens, padding_value=0) + + if self.return_cuts: + return audio_signals, audio_lengths, cuts + + return AudioToTextEOUBatch( + sample_ids=sample_ids, + audio_filepaths=audio_filepaths, + audio_signal=audio_signals, + audio_lengths=audio_lengths, + text_tokens=text_tokens, + text_token_lengths=text_token_lens, + eou_targets=eou_targets, + eou_target_lengths=eou_target_lens, + ) + + def _audio_len_to_frame_len(self, num_samples: int): + """ + Convert the raw audio length to the number of frames after audio encoder. + + self.num_sample_per_mel_frame = int( + self.cfg.get('window_stride', 0.01) * self.cfg.get('sample_rate', 16000) + ) # 160 samples for every 1ms by default + self.num_mel_frame_per_target_frame = int(self.cfg.get('subsampling_factor', 8)) + """ + mel_frame_count = math.ceil((num_samples + 1) / self.num_sample_per_mel_frame) + hidden_length = math.ceil(mel_frame_count / self.num_mel_frame_per_target_frame) + return hidden_length + + def _repeat_eou_labels(self, eou_targets: torch.Tensor) -> torch.Tensor: + """ + Repeat EOU labels according to self.pad_eou_label_secs + Args: + eou_targets: torch.Tensor of EOU labels, shape [T] + Returns: + eou_targets: torch.Tensor of padded EOU labels, shape [T] + """ + if not self.pad_eou_label_secs or self.pad_eou_label_secs <= 0: + return eou_targets + + eou_len = self._audio_len_to_frame_len(int(self.pad_eou_label_secs * self.sample_rate)) + + i = 0 + while i < eou_targets.size(0): + if eou_targets[i] == EOU_LABEL or eou_targets[i] == EOB_LABEL: + # repeat the label for the next eou_len samples + start = i + end = min(i + eou_len, eou_targets.size(0)) + j = start + 1 + while j < end: + if eou_targets[j] != NON_SPEECH_LABEL: + # do not overwrite the label if it's not non-speech + break + j += 1 + end = min(j, end) + # fill the non-speech label with the current EOU/EOB label + eou_targets[start:end] = eou_targets[i] + i = end + else: + i += 1 + return eou_targets + + def _get_frame_labels(self, cut: Cut, num_samples: int): + """ + Get the frame-level EOU labels for a single audio segment. + Args: + cut: Cut object + num_samples: int, the number of samples in the audio segment + Returns: + eou_targets: torch.Tensor of EOU labels, shape [T] + """ + hidden_length = self._audio_len_to_frame_len(num_samples) + if not "sou_time" in cut.custom or not "eou_time" in cut.custom: + # assume only single speech segment + text = cut.supervisions[0].text + if not text: + # skip empty utterances + return torch.zeros(hidden_length).long() + eou_targets = torch.ones(hidden_length).long() # speech label + eou_targets[-1] = EOU_LABEL # by default it's end of utterance + if cut.has_custom("is_backchannel") and cut.custom["is_backchannel"] and not self.ignore_eob_label: + eou_targets[-1] = EOB_LABEL # end of backchannel + return eou_targets + + sou_time = cut.custom["sou_time"] + eou_time = cut.custom["eou_time"] + if not isinstance(sou_time, list): + sou_time = [sou_time] + if not isinstance(eou_time, list): + eou_time = [eou_time] + + assert len(sou_time) == len( + eou_time + ), f"Number of SOU time and EOU time do not match: SOU ({sou_time}) vs EOU ({eou_time})" + + if cut.has_custom("is_backchannel"): + is_backchannel = cut.custom["is_backchannel"] + if not isinstance(is_backchannel, list): + is_backchannel = [is_backchannel] + assert len(sou_time) == len( + is_backchannel + ), f"Number of SOU and backchannel do not match: SOU ({len(sou_time)}) vs backchannel ({len(is_backchannel)})" + else: + is_backchannel = [False] * len(sou_time) + + eou_targets = torch.zeros(hidden_length).long() + for i in range(len(sou_time)): + if sou_time[i] is None or eou_time[i] is None or sou_time[i] < 0 or eou_time[i] < 0: + # skip empty utterances + continue + sou_idx = self._audio_len_to_frame_len(int((sou_time[i] - cut.start) * self.sample_rate)) + seg_len_in_secs = eou_time[i] - sou_time[i] + seg_len = self._audio_len_to_frame_len(int(seg_len_in_secs * self.sample_rate)) + eou_targets[sou_idx : sou_idx + seg_len] = SPEECH_LABEL + last_idx = min(sou_idx + seg_len - 1, hidden_length - 1) + if is_backchannel[i] and not self.ignore_eob_label: + eou_targets[last_idx] = EOB_LABEL # end of backchannel + else: + eou_targets[last_idx] = EOU_LABEL # end of utterance + + return eou_targets + + def _get_text_tokens(self, cut: Cut): + """ + Add EOU labels to the text and get the text tokens for a single audio segment. + Args: + cut: Cut object + Returns: + text_tokens: torch.Tensor of text tokens, shape [T] + """ + if not cut.has_custom("sou_time") or not cut.has_custom("eou_time") or not cut.has_custom("utterances"): + # assume only single speech segment + utterances = [cut.supervisions[0].text] + else: + utterances = cut.custom["utterances"] + + if not isinstance(utterances, list): + utterances = [utterances] + + if cut.has_custom("is_backchannel"): + is_backchannel = cut.custom["is_backchannel"] + if not isinstance(is_backchannel, list): + is_backchannel = [is_backchannel] + assert len(utterances) == len( + is_backchannel + ), f"Number of utterances and backchannel do not match: utterance ({len(utterances)}) vs backchannel ({len(is_backchannel)})" + else: + is_backchannel = [False] * len(utterances) + + total_text = "" + for i, text in enumerate(utterances): + if not text: + # skip empty utterances + continue + if self.add_eou_to_text: + eou_string = self.eob_string if is_backchannel[i] and not self.ignore_eob_label else self.eou_string + if self.add_sep_before_eou: + eou_string = " " + eou_string + else: + eou_string = "" + total_text += text + eou_string + " " + total_text = total_text.strip() + return torch.as_tensor(self.tokenizer(total_text)) + + def _random_pad_audio(self, audio: torch.Tensor, audio_len: torch.Tensor, eou_targets: torch.Tensor): + """ + Randomly pad the audio signal with non-speech signal before and after the audio signal. + Args: + audio: torch.Tensor of a single audio signal, shape [T] + audio_len: torch.Tensor of audio signal length, shape [1] + eou_targets: torch.Tensor of EOU labels, shape [T] + Returns: + padded_audio: torch.Tensor of padded audio signal, shape [T+padding] + padded_audio_len: torch.Tensor of padded audio signal length, shape [1] + padded_eou_targets: torch.Tensor of padded EOU labels, shape [T+padding] + padded_eou_targets_len: torch.Tensor of padded EOU label length, shape [1] + """ + p = np.random.rand() + if self.padding_cfg is None or p > self.padding_cfg.prob: + # don't apply padding + eou_targets = self._repeat_eou_labels(eou_targets) + return audio, audio_len, eou_targets + + duration = audio_len.item() / self.cfg.sample_rate + # if already longer than the maximum duration, return the original audio + if duration >= self.padding_cfg.max_total_duration: + return audio, audio_len, eou_targets + + # apply padding + audio = audio[:audio_len] + + self.padding_cfg.min_pre_pad_duration = max( + self.padding_cfg.min_pre_pad_duration, self.padding_cfg.min_pad_duration + ) + self.padding_cfg.min_post_pad_duration = max( + self.padding_cfg.min_post_pad_duration, self.padding_cfg.min_pad_duration + ) + + max_padding_duration = max(0, self.padding_cfg.max_total_duration - duration) + if max_padding_duration <= self.padding_cfg.min_pre_pad_duration + self.padding_cfg.min_post_pad_duration: + min_padding_duration = 0 + else: + min_padding_duration = self.padding_cfg.min_pre_pad_duration + self.padding_cfg.min_post_pad_duration + + pre_padding_duration = None + post_padding_duration = None + + if self.padding_cfg.pad_distribution == 'uniform': + total_padding_duration = np.random.uniform(min_padding_duration, max_padding_duration) + elif self.padding_cfg.pad_distribution == 'normal': + total_padding_duration = np.random.normal(self.padding_cfg.normal_mean, self.padding_cfg.normal_std) + total_padding_duration = max(min_padding_duration, min(max_padding_duration, total_padding_duration)) + elif self.padding_cfg.pad_distribution == 'constant': + pass + else: + raise ValueError(f"Unknown padding distribution: {self.padding_cfg.pad_distribution}") + + if self.padding_cfg.pad_distribution == 'constant': + pre_padding_duration = self.padding_cfg.pre_pad_duration + post_padding_duration = self.padding_cfg.post_pad_duration + elif min_padding_duration == 0: + pre_padding_duration = total_padding_duration / 2 + post_padding_duration = total_padding_duration / 2 + else: + post_padding_duration = np.random.uniform( + self.padding_cfg.min_post_pad_duration, total_padding_duration - self.padding_cfg.min_pre_pad_duration + ) + pre_padding_duration = total_padding_duration - post_padding_duration + + if self.padding_cfg.max_pad_duration is not None: + pre_padding_duration = min(pre_padding_duration, self.padding_cfg.max_pad_duration) + post_padding_duration = min(post_padding_duration, self.padding_cfg.max_pad_duration) + + pre_padding_len = math.ceil(pre_padding_duration * self.cfg.sample_rate) + post_padding_len = math.ceil(post_padding_duration * self.cfg.sample_rate) + + # pad the audio signal + pre_padding = torch.zeros(pre_padding_len, dtype=audio.dtype) + post_padding = torch.zeros(post_padding_len, dtype=audio.dtype) + padded_audio = torch.cat((pre_padding, audio, post_padding), dim=0) + padded_audio_len = audio_len + pre_padding_len + post_padding_len + + # pad the EOU labels + pre_padding_eou_len = self._audio_len_to_frame_len(pre_padding_len) + post_padding_eou_len = self._audio_len_to_frame_len(post_padding_len) + pre_padding_eou = torch.zeros(pre_padding_eou_len, dtype=eou_targets.dtype) + post_padding_eou = torch.zeros(post_padding_eou_len, dtype=eou_targets.dtype) + padded_eou_targets = torch.cat((pre_padding_eou, eou_targets, post_padding_eou), dim=0) + + padded_eou_targets = self._repeat_eou_labels(padded_eou_targets) + return padded_audio, padded_audio_len, padded_eou_targets + + def _maybe_augment_audio(self, audio: torch.Tensor, audio_len: torch.Tensor): + """ + Apply augmentation to the audio signal if augmentor is provided. + Args: + audio: torch.Tensor of a single audio signal, shape [T] + audio_len: torch.Tensor of audio signal length, shape [1] + Returns: + augmented_audio: torch.Tensor of augmented audio signal, shape [T] + augmented_audio_len: torch.Tensor of augmented audio signal length, shape [1] + """ + if self.augmentor is None: + return audio, audio_len + + # Cast to AudioSegment + audio_segment = AudioSegment( + samples=audio[:audio_len].numpy(), + sample_rate=self.sample_rate, + offset=0, + duration=audio_len.item() / self.sample_rate, + ) + # Apply augmentation + self.augmentor.perturb(audio_segment) + audio = torch.from_numpy(audio_segment.samples).float() + audio_len = audio.size(0) + + return audio, audio_len + + def _maybe_augment_length(self, audio: torch.Tensor, audio_len: torch.Tensor): + """ + Apply length augmentation (e.g., speed perturb) to the audio signal if augmentor is provided. + Args: + audio: torch.Tensor of a single audio signal, shape [T] + audio_len: torch.Tensor of audio signal length, shape [1] + Returns: + augmented_audio: torch.Tensor of augmented audio signal, shape [T] + augmented_audio_len: torch.Tensor of augmented audio signal length, shape [1] + """ + if self.len_augmentor is None: + return audio, audio_len + + # Cast to AudioSegment + audio_segment = AudioSegment( + samples=audio[:audio_len].numpy(), + sample_rate=self.sample_rate, + offset=0, + duration=audio_len.item() / self.sample_rate, + ) + # Apply augmentation + self.len_augmentor.perturb(audio_segment) + audio = torch.from_numpy(audio_segment.samples).float() + audio_len = audio.size(0) + + return audio, audio_len diff --git a/nemo/collections/asr/losses/ssl_losses/mlm.py b/nemo/collections/asr/losses/ssl_losses/mlm.py index 424374869c3d..4ed6f580bbb2 100644 --- a/nemo/collections/asr/losses/ssl_losses/mlm.py +++ b/nemo/collections/asr/losses/ssl_losses/mlm.py @@ -65,11 +65,14 @@ def forward( if masks is None: masks = spec_masks - # B,D,T -> B,T,D - masks = masks.transpose(1, 2) + if masks is None: + masks = torch.ones_like(decoder_outputs, dtype=torch.bool) + else: + # B,D,T -> B,T,D + masks = masks.transpose(1, 2) - masks = masks.reshape(masks.shape[0], masks.shape[1] // self.combine_time_steps, -1) - masks = masks.mean(-1) > self.mask_threshold + masks = masks.reshape(masks.shape[0], masks.shape[1] // self.combine_time_steps, -1) + masks = masks.mean(-1) > self.mask_threshold out_masked_only = decoder_outputs[masks] targets = F.pad(targets, (0, masks.shape[-1] - targets.shape[-1])) diff --git a/nemo/collections/asr/metrics/wer.py b/nemo/collections/asr/metrics/wer.py index 719af4adcd3b..4e9c6021b604 100644 --- a/nemo/collections/asr/metrics/wer.py +++ b/nemo/collections/asr/metrics/wer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy from typing import List, Optional, Tuple, Union import editdistance @@ -255,7 +256,7 @@ def __init__( batch_dim_index=0, dist_sync_on_step=False, sync_on_compute=True, - **kwargs, + return_hypotheses=False, ): super().__init__(dist_sync_on_step=dist_sync_on_step, sync_on_compute=sync_on_compute) @@ -264,30 +265,33 @@ def __init__( self.log_prediction = log_prediction self.fold_consecutive = fold_consecutive self.batch_dim_index = batch_dim_index + self.return_hypotheses = return_hypotheses self.decode = None if isinstance(self.decoding, AbstractRNNTDecoding): self.decode = lambda predictions, predictions_lengths, predictions_mask, input_ids: self.decoding.rnnt_decoder_predictions_tensor( - encoder_output=predictions, encoded_lengths=predictions_lengths + encoder_output=predictions, encoded_lengths=predictions_lengths, return_hypotheses=return_hypotheses ) elif isinstance(self.decoding, AbstractCTCDecoding): self.decode = lambda predictions, predictions_lengths, predictions_mask, input_ids: self.decoding.ctc_decoder_predictions_tensor( decoder_outputs=predictions, decoder_lengths=predictions_lengths, fold_consecutive=self.fold_consecutive, + return_hypotheses=return_hypotheses, ) elif isinstance(self.decoding, AbstractMultiTaskDecoding): self.decode = lambda predictions, prediction_lengths, predictions_mask, input_ids: self.decoding.decode_predictions_tensor( encoder_hidden_states=predictions, encoder_input_mask=predictions_mask, decoder_input_ids=input_ids, - return_hypotheses=False, + return_hypotheses=return_hypotheses, ) else: raise TypeError(f"WER metric does not support decoding of type {type(self.decoding)}") self.add_state("scores", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) self.add_state("words", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.hypotheses = None def update( self, @@ -352,8 +356,22 @@ def update( self.scores = torch.tensor(scores, device=self.scores.device, dtype=self.scores.dtype) self.words = torch.tensor(words, device=self.words.device, dtype=self.words.dtype) + self.hypotheses = hypotheses + return None def compute(self): scores = self.scores.detach().float() words = self.words.detach().float() return scores / words, scores, words + + def reset(self): + super().reset() + self.hypotheses = None + + def get_hypotheses(self): + """ + Returns the hypotheses generated during the last call to update. + """ + if self.hypotheses is None: + raise ValueError("No hypotheses available. Please call update() first.") + return deepcopy(self.hypotheses) diff --git a/nemo/collections/asr/models/asr_eou_models.py b/nemo/collections/asr/models/asr_eou_models.py new file mode 100644 index 000000000000..2a42bbb85410 --- /dev/null +++ b/nemo/collections/asr/models/asr_eou_models.py @@ -0,0 +1,949 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from lightning.pytorch.utilities import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict + +from nemo.collections.asr.data.audio_to_eou_label_lhotse import ( + EOB_LABEL, + EOB_STRING, + EOU_LABEL, + EOU_STRING, + AudioToTextEOUBatch, + LhotseSpeechToTextBpeEOUDataset, +) +from nemo.collections.asr.metrics.wer import WER +from nemo.collections.asr.models import EncDecHybridRNNTCTCBPEModel, EncDecRNNTBPEModel +from nemo.collections.asr.parts.mixins import TranscribeConfig +from nemo.collections.asr.parts.utils.eou_utils import ( + EOUResult, + cal_eou_metrics_from_frame_labels, + flatten_nested_list, +) +from nemo.collections.asr.parts.utils.manifest_utils import write_manifest +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.common.data.utils import move_data_to_device +from nemo.core.classes.mixins import AccessMixin +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType +from nemo.utils import logging + +__all__ = ['EncDecRNNTBPEEOUModel', 'EncDecHybridRNNTCTCBPEEOUModel'] + + +@dataclass +class EOUPrediction: + eou_probs: Optional[List[float]] = None + eob_probs: Optional[List[float]] = None + eou_preds: Optional[List[bool]] = None + eob_preds: Optional[List[bool]] = None + + +class ASREOUModelMixin: + def __init__(self): + if not hasattr(self, 'tokenizer'): + self.tokenizer = None + if not hasattr(self, 'eou_token'): + self.eou_token = None + if not hasattr(self, 'eob_token'): + self.eob_token = None + if not hasattr(self, 'frame_len_in_secs'): + self.frame_len_in_secs = None + + def setup_eou_mixin(self, eou_token: int, eob_token: int, frame_len_in_secs: float, tokenizer): + if getattr(self, 'eou_token', None) is None: + self.eou_token = eou_token + if getattr(self, 'eob_token', None) is None: + self.eob_token = eob_token + if getattr(self, 'frame_len_in_secs', None) is None: + self.frame_len_in_secs = frame_len_in_secs + if getattr(self, 'tokenizer', None) is None: + self.tokenizer = tokenizer + + def _patch_decoding_cfg(self, cfg: DictConfig): + """ + Patch the decoding config as needed for EOU computation + """ + with open_dict(cfg): + cfg.decoding.preserve_alignments = True + cfg.decoding.compute_timestamps = True + + def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: + """ + PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + """ + batch = move_data_to_device(batch, device) + return batch + + def _get_text_from_tokens(self, tokens: torch.Tensor, tokens_len: Optional[torch.Tensor] = None) -> List[str]: + """ + Convert tokens to text. + Args: + tokens: tensor of tokens + Returns: + text: list of text + """ + text_list = [] + for i in range(len(tokens)): + tokens_i = tokens[i] + if tokens_len is not None: + tokens_i = tokens[i][: tokens_len[i]] + tokens_i = [int(x) for x in tokens_i if x < self.tokenizer.vocab_size] + text = self.tokenizer.ids_to_text(tokens_i) + text_list.append(text) + return text_list + + def _get_eou_predictions_from_hypotheses( + self, hypotheses: List[Hypothesis], batch: AudioToTextEOUBatch + ) -> List[EOUPrediction]: + """ + Get EOU predictions from the hypotheses. + Args: + hypotheses: batch of hypotheses + Returns: + eou_predictions: list of EOU predictions + """ + eou_predictions = [] + + for hyp in hypotheses: + # Process one hypothesis at a time + eou_probs = [] + eob_probs = [] + eou_preds = [] + eob_preds = [] + if isinstance(hyp.alignments, tuple): + # CTC + probs = torch.softmax(hyp.alignments[0], dim=-1) # [time, num_classes] + tokens = hyp.alignments[1] + eou_probs = probs[:, self.eou_token].tolist() + eob_probs = probs[:, self.eob_token].tolist() + eou_preds = [int(x) == self.eou_token for x in tokens] + eob_preds = [int(x) == self.eob_token for x in tokens] + else: + # RNNT, each timestamp has a list of (prob, token) tuples + for alignment in hyp.alignments: + # Process for each timestamp + probs = torch.softmax(torch.stack([a[0] for a in alignment], dim=0), dim=-1) # unfold RNNT preds + tokens = torch.stack([a[1] for a in alignment], dim=0) # unfold RNNT preds + + # Get the max prob for eou and eob + # and check if eou and eob are predicted + max_eou_prob = probs[:, self.eou_token].max().item() + max_eob_prob = probs[:, self.eob_token].max().item() + eou_pred = torch.any(tokens == self.eou_token).item() + eob_pred = torch.any(tokens == self.eob_token).item() + + eou_probs.append(max_eou_prob) + eob_probs.append(max_eob_prob) + eou_preds.append(eou_pred) + eob_preds.append(eob_pred) + + eou_predictions.append( + EOUPrediction( + eou_probs=eou_probs, + eob_probs=eob_probs, + eou_preds=eou_preds, + eob_preds=eob_preds, + ) + ) + + return eou_predictions + + def _pad_to_same_length(self, eou_labels: List[float], eou_preds: List[float]) -> Tuple[List[float], List[float]]: + """ + Pad the EOU labels and predictions to the same length. + Args: + eou_labels: list of EOU labels + eou_preds: list of EOU predictions + Returns: + eou_labels: list of EOU labels, padded to the same length + eou_preds: list of EOU predictions, padded to the same length + """ + if len(eou_labels) < len(eou_preds): + eou_labels = eou_labels + [0] * (len(eou_preds) - len(eou_labels)) + elif len(eou_labels) > len(eou_preds): + eou_preds = eou_preds + [0] * (len(eou_labels) - len(eou_preds)) + return eou_labels, eou_preds + + def _calculate_eou_metrics( + self, eou_predictions: List[EOUPrediction], batch: AudioToTextEOUBatch + ) -> Tuple[List[EOUResult], List[EOUResult]]: + """ + Calculate EOU metrics. + Args: + eou_predictions: list of EOU predictions + batch: batch of data + Returns: + eou_metrics_list: list of EOU metrics, each is of type EOUResult + eob_metrics_list: list of EOB metrics, each is of type EOUResult + """ + # Get the ground truth EOU labels + eou_labels = batch.eou_targets + eou_labels_len = batch.eou_target_lengths + + # Calculate EOU metrics + eou_metrics_list = [] + eob_metrics_list = [] + for i, eou_prediction in enumerate(eou_predictions): + eou_preds_i = [float(x) for x in eou_prediction.eou_preds] + eob_preds_i = [float(x) for x in eou_prediction.eob_preds] + + eou_labels_i = (eou_labels[i][: eou_labels_len[i]] == EOU_LABEL).float().tolist() + eob_labels_i = (eou_labels[i][: eou_labels_len[i]] == EOB_LABEL).float().tolist() + + # Pad the EOU labels and predictions to the same length with zeros + eou_labels_i, eou_preds_i = self._pad_to_same_length(eou_labels_i, eou_preds_i) + eob_labels_i, eob_preds_i = self._pad_to_same_length(eob_labels_i, eob_preds_i) + + # Calculate EOU metrics + eou_metrics = cal_eou_metrics_from_frame_labels( + prediction=eou_preds_i, + reference=eou_labels_i, + threshold=0.0, + collar=0.0, + frame_len_in_secs=self.frame_len_in_secs, + ) # type: EOUResult + + eob_metrics = cal_eou_metrics_from_frame_labels( + prediction=eob_preds_i, + reference=eob_labels_i, + threshold=0.0, + collar=0.0, + frame_len_in_secs=self.frame_len_in_secs, + ) + + eou_metrics_list.append(eou_metrics) + eob_metrics_list.append(eob_metrics) + + return eou_metrics_list, eob_metrics_list + + def _get_percentiles(self, values: List[float], percentiles: List[float], tag: str = "") -> Dict[str, float]: + """ + Get the percentiles of a list of values. + Args: + values: list of values + percentiles: list of percentiles + Returns: + metrics: Dict of percentiles + """ + if len(values) == 0: + return [0.0] * len(percentiles) + results = np.percentile(values, percentiles).tolist() + metrics = {} + if tag: + tag += "_" + for i, p in enumerate(percentiles): + metrics[f'{tag}p{int(p)}'] = float(results[i]) + return metrics + + def _aggregate_eou_metrics(self, outputs: List[dict], mode: str, is_ctc: bool = False): + if f'{mode}_eou_metrics' not in outputs[0] and not is_ctc: + return {} + if f'{mode}_eou_metrics_ctc' not in outputs[0] and is_ctc: + return {} + + # Aggregate EOU/EOB metrics + eou_metrics = [] # type: List[EOUResult] + eob_metrics = [] # type: List[EOUResult] + for x in outputs: + if is_ctc: + eou_metrics.extend(x[f'{mode}_eou_metrics_ctc']) + eob_metrics.extend(x[f'{mode}_eob_metrics_ctc']) + else: + eou_metrics.extend(x[f'{mode}_eou_metrics']) + eob_metrics.extend(x[f'{mode}_eob_metrics']) + num_eou_utterances = sum([x.num_utterances for x in eou_metrics]) + eou_latency = flatten_nested_list([x.latency for x in eou_metrics]) + eou_early_cutoff = flatten_nested_list([x.early_cutoff for x in eou_metrics]) + + num_eob_utterances = sum([x.num_utterances for x in eob_metrics]) + eob_latency = flatten_nested_list([x.latency for x in eob_metrics]) + eob_early_cutoff = flatten_nested_list([x.early_cutoff for x in eob_metrics]) + + eou_avg_num_early_cutoff = len(eou_early_cutoff) / num_eou_utterances if num_eou_utterances > 0 else 0.0 + eob_avg_num_early_cutoff = len(eob_early_cutoff) / num_eob_utterances if num_eob_utterances > 0 else 0.0 + if len(eou_latency) == 0: + eou_latency = [0.0] + if len(eou_early_cutoff) == 0: + eou_early_cutoff = [0.0] + if len(eob_latency) == 0: + eob_latency = [0.0] + if len(eob_early_cutoff) == 0: + eob_early_cutoff = [0.0] + + eou_missing = [x.missing for x in eou_metrics] + eob_missing = [x.missing for x in eob_metrics] + + tensorboard_logs = {} + target_percentiles = [50, 90, 95] + eou_latency_metrics = self._get_percentiles(eou_latency, target_percentiles, tag=f'{mode}_eou_latency') + eou_early_cutoff_metrics = self._get_percentiles( + eou_early_cutoff, target_percentiles, tag=f'{mode}_eou_early_cutoff' + ) + eob_latency_metrics = self._get_percentiles(eob_latency, target_percentiles, tag=f'{mode}_eob_latency') + eob_early_cutoff_metrics = self._get_percentiles( + eob_early_cutoff, target_percentiles, tag=f'{mode}_eob_early_cutoff' + ) + + tensorboard_logs.update(eou_latency_metrics) + tensorboard_logs.update(eou_early_cutoff_metrics) + tensorboard_logs.update(eob_latency_metrics) + tensorboard_logs.update(eob_early_cutoff_metrics) + + tensorboard_logs[f'{mode}_eou_early_cutoff_avg_num'] = eou_avg_num_early_cutoff + tensorboard_logs[f'{mode}_eob_early_cutoff_avg_num'] = eob_avg_num_early_cutoff + + tensorboard_logs[f'{mode}_eou_missing'] = ( + sum(eou_missing) / num_eou_utterances if num_eou_utterances > 0 else 0.0 + ) + tensorboard_logs[f'{mode}_eob_missing'] = ( + sum(eob_missing) / num_eob_utterances if num_eob_utterances > 0 else 0.0 + ) + + return tensorboard_logs + + @rank_zero_only + def _maybe_save_predictions( + self, outputs: List[Dict], mode: str = "val", dataloader_idx: int = 0 + ) -> Optional[Path]: + """ + Save predictions to disk. + Args: + outputs: list of outputs + mode: mode of the model, either 'val' or 'test' + Returns: + Path object if predictions are saved, None otherwise. + """ + + if not self.cfg.get('save_pred_to_file', None): + return None + + output_file = Path(self.cfg.save_pred_to_file) + output_file.parent.mkdir(parents=True, exist_ok=True) + + if getattr(self, '_validation_names', None): + output_file = output_file.with_name(f"{self._validation_names[dataloader_idx]}_{output_file.name}") + else: + output_file = output_file.with_suffix(f'.{dataloader_idx}.json') + + manifest = [] + for output in outputs: + for i in range(len(output[f'{mode}_sample_id'])): + item = { + "sample_id": output[f'{mode}_sample_id'][i], + "audio_filepath": output[f'{mode}_audio_filepath'][i], + "eou_text": output[f'{mode}_text_gt'][i], + "eou_pred_text": output[f'{mode}_text_pred'][i], + "is_backchannel": bool(str(output[f'{mode}_text_gt'][i]).endswith(EOB_STRING)), + } + if f"{mode}_text_pred_ctc" in output: + item["eou_pred_text_ctc"] = output[f"{mode}_text_pred_ctc"][i] + + eou_metrics = {f"eou_{k}": v for k, v in output[f"{mode}_eou_metrics"][i].to_dict().items()} + eob_metrics = {f"eob_{k}": v for k, v in output[f"{mode}_eob_metrics"][i].to_dict().items()} + item.update(eou_metrics) + item.update(eob_metrics) + manifest.append(item) + write_manifest(output_file, manifest) + logging.info(f"Predictions saved to {output_file}") + return output_file + + +class EncDecRNNTBPEEOUModel(EncDecRNNTBPEModel, ASREOUModelMixin): + def __init__(self, cfg: DictConfig, trainer): + + self._patch_decoding_cfg(cfg) + super().__init__(cfg=cfg, trainer=trainer) + + self.eou_token = self.tokenizer.token_to_id(EOU_STRING) + self.eob_token = self.tokenizer.token_to_id(EOB_STRING) + self.frame_len_in_secs = self.cfg.preprocessor.window_stride * self.cfg.encoder.subsampling_factor + + self.setup_eou_mixin(self.eou_token, self.eob_token, self.frame_len_in_secs, self.tokenizer) + + self.wer = WER( + decoding=self.decoding, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + log_prediction=self._cfg.get('log_prediction', True), + dist_sync_on_step=True, + return_hypotheses=True, + ) + + # Setup fused Joint step if flag is set + if self.joint.fuse_loss_wer: + self.joint.set_loss(self.loss) + self.joint.set_wer(self.wer) + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + cfg = OmegaConf.create(config) if not isinstance(config, DictConfig) else config + dataset = LhotseSpeechToTextBpeEOUDataset( + cfg=cfg, tokenizer=self.tokenizer, return_cuts=config.get("do_transcribe", False) + ) + return get_lhotse_dataloader_from_config( + config, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), + dataset=dataset, + tokenizer=self.tokenizer, + ) + + def _transcribe_forward(self, batch: AudioToTextEOUBatch, trcfg: TranscribeConfig): + encoded, encoded_len = self.forward(input_signal=batch.audio_signal, input_signal_length=batch.audio_lengths) + output = dict(encoded=encoded, encoded_len=encoded_len) + return output + + def training_step(self, batch: AudioToTextEOUBatch, batch_nb): + # Reset access registry + if AccessMixin.is_access_enabled(self.model_guid): + AccessMixin.reset_registry(self) + + signal = batch.audio_signal + signal_len = batch.audio_lengths + transcript = batch.text_tokens + transcript_len = batch.text_token_lengths + + # forward() only performs encoder forward + encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len) + del signal + + # During training, loss must be computed, so decoder forward is necessary + decoder, target_length, states = self.decoder(targets=transcript, target_length=transcript_len) + + if hasattr(self, '_trainer') and self._trainer is not None: + log_every_n_steps = self._trainer.log_every_n_steps + sample_id = self._trainer.global_step + else: + log_every_n_steps = 1 + sample_id = batch_nb + + # If experimental fused Joint-Loss-WER is not used + if not self.joint.fuse_loss_wer: + # Compute full joint and loss + joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder) + loss_value = self.loss( + log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length + ) + + # Add auxiliary losses, if registered + loss_value = self.add_auxiliary_losses(loss_value) + + # Reset access registry + if AccessMixin.is_access_enabled(self.model_guid): + AccessMixin.reset_registry(self) + + tensorboard_logs = { + 'train_loss': loss_value, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'global_step': torch.tensor(self.trainer.global_step, dtype=torch.float32), + } + + if (sample_id + 1) % log_every_n_steps == 0: + self.wer.update( + predictions=encoded, + predictions_lengths=encoded_len, + targets=transcript, + targets_lengths=transcript_len, + ) + _, scores, words = self.wer.compute() + self.wer.reset() + tensorboard_logs.update({'training_batch_wer': scores.float() / words}) + + else: + # If experimental fused Joint-Loss-WER is used + if (sample_id + 1) % log_every_n_steps == 0: + compute_wer = True + else: + compute_wer = False + + # Fused joint step + loss_value, wer, _, _ = self.joint( + encoder_outputs=encoded, + decoder_outputs=decoder, + encoder_lengths=encoded_len, + transcripts=transcript, + transcript_lengths=transcript_len, + compute_wer=compute_wer, + ) + + # Add auxiliary losses, if registered + loss_value = self.add_auxiliary_losses(loss_value) + + # Reset access registry + if AccessMixin.is_access_enabled(self.model_guid): + AccessMixin.reset_registry(self) + + tensorboard_logs = { + 'train_loss': loss_value, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'global_step': torch.tensor(self.trainer.global_step, dtype=torch.float32), + } + + if compute_wer: + tensorboard_logs.update({'training_batch_wer': wer}) + + # Log items + self.log_dict(tensorboard_logs) + + # Preserve batch acoustic model T and language model U parameters if normalizing + if self._optim_normalize_joint_txu: + self._optim_normalize_txu = [encoded_len.max(), transcript_len.max()] + + return {'loss': loss_value} + + def predict_step(self, batch: AudioToTextEOUBatch, batch_idx, dataloader_idx=0): + signal = batch.audio_signal + signal_len = batch.audio_lengths + + # forward() only performs encoder forward + encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len) + del signal + + best_hyp_text = self.decoding.rnnt_decoder_predictions_tensor( + encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False + ) + + return list(best_hyp_text) + + def validation_pass(self, batch: AudioToTextEOUBatch, batch_idx: int, dataloader_idx: int = 0): + signal = batch.audio_signal + signal_len = batch.audio_lengths + transcript = batch.text_tokens + transcript_len = batch.text_token_lengths + + # forward() only performs encoder forward + encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len) + del signal + + tensorboard_logs = {} + + if self.cfg.get('save_pred_to_file', None): + text_gt = self._get_text_from_tokens(transcript, transcript_len) + tensorboard_logs['val_sample_id'] = batch.sample_ids + tensorboard_logs['val_audio_filepath'] = batch.audio_filepaths + tensorboard_logs['val_text_gt'] = text_gt + # If experimental fused Joint-Loss-WER is not used + if not self.joint.fuse_loss_wer: + if self.compute_eval_loss: + decoder, target_length, states = self.decoder(targets=transcript, target_length=transcript_len) + joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder) + + loss_value = self.loss( + log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length + ) + + tensorboard_logs['val_loss'] = loss_value + + self.wer.update( + predictions=encoded, + predictions_lengths=encoded_len, + targets=transcript, + targets_lengths=transcript_len, + ) + hypotheses = self.wer.get_hypotheses() + + if self.cfg.get('save_pred_to_file', None): + text_pred = self._get_text_from_tokens([x.y_sequence for x in hypotheses]) + tensorboard_logs['val_text_pred'] = text_pred + + if self.cfg.get('calculate_eou_metrics', True): + eou_predictions = self._get_eou_predictions_from_hypotheses(hypotheses, batch) + eou_metrics_list, eob_metrics_list = self._calculate_eou_metrics(eou_predictions, batch) + else: + eou_metrics_list = [] + eob_metrics_list = [] + + wer, wer_num, wer_denom = self.wer.compute() + self.wer.reset() + + tensorboard_logs['val_wer_num'] = wer_num + tensorboard_logs['val_wer_denom'] = wer_denom + tensorboard_logs['val_wer'] = wer + tensorboard_logs['val_eou_metrics'] = eou_metrics_list + tensorboard_logs['val_eob_metrics'] = eob_metrics_list + + else: + # If experimental fused Joint-Loss-WER is used + compute_wer = True + + if self.compute_eval_loss: + decoded, target_len, states = self.decoder(targets=transcript, target_length=transcript_len) + else: + decoded = None + target_len = transcript_len + + # Fused joint step + loss_value, wer, wer_num, wer_denom = self.joint( + encoder_outputs=encoded, + decoder_outputs=decoded, + encoder_lengths=encoded_len, + transcripts=transcript, + transcript_lengths=target_len, + compute_wer=compute_wer, + keep_hypotheses=True, + ) + + hypotheses = self.joint.get_hypotheses() + + if self.cfg.get('save_pred_to_file', None): + text_pred = self._get_text_from_tokens([x.y_sequence for x in hypotheses]) + tensorboard_logs['val_text_pred'] = text_pred + + if self.cfg.get('calculate_eou_metrics', True): + eou_predictions = self._get_eou_predictions_from_hypotheses(hypotheses, batch) + eou_metrics_list, eob_metrics_list = self._calculate_eou_metrics(eou_predictions, batch) + else: + eou_metrics_list = [] + eob_metrics_list = [] + + if loss_value is not None: + tensorboard_logs['val_loss'] = loss_value + + tensorboard_logs['val_wer_num'] = wer_num + tensorboard_logs['val_wer_denom'] = wer_denom + tensorboard_logs['val_wer'] = wer + tensorboard_logs['val_eou_metrics'] = eou_metrics_list + tensorboard_logs['val_eob_metrics'] = eob_metrics_list + + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return tensorboard_logs + + def multi_inference_epoch_end(self, outputs, dataloader_idx: int = 0, mode: str = "val"): + assert mode in ['val', 'test'], f"Invalid mode: {mode}. Must be 'val' or 'test'." + + if not outputs: + logging.warning( + f"No outputs received for {mode} dataloader {dataloader_idx}. Skipping epoch end processing." + ) + return {} + + self._maybe_save_predictions(outputs, mode=mode, dataloader_idx=dataloader_idx) + + # Aggregate WER metrics + if self.compute_eval_loss: + loss_mean = torch.stack([x[f'{mode}_loss'] for x in outputs]).mean() + loss_log = {f'{mode}_loss': loss_mean} + else: + loss_log = {} + wer_num = torch.stack([x[f'{mode}_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x[f'{mode}_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {**loss_log, f'{mode}_wer': wer_num.float() / wer_denom} + + eou_metrics = {} + if self.cfg.get('calculate_eou_metrics', True): + eou_metrics = self._aggregate_eou_metrics(outputs, mode=mode) + tensorboard_logs.update(eou_metrics) + + return {**loss_log, 'log': tensorboard_logs} + + def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): + return self.multi_inference_epoch_end(outputs, dataloader_idx, mode='val') + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + return self.multi_inference_epoch_end(outputs, dataloader_idx, mode='test') + + @property + def oomptimizer_schema(self) -> dict: + """ + Return a typing schema for optimal batch size calibration for various + sequence lengths using OOMptimizer. + """ + return { + "cls": AudioToTextEOUBatch, + "inputs": [ + {"type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input", "name": "audio_signal"}, + {"type": NeuralType(("B",), LengthsType()), "seq_length": "input", "name": "audio_lengths"}, + { + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "output", + "name": "text_tokens", + "vocab_size": self.tokenizer.vocab_size, + }, + {"type": NeuralType(("B",), LengthsType()), "seq_length": "output", "name": "text_token_lengths"}, + { + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "output", + "name": "eou_targets", + "vocab_size": 4, + }, + {"type": NeuralType(("B",), LengthsType()), "seq_length": "output", "name": "eou_target_lengths"}, + ], + } + + +class EncDecHybridRNNTCTCBPEEOUModel(EncDecHybridRNNTCTCBPEModel, ASREOUModelMixin): + def __init__(self, cfg: DictConfig, trainer): + self._patch_decoding_cfg(cfg) + if cfg.aux_ctc.get('decoding', None) is not None: + with open_dict(cfg): + cfg.aux_ctc.decoding.preserve_alignments = True + cfg.aux_ctc.decoding.compute_timestamps = True + + super().__init__(cfg=cfg, trainer=trainer) + + self.eou_token = self.tokenizer.token_to_id(EOU_STRING) + self.eob_token = self.tokenizer.token_to_id(EOB_STRING) + self.frame_len_in_secs = self.cfg.preprocessor.window_stride * self.cfg.encoder.subsampling_factor + self.setup_eou_mixin(self.eou_token, self.eob_token, self.frame_len_in_secs, self.tokenizer) + + self.wer = WER( + decoding=self.decoding, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + log_prediction=self._cfg.get('log_prediction', True), + dist_sync_on_step=True, + return_hypotheses=True, + ) + + self.ctc_wer = WER( + decoding=self.ctc_decoding, + use_cer=self.cfg.aux_ctc.get('use_cer', False), + dist_sync_on_step=True, + log_prediction=self.cfg.get("log_prediction", False), + return_hypotheses=True, + ) + + # Setup fused Joint step if flag is set + if self.joint.fuse_loss_wer: + self.joint.set_loss(self.loss) + self.joint.set_wer(self.wer) + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + cfg = OmegaConf.create(config) if not isinstance(config, DictConfig) else config + dataset = LhotseSpeechToTextBpeEOUDataset( + cfg=cfg, tokenizer=self.tokenizer, return_cuts=config.get("do_transcribe", False) + ) + return get_lhotse_dataloader_from_config( + config, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), + dataset=dataset, + tokenizer=self.tokenizer, + ) + + def training_step(self, batch: AudioToTextEOUBatch, batch_nb): + signal = batch.audio_signal + signal_len = batch.audio_lengths + transcript = batch.text_tokens + transcript_len = batch.text_token_lengths + + new_batch = (signal, signal_len, transcript, transcript_len) + return super().training_step(new_batch, batch_nb) + + def predict_step(self, batch: AudioToTextEOUBatch, batch_idx, dataloader_idx=0): + signal = batch.audio_signal + signal_len = batch.audio_lengths + transcript = batch.text_tokens + transcript_len = batch.text_token_lengths + sample_ids = batch.sample_ids + new_batch = (signal, signal_len, transcript, transcript_len, sample_ids) + return super().predict_step(new_batch, batch_idx, dataloader_idx) + + def validation_pass(self, batch: AudioToTextEOUBatch, batch_idx: int, dataloader_idx: int = 0): + signal = batch.audio_signal + signal_len = batch.audio_lengths + transcript = batch.text_tokens + transcript_len = batch.text_token_lengths + + encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len) + del signal + + tensorboard_logs = {} + + if self.cfg.get('save_pred_to_file', None): + text_gt = self._get_text_from_tokens(transcript, transcript_len) + tensorboard_logs['val_sample_id'] = batch.sample_ids + tensorboard_logs['val_audio_filepath'] = batch.audio_filepaths + tensorboard_logs['val_text_gt'] = text_gt + + loss_value = None + + # If experimental fused Joint-Loss-WER is not used + if not self.joint.fuse_loss_wer: + if self.compute_eval_loss: + decoder, target_length, states = self.decoder(targets=transcript, target_length=transcript_len) + joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder) + + loss_value = self.loss( + log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length + ) + tensorboard_logs['val_loss'] = loss_value + + self.wer.update( + predictions=encoded, + predictions_lengths=encoded_len, + targets=transcript, + targets_lengths=transcript_len, + ) + + hypotheses = self.wer.get_hypotheses() + + if self.cfg.get('save_pred_to_file', None): + text_pred = self._get_text_from_tokens([x.y_sequence for x in hypotheses]) + tensorboard_logs['val_text_pred'] = text_pred + + eou_predictions = self._get_eou_predictions_from_hypotheses(hypotheses, batch) + eou_metrics_list, eob_metrics_list = self._calculate_eou_metrics(eou_predictions, batch) + + wer, wer_num, wer_denom = self.wer.compute() + self.wer.reset() + + tensorboard_logs['val_wer_num'] = wer_num + tensorboard_logs['val_wer_denom'] = wer_denom + tensorboard_logs['val_wer'] = wer + tensorboard_logs['val_eou_metrics'] = eou_metrics_list + tensorboard_logs['val_eob_metrics'] = eob_metrics_list + tensorboard_logs['val_text_pred'] = text_pred + + else: + # If experimental fused Joint-Loss-WER is used + compute_wer = True + + if self.compute_eval_loss: + decoded, target_len, states = self.decoder(targets=transcript, target_length=transcript_len) + else: + decoded = None + target_len = transcript_len + + # Fused joint step + loss_value, wer, wer_num, wer_denom = self.joint( + encoder_outputs=encoded, + decoder_outputs=decoded, + encoder_lengths=encoded_len, + transcripts=transcript, + transcript_lengths=target_len, + compute_wer=compute_wer, + keep_hypotheses=True, + ) + hypotheses = self.joint.get_hypotheses() + + if self.cfg.get('save_pred_to_file', None): + text_pred = self._get_text_from_tokens([x.y_sequence for x in hypotheses]) + tensorboard_logs['val_text_pred'] = text_pred + + eou_predictions = self._get_eou_predictions_from_hypotheses(hypotheses, batch) + + eou_metrics_list, eob_metrics_list = self._calculate_eou_metrics(eou_predictions, batch) + + if loss_value is not None: + tensorboard_logs['val_loss'] = loss_value + + tensorboard_logs['val_wer_num'] = wer_num + tensorboard_logs['val_wer_denom'] = wer_denom + tensorboard_logs['val_wer'] = wer + tensorboard_logs['val_eou_metrics'] = eou_metrics_list + tensorboard_logs['val_eob_metrics'] = eob_metrics_list + + log_probs = self.ctc_decoder(encoder_output=encoded) + if self.compute_eval_loss: + ctc_loss = self.ctc_loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + tensorboard_logs['val_ctc_loss'] = ctc_loss + tensorboard_logs['val_rnnt_loss'] = loss_value + loss_value = (1 - self.ctc_loss_weight) * loss_value + self.ctc_loss_weight * ctc_loss + tensorboard_logs['val_loss'] = loss_value + + self.ctc_wer.update( + predictions=log_probs, + targets=transcript, + targets_lengths=transcript_len, + predictions_lengths=encoded_len, + ) + hypotheses_ctc = self.ctc_wer.get_hypotheses() + + if self.cfg.get('save_pred_to_file', None): + text_pred_ctc = self._get_text_from_tokens([x.y_sequence for x in hypotheses_ctc]) + tensorboard_logs['val_text_pred_ctc'] = text_pred_ctc + + eou_predictions_ctc = self._get_eou_predictions_from_hypotheses(hypotheses_ctc, batch) + eou_metrics_list_ctc, eob_metrics_list_ctc = self._calculate_eou_metrics(eou_predictions_ctc, batch) + + ctc_wer, ctc_wer_num, ctc_wer_denom = self.ctc_wer.compute() + self.ctc_wer.reset() + + tensorboard_logs['val_wer_num_ctc'] = ctc_wer_num + tensorboard_logs['val_wer_denom_ctc'] = ctc_wer_denom + tensorboard_logs['val_wer_ctc'] = ctc_wer + tensorboard_logs['val_eou_metrics_ctc'] = eou_metrics_list_ctc + tensorboard_logs['val_eob_metrics_ctc'] = eob_metrics_list_ctc + + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + loss_value, additional_logs = self.add_interctc_losses( + loss_value, + transcript, + transcript_len, + compute_wer=True, + compute_loss=self.compute_eval_loss, + log_wer_num_denom=True, + log_prefix="val_", + ) + if self.compute_eval_loss: + # overriding total loss value. Note that the previous + # rnnt + ctc loss is available in metrics as "val_final_loss" now + tensorboard_logs['val_loss'] = loss_value + tensorboard_logs.update(additional_logs) + # Reset access registry + if AccessMixin.is_access_enabled(self.model_guid): + AccessMixin.reset_registry(self) + + return tensorboard_logs + + def multi_inference_epoch_end(self, outputs, dataloader_idx: int = 0, mode: str = "val"): + assert mode in ['val', 'test'], f"Invalid mode: {mode}. Must be 'val' or 'test'." + self._maybe_save_predictions(outputs, mode=mode, dataloader_idx=dataloader_idx) + + # Aggregate WER metrics + if self.compute_eval_loss: + loss_mean = torch.stack([x[f'{mode}_loss'] for x in outputs]).mean() + loss_log = {f'{mode}_loss': loss_mean} + else: + loss_log = {} + wer_num = torch.stack([x[f'{mode}_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x[f'{mode}_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {**loss_log, f'{mode}_wer': wer_num.float() / wer_denom} + + if self.ctc_loss_weight > 0: + ctc_wer_num = torch.stack([x['val_wer_num_ctc'] for x in outputs]).sum() + ctc_wer_denom = torch.stack([x['val_wer_denom_ctc'] for x in outputs]).sum() + tensorboard_logs['val_wer_ctc'] = ctc_wer_num.float() / ctc_wer_denom + + eou_metrics = self._aggregate_eou_metrics(outputs, mode) + tensorboard_logs.update(eou_metrics) + + eou_metrics_ctc = self._aggregate_eou_metrics(outputs, mode, is_ctc=True) + for key, value in eou_metrics_ctc.items(): + tensorboard_logs[f'{key}_ctc'] = value + + return {**loss_log, 'log': tensorboard_logs} + + def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): + return self.multi_inference_epoch_end(outputs, dataloader_idx, mode='val') + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + return self.multi_inference_epoch_end(outputs, dataloader_idx, mode='test') diff --git a/nemo/collections/asr/modules/__init__.py b/nemo/collections/asr/modules/__init__.py index 14abdd0d2776..310b89bf9f8c 100644 --- a/nemo/collections/asr/modules/__init__.py +++ b/nemo/collections/asr/modules/__init__.py @@ -20,7 +20,11 @@ SpectrogramAugmentation, ) from nemo.collections.asr.modules.beam_search_decoder import BeamSearchDecoderWithLM -from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder, ConformerEncoderAdapter +from nemo.collections.asr.modules.conformer_encoder import ( + ConformerEncoder, + ConformerEncoderAdapter, + ConformerMultiLayerFeatureExtractor, +) from nemo.collections.asr.modules.conv_asr import ( ConvASRDecoder, ConvASRDecoderClassification, @@ -45,10 +49,47 @@ ) from nemo.collections.asr.modules.squeezeformer_encoder import SqueezeformerEncoder, SqueezeformerEncoderAdapter from nemo.collections.asr.modules.ssl_modules import ( - ConformerMultiLayerFeatureExtractor, ConformerMultiLayerFeaturePreprocessor, ConvFeatureMaksingWrapper, MultiSoftmaxDecoder, RandomBlockMasking, RandomProjectionVectorQuantizer, ) + + +__all__ = [ + 'AudioToMelSpectrogramPreprocessor', + 'AudioToMFCCPreprocessor', + 'CropOrPadSpectrogramAugmentation', + 'MaskedPatchAugmentation', + 'SpectrogramAugmentation', + 'BeamSearchDecoderWithLM', + 'ConformerEncoder', + 'ConformerEncoderAdapter', + 'ConformerMultiLayerFeatureExtractor', + 'ConvASRDecoder', + 'ConvASRDecoderClassification', + 'ConvASRDecoderReconstruction', + 'ConvASREncoder', + 'ConvASREncoderAdapter', + 'ECAPAEncoder', + 'ParallelConvASREncoder', + 'SpeakerDecoder', + 'ViterbiDecoderWithGraph', + 'HATJoint', + 'LSTMDecoder', + 'MSDD_module', + 'RNNEncoder', + 'RNNTDecoder', + 'RNNTDecoderJointSSL', + 'RNNTJoint', + 'SampledRNNTJoint', + 'StatelessTransducerDecoder', + 'SqueezeformerEncoder', + 'SqueezeformerEncoderAdapter', + 'ConformerMultiLayerFeaturePreprocessor', + 'ConvFeatureMaksingWrapper', + 'MultiSoftmaxDecoder', + 'RandomBlockMasking', + 'RandomProjectionVectorQuantizer', +] diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index 2a7ffeba5d3c..36371e6cf6a7 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -16,7 +16,7 @@ import random from collections import OrderedDict from dataclasses import dataclass -from typing import List, Optional, Set, Tuple +from typing import Callable, List, Optional, Set, Tuple import torch import torch.distributed @@ -56,7 +56,7 @@ ) from nemo.utils import logging -__all__ = ['ConformerEncoder'] +__all__ = ['ConformerEncoder', 'ConformerMultiLayerFeatureExtractor'] class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin): @@ -1266,44 +1266,54 @@ def get_accepted_adapter_types( class ConformerMultiLayerFeatureExtractor(NeuralModule, Exportable, AccessMixin): - """ - A wrapper module that extracts features from multiple layers of a ConformerEncoder, - by reusing existing mechanisim for interctc loss. - To use it, set `layer_idx_list` to specify the indices of layers to extract from. - Also, you can specify an `aggretator` module to aggregate the features from different layers, - default not aggregating. - """ - def __init__( self, encoder: ConformerEncoder, - layer_idx_list: List[int], - aggregator: NeuralModule = None, - detach: bool = False, - convert_to_cpu: bool = False, + aggregator: Optional[Callable] = None, + layer_idx_list: Optional[List[int]] = None, ): + """ + This class is used to extract features from different layers of the ConformerEncoder. + Args: + encoder: ConformerEncoder instance. + aggregator: Aggregator instance. If None, the features are returned as a list. + layer_idx_list: List of layer indices to extract features from. If None, all layers are extracted. + """ super().__init__() self.encoder = encoder - self.layer_idx_list = [int(lyr_idx) for lyr_idx in layer_idx_list] - for x in self.layer_idx_list: - if x < 0 or x >= len(encoder.layers): - raise ValueError(f"layer index {x} out of range [0, {len(encoder.layers)})") - self.enc_access_cfg = { + self.aggregator = aggregator + self.num_layers = len(encoder.layers) + self.layer_idx_list = [] + if not layer_idx_list: + layer_idx_list = list(range(self.num_layers)) + for lid in layer_idx_list: + if lid < -self.num_layers or lid >= self.num_layers: + raise ValueError(f"Invalid layer index {lid} for ConformerEncoder with {self.num_layers} layers.") + if lid < 0: + lid = self.num_layers + lid + self.layer_idx_list.append(lid) + self.layer_idx_list.sort() + logging.info(f"Extracting features from layers: {self.layer_idx_list}") + self.access_cfg = { "interctc": { "capture_layers": self.layer_idx_list, }, - "detach": detach, - "convert_to_cpu": convert_to_cpu, + "detach": False, + "convert_to_cpu": False, } - self.aggregator = aggregator + self._is_access_enabled = False def forward( self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None ) -> Tuple[torch.Tensor, torch.Tensor]: - # pylint: disable=missing-function-docstring - old_access_flag = self.is_access_enabled(guid=getattr(self, "model_guid", None)) - self.update_access_cfg(self.enc_access_cfg, guid=getattr(self, "model_guid", None)) - self.set_access_enabled(access_enabled=True, guid=getattr(self, "model_guid", None)) + """ + Args: + same interface as ConformerEncoder.forward() + Returns: + tuple of aggregated features of shape [B, D, T] and lengths of shape [B] + """ + self.encoder.update_access_cfg(self.access_cfg, guid=getattr(self, "model_guid", None)) + self.encoder.set_access_enabled(access_enabled=True, guid=getattr(self, "model_guid", None)) _ = self.encoder( audio_signal=audio_signal, @@ -1313,9 +1323,8 @@ def forward( cache_last_channel_len=cache_last_channel_len, ) - # Chunk of code adapted from ConformerEncoder.forward_internal() total_registry = {} - for module_registry in self.get_module_registry(self.encoder).values(): + for module_registry in self.encoder.get_module_registry(self.encoder).values(): for key in module_registry: if key.startswith("interctc/") and key in total_registry: raise RuntimeError(f"layer {key} has been logged multiple times!") @@ -1329,8 +1338,8 @@ def forward( layer_lengths = total_registry[f"interctc/layer_length_{layer_idx}"] except KeyError: raise RuntimeError( - f"Intermediate layer {layer_idx} was not captured! " - "Check the layer index and the number of ConformerEncoder layers." + f"Intermediate layer {layer_idx} was not captured! Check the layer index and the number of " + "ConformerEncoder layers." ) if len(layer_outputs) > 1 or len(layer_lengths) > 1: raise RuntimeError("Make sure encoder.forward is called exactly one time") @@ -1338,13 +1347,9 @@ def forward( encoded_len_list.append(layer_lengths[0]) # [B] self.encoder.reset_registry() - self.set_access_enabled(access_enabled=old_access_flag, guid=getattr(self, "model_guid", None)) - # End of the adapted chunk - - if self.aggregator is not None: - return self.aggregator(encoded_list, encoded_len_list) # Tensor[B,D*L,T], Tensor[B] - else: - return encoded_list, encoded_len_list # List[Tensor[B,D,T]], List[Tensor[B]] + if self.aggregator is None: + return encoded_list, encoded_len_list + return self.aggregator(encoded_list, encoded_len_list) # Register any additional information diff --git a/nemo/collections/asr/modules/lstm_decoder.py b/nemo/collections/asr/modules/lstm_decoder.py index 9bb60e2fabca..8c41c5657f52 100644 --- a/nemo/collections/asr/modules/lstm_decoder.py +++ b/nemo/collections/asr/modules/lstm_decoder.py @@ -45,7 +45,16 @@ def input_types(self): def output_types(self): return OrderedDict({"logprobs": NeuralType(('B', 'T', 'D'), LogprobsType())}) - def __init__(self, feat_in, num_classes, lstm_hidden_size, vocabulary=None, bidirectional=False, num_layers=1): + def __init__( + self, + feat_in, + num_classes, + lstm_hidden_size, + vocabulary=None, + bidirectional=False, + num_layers=1, + add_blank=True, + ): super().__init__() if vocabulary is not None: @@ -57,7 +66,7 @@ def __init__(self, feat_in, num_classes, lstm_hidden_size, vocabulary=None, bidi self.__vocabulary = vocabulary self._feat_in = feat_in # Add 1 for blank char - self._num_classes = num_classes + 1 + self._num_classes = num_classes + 1 if add_blank else num_classes self.lstm_layer = nn.LSTM( input_size=feat_in, diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index bcc3c24fcb55..a3c37f254b3a 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -1358,6 +1358,7 @@ def input_types(self): "transcripts": NeuralType(('B', 'T'), LabelsType(), optional=True), "transcript_lengths": NeuralType(tuple('B'), LengthsType(), optional=True), "compute_wer": NeuralType(optional=True), + "keep_hypotheses": NeuralType(optional=True), } @property @@ -1470,6 +1471,8 @@ def __init__( # to change, requires running ``model.temperature = T`` explicitly self.temperature = 1.0 + self.hypotheses = None + @typecheck() def forward( self, @@ -1479,6 +1482,7 @@ def forward( transcripts: Optional[torch.Tensor] = None, transcript_lengths: Optional[torch.Tensor] = None, compute_wer: bool = False, + keep_hypotheses: bool = False, ) -> Union[torch.Tensor, List[Optional[torch.Tensor]]]: # encoder = (B, D, T) # decoder = (B, D, U) if passed, else None @@ -1516,6 +1520,7 @@ def forward( wers, wer_nums, wer_denoms = [], [], [] target_lengths = [] batch_size = int(encoder_outputs.size(0)) # actual batch size + hypotheses = [] # Iterate over batch using fused_batch_size steps for batch_idx in range(0, batch_size, self._fused_batch_size): @@ -1600,6 +1605,9 @@ def forward( targets=sub_transcripts, targets_lengths=sub_transcript_lens, ) + + hyp = self.wer.get_hypotheses() if keep_hypotheses else [] + # Sync and all_reduce on all processes, compute global WER wer, wer_num, wer_denom = self.wer.compute() self.wer.reset() @@ -1610,6 +1618,7 @@ def forward( wers.append(wer) wer_nums.append(wer_num) wer_denoms.append(wer_denom) + hypotheses.extend(hyp) del sub_enc, sub_transcripts, sub_enc_lens, sub_transcript_lens @@ -1627,8 +1636,19 @@ def forward( wer_num = None wer_denom = None + self.hypotheses = hypotheses if keep_hypotheses else None return losses, wer, wer_num, wer_denom + def get_hypotheses(self): + """ + Returns the hypotheses generated during the last forward pass. + """ + if self.hypotheses is None: + raise ValueError( + "No hypotheses were generated during the last forward pass. Did you set keep_hypotheses=True in forward()?" + ) + return self.hypotheses + def project_encoder(self, encoder_output: torch.Tensor) -> torch.Tensor: """ Project the encoder output to the joint hidden dimension. diff --git a/nemo/collections/asr/modules/ssl_modules/__init__.py b/nemo/collections/asr/modules/ssl_modules/__init__.py index dcfefd54fa73..f33127bd7d85 100644 --- a/nemo/collections/asr/modules/ssl_modules/__init__.py +++ b/nemo/collections/asr/modules/ssl_modules/__init__.py @@ -17,9 +17,16 @@ SpeakerNoiseAugmentation, ) from nemo.collections.asr.modules.ssl_modules.masking import ConvFeatureMaksingWrapper, RandomBlockMasking -from nemo.collections.asr.modules.ssl_modules.multi_layer_feat import ( - ConformerMultiLayerFeatureExtractor, - ConformerMultiLayerFeaturePreprocessor, -) +from nemo.collections.asr.modules.ssl_modules.multi_layer_feat import ConformerMultiLayerFeaturePreprocessor from nemo.collections.asr.modules.ssl_modules.multi_softmax_decoder import MultiSoftmaxDecoder from nemo.collections.asr.modules.ssl_modules.quantizers import RandomProjectionVectorQuantizer + +__all__ = [ + 'MultiSpeakerNoiseAugmentation', + 'SpeakerNoiseAugmentation', + 'ConvFeatureMaksingWrapper', + 'RandomBlockMasking', + 'ConformerMultiLayerFeaturePreprocessor', + 'MultiSoftmaxDecoder', + 'RandomProjectionVectorQuantizer', +] diff --git a/nemo/collections/asr/modules/ssl_modules/multi_layer_feat.py b/nemo/collections/asr/modules/ssl_modules/multi_layer_feat.py index 490d68c52f04..73ca41438437 100644 --- a/nemo/collections/asr/modules/ssl_modules/multi_layer_feat.py +++ b/nemo/collections/asr/modules/ssl_modules/multi_layer_feat.py @@ -12,16 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Optional, Tuple +from typing import List, Optional, Tuple import torch import torch.distributed import torch.nn as nn -from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor, ConformerEncoder +from nemo.collections.asr.modules import ( + AudioToMelSpectrogramPreprocessor, + ConformerEncoder, + ConformerMultiLayerFeatureExtractor, +) from nemo.core.classes import Exportable, NeuralModule from nemo.core.classes.mixins import AccessMixin -from nemo.utils import logging class Aggregator(nn.Module): @@ -81,85 +84,12 @@ def forward( raise ValueError(f"Unknown mode {self.mode}") -class ConformerMultiLayerFeatureExtractor(NeuralModule, Exportable): - def __init__(self, encoder, aggregator: Optional[Callable] = None, layer_idx_list: Optional[List[int]] = None): - """ - Args: - encoder: ConformerEncoder instance. - aggregator: Aggregator instance. - layer_idx_list: List of layer indices to extract features from. - """ - super().__init__() - self.encoder = encoder - self.aggregator = aggregator - self.layer_idx_list = ( - [int(l) for l in layer_idx_list] - if layer_idx_list is not None - else [i for i in range(len(self.encoder.layers))] - ) - for x in self.layer_idx_list: - if x < 0 or x >= len(self.encoder.layers): - raise ValueError(f"layer index {x} out of range [0, {len(self.encoder.layers)})") - logging.info(f"Extracting features from layers {self.layer_idx_list}") - self.access_cfg = { - "interctc": { - "capture_layers": self.layer_idx_list, - }, - "detach": False, - "convert_to_cpu": False, - } - self._is_access_enabled = False - - def forward( - self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - same interface as ConformerEncoder.forward() - Returns: - tuple of aggregated features of shape [B, D, T] and lengths of shape [B] - """ - self.encoder.update_access_cfg(self.access_cfg, guid=getattr(self, "model_guid", None)) - self.encoder.set_access_enabled(access_enabled=True, guid=getattr(self, "model_guid", None)) - - _ = self.encoder( - audio_signal=audio_signal, - length=length, - cache_last_channel=cache_last_channel, - cache_last_time=cache_last_time, - cache_last_channel_len=cache_last_channel_len, - ) - - total_registry = {} - for module_registry in self.encoder.get_module_registry(self.encoder).values(): - for key in module_registry: - if key.startswith("interctc/") and key in total_registry: - raise RuntimeError(f"layer {key} has been logged multiple times!") - total_registry.update(module_registry) - - encoded_list = [] - encoded_len_list = [] - for layer_idx in self.layer_idx_list: - try: - layer_outputs = total_registry[f"interctc/layer_output_{layer_idx}"] - layer_lengths = total_registry[f"interctc/layer_length_{layer_idx}"] - except KeyError: - raise RuntimeError( - f"Intermediate layer {layer_idx} was not captured! Check the layer index and the number of " - "ConformerEncoder layers." - ) - if len(layer_outputs) > 1 or len(layer_lengths) > 1: - raise RuntimeError("Make sure encoder.forward is called exactly one time") - encoded_list.append(layer_outputs[0]) # [B, D, T] - encoded_len_list.append(layer_lengths[0]) # [B] - - self.encoder.reset_registry() - if self.aggregator is None: - return encoded_list, encoded_len_list - return self.aggregator(encoded_list, encoded_len_list) - - class ConformerMultiLayerFeaturePreprocessor(NeuralModule, Exportable, AccessMixin): + """ + This class is used to replace the AudioToMelSpectrogramPreprocessor such that + the input to the actual model encoder is the multi-layer features from a pre-trained ConformerEncoder. + """ + def __init__( self, aggregator: nn.Module, diff --git a/nemo/collections/asr/parts/utils/eou_utils.py b/nemo/collections/asr/parts/utils/eou_utils.py new file mode 100644 index 000000000000..6099e2a6fcfe --- /dev/null +++ b/nemo/collections/asr/parts/utils/eou_utils.py @@ -0,0 +1,263 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, List + +import numpy as np + + +@dataclass +class EOUResult: + latency: list + early_cutoff: list + true_positives: int + false_negatives: int + false_positives: int + num_utterances: int + num_predictions: int + missing: int + + def to_dict(self) -> Dict[str, float]: + """ + Convert the EOUResult dataclass to a dictionary. + Returns: + Dict: A dictionary representation of the EOUResult. + """ + return { + 'latency': self.latency, + 'early_cutoff': self.early_cutoff, + 'true_positives': self.true_positives, + 'false_negatives': self.false_negatives, + 'false_positives': self.false_positives, + 'num_utterances': self.num_utterances, + 'num_predictions': self.num_predictions, + 'missing': self.missing, + } + + +def flatten_nested_list(nested_list: List[List[float]]) -> List[float]: + """ + Flatten a nested list into a single list. + Args: + nested_list (List[List]): A nested list to be flattened. + Returns: + List: A flattened list. + """ + return [item for sublist in nested_list for item in sublist] + + +def evaluate_eou( + *, prediction: List[dict], reference: List[dict], threshold: float, collar: float, do_sorting: bool = True +) -> EOUResult: + """ + Evaluate end of utterance predictions against reference labels. + Each item in predicition/reference is a dictionary in SegLST containing: + { + "session_id": str, + "start_time": float, # start time in seconds + "end_time": float, # end time in seconds + "words": str, # transcription of the utterance + "audio_filepath": str, # only in prediction + "eou_prob": float, # only in prediction, probability of EOU in range [0.1] + "eou_pred": bool, # only in prediction + "full_text": str, # only in prediction, which is the full transcription up to the end_time + } + + Args: + predictions (List[dict]): List of dictionaries containing predictions. + references (List[dict]): List of dictionaries containing reference labels. + threshold (float): Threshold for considering a prediction as EOU. + collar (float): Collar time in seconds for matching predictions to references. + do_sorting (bool): Whether to sort the predictions and references by start time. + Returns: + EOUResult: A dataclass containing the evaluation results. + """ + + latency = [] + early_cutoff = [] + true_positives = 0 + false_negatives = 0 + false_positives = 0 + num_utterances = len(reference) + num_predictions = len(prediction) + missing = 0 + earlycut_ids = set() + predicted_eou = prediction + if threshold is not None and threshold > 0: + predicted_eou = [p for p in prediction if p["eou_prob"] > threshold] + elif all([hasattr(p, "eou_pred") for p in prediction]): + # If eou_pred is available, use it + predicted_eou = [p for p in prediction if p["eou_pred"]] + + if do_sorting: + predicted_eou = sorted(predicted_eou, key=lambda x: x["start_time"]) + reference = sorted(reference, key=lambda x: x["start_time"]) + + p_idx = 0 + r_idx = 0 + for p_idx in range(len(predicted_eou)): + p = predicted_eou[p_idx] + p_start = p["start_time"] + p_end = p["end_time"] + + while r_idx < len(reference) and reference[r_idx]["end_time"] < p_start: + # Current reference ends before the current predicted utterance starts, find the next reference + r_idx += 1 + + if r_idx >= len(reference): + # No more references to compare against + false_positives += 1 + continue + + r = reference[r_idx] + r_start = r["start_time"] + r_end = r["end_time"] + + if np.abs(p_end - r_end) <= collar: + # Correctly predicted EOU + true_positives += 1 + latency.append(p_end - r_end) + if r_idx in earlycut_ids: + # If this reference was already missed due to early cutoff, we do not count it again + earlycut_ids.remove(r_idx) + r_idx += 1 + elif r_start <= p_end < r_end - collar: + # Early cutoff + # current predicted EOU is within the current reference utterance + false_positives += 1 + early_cutoff.append(r_end - p_end) + earlycut_ids.add(r_idx) + elif r_end + collar < p_end: + # Late EOU + # Current predicted EOU is after the current reference ends + false_negatives += 1 + latency.append(p_end - r_end) + if r_idx in earlycut_ids: + # If this reference was already missed due to early cutoff, we do not count it again + earlycut_ids.remove(r_idx) + r_idx += 1 + else: + # p_end <= r_start + # Current predicted EOU is before the current reference starts + false_positives += 1 + + if r_idx < len(reference): + # There are remaining references that were not matched + false_negatives += len(reference) - r_idx + missing += len(reference) - r_idx + + missing -= len(earlycut_ids) # Remove the references that were missed due to early cutoff + return EOUResult( + latency=latency, + early_cutoff=early_cutoff, + true_positives=true_positives, + false_negatives=false_negatives, + false_positives=false_positives, + num_utterances=num_utterances, + num_predictions=num_predictions, + missing=missing, + ) + + +def get_SegLST_from_frame_labels(frame_labels: List[int], frame_len_in_secs: float = 0.08) -> List[dict]: + """ + Convert frame labels to SegLST format. + Args: + frame_labels (List[int]): List of frame labels. + frame_len_in_secs (float): Length of each frame in seconds. + Returns: + List[dict]: List of dictionaries in SegLST format. + """ + seg_lst = [] + start_time = 0.0 + for i, label in enumerate(frame_labels): + if label > 0: + end_time = start_time + frame_len_in_secs * i + seg_lst.append({"start_time": start_time, "end_time": end_time, "eou_prob": label}) + start_time = end_time + return seg_lst + + +def cal_eou_metrics_from_frame_labels( + *, prediction: List, reference: List, threshold: float = 0.5, collar: float = 0, frame_len_in_secs: float = 0.08 +) -> EOUResult: + """ + Calculate EOU metrics from lists of predictions and references. + Args: + prediction (List): List of floats containing predicted EOU probabilities. + reference (List): List of binary floats containing reference EOU probabilities. + threshold (float): Threshold for considering a prediction as EOU. + collar (float): Collar time in seconds for matching predictions to references. + frame_len_in_secs (float): Length of each frame in seconds. + """ + + if len(prediction) != len(reference): + raise ValueError( + f"Prediction ({len(prediction)}) and reference ({len(reference)}) lists must have the same length." + ) + + pred_seg_lst = get_SegLST_from_frame_labels(prediction, frame_len_in_secs) + ref_seg_lst = get_SegLST_from_frame_labels(reference, frame_len_in_secs) + eou_metrics = evaluate_eou( + prediction=pred_seg_lst, reference=ref_seg_lst, threshold=threshold, collar=collar, do_sorting=False + ) + return eou_metrics + + +def get_percentiles(values: List[float], percentiles: List[float], tag: str = "") -> Dict[str, float]: + """ + Get the percentiles of a list of values. + Args: + values: list of values + percentiles: list of percentiles + Returns: + metrics: Dict of percentiles + """ + if len(values) == 0: + return [0.0] * len(percentiles) + results = np.percentile(values, percentiles).tolist() + metrics = {} + if tag: + tag += "_" + for i, p in enumerate(percentiles): + metrics[f'{tag}p{int(p)}'] = float(results[i]) + return metrics + + +def aggregate_eou_metrics(eou_metrics: List[EOUResult], target_percentiles: List = [50, 90, 95]) -> Dict[str, float]: + # Aggregate EOU metrics + num_eou_utterances = sum([x.num_utterances for x in eou_metrics]) + eou_latency = flatten_nested_list([x.latency for x in eou_metrics]) + eou_early_cutoff = flatten_nested_list([x.early_cutoff for x in eou_metrics]) + + eou_avg_num_early_cutoff = len(eou_early_cutoff) / num_eou_utterances if num_eou_utterances > 0 else 0.0 + if len(eou_latency) == 0: + eou_latency = [0.0] + if len(eou_early_cutoff) == 0: + eou_early_cutoff = [0.0] + + eou_missing = [x.missing for x in eou_metrics] + + metrics = {} + eou_latency_metrics = get_percentiles(eou_latency, target_percentiles, tag='latency') + eou_early_cutoff_metrics = get_percentiles(eou_early_cutoff, target_percentiles, tag='early_cutoff') + + metrics.update(eou_latency_metrics) + metrics.update(eou_early_cutoff_metrics) + + metrics['early_cutoff_rate'] = eou_avg_num_early_cutoff + metrics['miss_rate'] = sum(eou_missing) / num_eou_utterances if num_eou_utterances > 0 else 0.0 + + return metrics diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index f62b15f98dc0..bb859f21fa06 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -506,6 +506,19 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No # 2.a. Noise mixing. if config.noise_path is not None: noise = guess_parse_cutset(config.noise_path) + noise = noise.resample(config.sample_rate) + + def mark_as_mixed_in_noise(cut): + cut.is_mixed_noise = True + return cut + + # In current lhotse implementation, if padding is applied before noise augmentation, + # and your noise manifest has dummy text field like `"text": "-"`, + # the call to MixCut.first_non_padding_cut will return the noise cut + # instead of the speech cut. We mark the noise cut with ``is_mixed_noise`` flag + # to avoid this issue, and the speech cut can be obtained by: + # `cut =[t.cut for t in cut.tracks if len(t.cut.supervisions) > 0 and not t.cut.custom.get("is_mixed_noise", False)][0]` + noise = noise.map(mark_as_mixed_in_noise) cuts = cuts.mix( cuts=noise, snr=tuple(config.noise_snr), diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 12b1d0fc60fb..519cdf4c62ca 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -1459,7 +1459,7 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st if isinstance(cfg.init_from_ptl_ckpt, str): # Restore checkpoint ckpt_path = cfg.pop('init_from_ptl_ckpt') - ckpt = torch.load(ckpt_path, map_location=map_location) + ckpt = torch.load(ckpt_path, map_location=map_location, weights_only=False) # Restore checkpoint into current model self.load_state_dict(ckpt['state_dict'], strict=False) @@ -1473,7 +1473,7 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st for model_load_cfg in model_load_dict.values(): ckpt_path = model_load_cfg.path # Restore model - ckpt = torch.load(ckpt_path, map_location=map_location) + ckpt = torch.load(ckpt_path, map_location=map_location, weights_only=False) include = model_load_cfg.pop('include', [""]) exclude = model_load_cfg.pop('exclude', []) diff --git a/scripts/asr_eou/add_eob_labels.py b/scripts/asr_eou/add_eob_labels.py new file mode 100644 index 000000000000..ba17b7aa98f2 --- /dev/null +++ b/scripts/asr_eou/add_eob_labels.py @@ -0,0 +1,211 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example usage: + +```bash +python add_eob_labels.py /path/to/manifest/dir +``` +where output will be saved in the same directory with `-eob` suffix added to the filename. +""" + +import argparse +import json +from pathlib import Path +from string import punctuation + +from tqdm import tqdm + +parser = argparse.ArgumentParser(description="Add `is_backchannel` labels to manifest files.") +parser.add_argument( + "input_manifest", + type=str, + help="Path to the input manifest file to be cleaned.", +) +parser.add_argument( + "-o", + "--output", + type=str, + default=None, + help="Path to the output manifest file after cleaning.", +) +parser.add_argument( + "-p", + "--pattern", + type=str, + default="*.json", + help="Pattern to match files in the input directory.", +) + + +def read_manifest(manifest_path): + manifest = [] + with open(manifest_path, 'r') as f: + for line in f.readlines(): + line = line.strip() + if line: + manifest.append(json.loads(line)) + return manifest + + +def write_manifest(manifest_path, manifest): + with open(manifest_path, 'w') as f: + for item in manifest: + f.write(json.dumps(item) + '\n') + + +def clean_text(text): + text = text.translate(str.maketrans('', '', punctuation)).lower().strip() + valid_chars = "abcdefghijklmnopqrstuvwxyz'" + text = ''.join([c for c in text if c in valid_chars or c.isspace() or c == "'"]) + return " ".join(text.split()).strip() + + +backchannel_phrases = [ + 'absolutely', + 'ah', + 'all right', + 'alright', + 'but yeah', + 'definitely', + 'exactly', + 'go ahead', + 'good', + 'great', + 'great thanks', + 'ha ha', + 'hi', + 'i know', + 'i know right', + 'i see', + 'indeed', + 'interesting', + 'mhmm', + 'mhmm mhmm', + 'mhmm right', + 'mhmm yeah', + 'mhmm yes', + 'nice', + 'of course', + 'oh', + 'oh dear', + 'oh man', + 'oh okay', + 'oh wow', + 'oh yes', + 'ok', + 'ok thanks', + 'okay', + 'okay okay', + 'okay thanks', + 'perfect', + 'really', + 'right', + 'right exactly', + 'right right', + 'right yeah', + 'so yeah', + 'sounds good', + 'sure', + 'thank you', + 'thanks', + "that's awesome", + 'thats right', + 'thats true', + 'true', + 'uh-huh', + 'uh-huh yeah', + 'uhhuh', + 'um-humm', + 'well', + 'what', + 'wow', + 'yeah', + 'yeah i know', + 'yeah i see', + 'yeah mhmm', + 'yeah okay', + 'yeah right', + 'yeah uh-huh', + 'yeah yeah', + 'yep', + 'yes', + 'yes please', + 'yes yes', + 'you know', + "you're right", +] + +backchannel_phrases_nopc = [clean_text(phrase) for phrase in backchannel_phrases] + + +def check_if_backchannel(text): + """ + Check if the text is a backchannel phrase. + """ + # Remove punctuation and convert to lowercase + text = clean_text(text) + # Check if the text is in the list of backchannel phrases + return text in backchannel_phrases_nopc + + +def add_eob_labels(manifest_path): + num_eob = 0 + manifest = read_manifest(manifest_path) + for i, item in enumerate(manifest): + text = item['text'] + # Check if the text is a backchannel phrase + is_backchannel = check_if_backchannel(text) + # Add the EOB label to the text + if is_backchannel: + item['is_backchannel'] = True + num_eob += 1 + else: + item['is_backchannel'] = False + manifest[i] = item + return manifest, num_eob + + +def main(): + args = parser.parse_args() + input_manifest = Path(args.input_manifest) + + if input_manifest.is_dir(): + manifest_list = list(input_manifest.glob(args.pattern)) + if not manifest_list: + raise ValueError(f"No files found in {input_manifest} matching pattern `{args.pattern}`") + else: + manifest_list = [input_manifest] + + if args.output is None: + output_dir = input_manifest if input_manifest.is_dir() else input_manifest.parent + else: + output_dir = Path(args.output) + output_dir.mkdir(parents=True, exist_ok=True) + + total_num_eob = 0 + print(f"Processing {len(manifest_list)} manifest files...") + for manifest_path in tqdm(manifest_list, total=len(manifest_list)): + output_file = output_dir / f"{manifest_path.stem}-eob.json" + new_manifest, num_eob = add_eob_labels(manifest_path) + total_num_eob += num_eob + write_manifest(output_file, new_manifest) + print(f"Processed {manifest_path} and saved to {output_file}. Number of EOB labels added: {num_eob}") + + print(f"Total number of EOB labels added: {total_num_eob}") + + +if __name__ == "__main__": + main() diff --git a/scripts/asr_eou/clean_manifest.py b/scripts/asr_eou/clean_manifest.py new file mode 100644 index 000000000000..54766c4f321e --- /dev/null +++ b/scripts/asr_eou/clean_manifest.py @@ -0,0 +1,646 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example usage: + +```bash +python clean_manifest.py \ + /path/to/manifest/dir \ + -o /path/to/output/dir +``` + +""" + +import argparse +import re +import unicodedata +from pathlib import Path +from string import punctuation + +import dateutil.parser as date_parser +from num2words import num2words +from whisper_normalizer.english import EnglishTextNormalizer + +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest + +punctuations = punctuation.replace("'", "") + +text_normalizer = EnglishTextNormalizer() + +parser = argparse.ArgumentParser(description="Clean manifest file") +parser.add_argument( + "input_manifest", + type=str, + help="Path to the input manifest file to be cleaned.", +) +parser.add_argument( + "-o", + "--output", + type=str, + default=None, + help="Path to the output manifest file after cleaning.", +) +parser.add_argument( + "-lower", + "--lowercase", + type=bool, + default=False, + help="Whether to convert the text to lowercase.", +) +parser.add_argument( + "-drop", + "--remove_punc", + type=bool, + default=False, + help="Whether to remove punctuation from the text.", +) +parser.add_argument( + "--normalize", + type=bool, + default=False, + help="Whether to normalize the text using Whisper text normalizer.", +) +parser.add_argument( + "-n2w", + "--replace_numbers", + type=bool, + default=True, + help="Whether to replace numbers with words.", +) +parser.add_argument( + "-p", + "--pattern", + type=str, + default="**/*.json", + help="Pattern to match files in the input directory.", +) +parser.add_argument( + "-t", + "--text_field", + type=str, + default="text", + help="Field in the manifest to clean. Default is 'text'.", +) +parser.add_argument( + "--auto_pc", + action="store_true", + help="If set, will add auto capitalization and punctuation at the end of the text.", +) +parser.add_argument( + "--format", + default="asr", + choices=["asr", "conv"], + help="Format of the manifest. Default is 'asr'.", +) +parser.add_argument( + "--keep_name", + action="store_true", + help="If set, will keep the original name of the manifest file.", +) + +# Spoken representations + +MONTHS = [ + "", + "January", + "February", + "March", + "April", + "May", + "June", + "July", + "August", + "September", + "October", + "November", + "December", +] + +ORDINALS = { + 1: "first", + 2: "second", + 3: "third", + 4: "fourth", + 5: "fifth", + 6: "sixth", + 7: "seventh", + 8: "eighth", + 9: "ninth", + 10: "tenth", + 11: "eleventh", + 12: "twelfth", + 13: "thirteenth", + 14: "fourteenth", + 15: "fifteenth", + 16: "sixteenth", + 17: "seventeenth", + 18: "eighteenth", + 19: "nineteenth", + 20: "twentieth", + 21: "twenty first", + 22: "twenty second", + 23: "twenty third", + 24: "twenty fourth", + 25: "twenty fifth", + 26: "twenty sixth", + 27: "twenty seventh", + 28: "twenty eighth", + 29: "twenty ninth", + 30: "thirtieth", + 31: "thirty first", +} + + +def speak_year(year: int) -> str: + if 2000 <= year <= 2099: + return f"twenty {speak_number(year % 100)}" + elif 1900 <= year <= 1999: + return f"nineteen {speak_number(year % 100)}" + else: + return str(year) + + +def speak_number(n: int) -> str: + num_words = { + 0: "zero", + 1: "one", + 2: "two", + 3: "three", + 4: "four", + 5: "five", + 6: "six", + 7: "seven", + 8: "eight", + 9: "nine", + 10: "ten", + 11: "eleven", + 12: "twelve", + 13: "thirteen", + 14: "fourteen", + 15: "fifteen", + 16: "sixteen", + 17: "seventeen", + 18: "eighteen", + 19: "nineteen", + 20: "twenty", + 30: "thirty", + 40: "forty", + 50: "fifty", + 60: "sixty", + 70: "seventy", + 80: "eighty", + 90: "ninety", + } + if n <= 20: + return num_words[n] + elif n < 100: + tens, ones = divmod(n, 10) + return f"{num_words[tens * 10]} {num_words[ones]}" if ones else num_words[tens * 10] + else: + return str(n) + + +def parse_with_auto_dayfirst(date_str: str): + try: + # Try both ways + parsed_us = date_parser.parse(date_str, dayfirst=False) + parsed_eu = date_parser.parse(date_str, dayfirst=True) + + # If one of the parses clearly makes more sense, return it + if parsed_us.month > 12: + return parsed_eu + if parsed_eu.month > 12: + return parsed_us + + # If day is greater than 12, it's probably day-first + if parsed_us.day > 12 and parsed_eu.day <= 12: + return parsed_eu + elif parsed_eu.day > 12 and parsed_us.day <= 12: + return parsed_us + + # Default fallback (assumes US style) + return parsed_us + except Exception: + return None + + +def date_to_spoken_string(date_str: str) -> str: + parsed = parse_with_auto_dayfirst(date_str) + if not parsed: + return None + + month = MONTHS[parsed.month] + day = ORDINALS[parsed.day] + spoken = f"{month} {day} {speak_year(parsed.year)}" + + return spoken + + +def replace_dates_in_text(text: str) -> str: + # Regex pattern to match common date formats like: + # 5/22, 05/22/2025, 22/05/2025, 2025-05-22 + date_pattern = r'\b(?:\d{1,4}[-/])?\d{1,2}[-/]\d{1,4}\b' + + def replace_match(match): + date_str = match.group(0) + spoken = date_to_spoken_string(date_str) + return spoken if spoken else date_str + + return re.sub(date_pattern, replace_match, text) + + +def convert_to_spoken(text: str) -> str: + + text = replace_dates_in_text(text) # Convert dates to spoken form + + # Mapping of metric units to spoken forms + unit_map = { + "kg": "kilograms", + "g": "grams", + "mg": "milligrams", + "l": "liters", + "ml": "milliliters", + "cm": "centimeters", + "mm": "millimeters", + "m": "meters", + "km": "kilometers", + "°c": "degrees celsius", + "°f": "degrees fahrenheit", + "oz": "ounces", + "lb": "pounds", + "lbs": "pounds", + } + + # Replace metric units like "12kg" or "5 ml" + def replace_metric(match): + number = match.group(1) + unit = match.group(2).lower() + spoken_unit = unit_map.get(unit, unit) + return f"{number} {spoken_unit}" + + # Replace time like "5am" or "6PM" + def replace_ampm(match): + hour = match.group(1) + meridiem = match.group(2).lower() + return f"{hour} {'a m' if meridiem == 'am' else 'p m'}" + + # Replace time like "1:30pm" + def replace_colon_time(match): + hour = match.group(1) + minute = match.group(2) + meridiem = match.group(3).lower() + return f"{hour} {minute} {'a m' if meridiem == 'am' else 'p m'}" + + # Convert feet and inches like 5'11" to "5 feet 11 inches" + def replace_feet_inches(match): + feet = match.group(1) + inches = match.group(2) + return f"{feet} feet {inches} inches" + + # Convert just feet (e.g., 6') to "6 feet" + def replace_feet_only(match): + feet = match.group(1) + return f"{feet} feet" + + # Convert just inches (e.g., 10") to "10 inches" + def replace_inches_only(match): + inches = match.group(1) + return f"{inches} inches" + + # Apply replacements + # First: time with colon (e.g., 1:30pm) + text = re.sub(r'\b(\d{1,2}):(\d{2})(am|pm)\b', replace_colon_time, text, flags=re.IGNORECASE) + + # Then: basic am/pm (e.g., 5am) + text = re.sub(r'\b(\d{1,2})(am|pm)\b', replace_ampm, text, flags=re.IGNORECASE) + + # Then: replace 1st, 2nd, 3rd, etc + text = text.replace("1st", "first") + text = text.replace("2nd", "second") + text = text.replace("3rd", "third") + text = text.replace("@", " at ") + + # Finally: metric units + text = re.sub( + r'\b(\d+(?:\.\d+)?)\s?(kg|g|mg|l|ml|cm|mm|m|km|°c|°f|oz|lbs?|LB|LBS?)\b', + replace_metric, + text, + flags=re.IGNORECASE, + ) + text = re.sub(r'\b(\d+)\'(\d+)"', replace_feet_inches, text) # e.g., 5'11" + text = re.sub(r'\b(\d+)\'', replace_feet_only, text) # e.g., 6' + text = re.sub(r'(\d+)"', replace_inches_only, text) # e.g., 10" + + return text + + +def replace_numbers_with_words(text): + def convert_number(match): + num_str = match.group() + original = num_str + + # Remove dollar sign + is_dollar = False + if num_str.startswith('$'): + is_dollar = True + num_str = num_str[1:] + elif num_str.endswith('$'): + is_dollar = True + num_str = num_str[:-1] + + # Remove commas + num_str = num_str.replace(',', '') + + try: + if '.' in num_str: + # Convert decimal number + integer_part, decimal_part = num_str.split('.') + words = num2words(int(integer_part)) + ' point ' + ' '.join(num2words(int(d)) for d in decimal_part) + else: + words = num2words(int(num_str)) + if is_dollar: + words += ' dollars' + return words + " " + except Exception: + return original # Return original if conversion fails + + # Pattern matches: $3,000 or 3,000.45 or 1234 + pattern = re.compile(r'\$?\d{1,3}(?:,\d{3})*(?:\.\d+)?|\$?\d+(?:\.\d+)?') + result = pattern.sub(convert_number, text) + result = result.replace("$", " dollars ") # Handle dollar sign separately + + def merge_th(text: str) -> str: + # merge th with the preceding digit + candidates = ["four th ", "five th ", "six th ", "seven th ", "eight th ", "nine th "] + for key in candidates: + if key in text: + if "five" in key: + target = "fifth " + else: + target = f"{key.split(' ')[0]}th " + text = text.replace(key, target) + elif text.endswith(key.strip()): + if "five" in key: + target = "fifth" + else: + target = f"{key.split(' ')[0]}th" + text = text.replace(key.strip(), target) + return text + + result = merge_th(result) + result = " ".join(result.split()) # Remove extra spaces + return result + + +def unicode_to_ascii(text: str) -> str: + """ + Converts text with accented or special Latin characters (e.g., ó, ñ, ū, ō) + into their closest ASCII equivalents. + """ + # Normalize the string to NFKD to separate base characters from diacritics + normalized = unicodedata.normalize('NFKD', text) + + # Encode to ASCII bytes, ignoring characters that can't be converted + ascii_bytes = normalized.encode('ascii', 'ignore') + + # Decode back to string + ascii_text = ascii_bytes.decode('ascii') + + return ascii_text + + +def drop_punctuations(text: str) -> str: + """ + Clean the text by removing invalid characters and converting to lowercase. + + :param text: Input text. + :return: Cleaned text. + """ + valid_chars = "abcdefghijklmnopqrstuvwxyz'" + text = text.lower() + text = unicode_to_ascii(text) + text = text.replace(":", " ") + text = text.replace("-", " ") + text = ''.join([c for c in text if c in valid_chars or c.isspace()]) + text = ' '.join(text.split()).strip() + return text + + +def clean_label(_str: str) -> str: + """ + Remove unauthorized characters in a string, lower it and remove unneeded spaces + """ + # replace_with_space = [char for char in '/?*\",.:=?_{|}~¨«·»¡¿„…‧‹›≪≫!:;ː→'] + replace_with_blank = [char for char in '`¨´‘’“”`ʻ‘’“"‘”'] + replace_with_apos = [char for char in '‘’ʻ‘’‘'] + ["\u2019"] + _str = _str.strip() + for i in replace_with_blank: + _str = _str.replace(i, "") + for i in replace_with_apos: + _str = _str.replace(i, "'") + + text = _str + text = text.replace("\u2103", "celsius") + text = text.replace("\u2109", "fahrenheit") + text = text.replace("\u00b0", "degrees") + text = text.replace("\u2019", "'") + text = text.replace("\\", ".") + text = text.replace("\n", " ") + text = text.replace("\r", " ") + text = text.replace("\t", " ") + + ret = " ".join(text.split()) + return ret + + +def ends_with_punctuation(s: str) -> bool: + # Strip trailing whitespace + s = s.rstrip() + + # consider this set to be punctuation that's acceptable to end a sentence with + puncturation_chars = [",", ".", ":", ";", "?", "!", "-", "—", "–", "…"] + + # If string is empty after stripping, return False + if not s: + return False + + # Get the last character + last_char = s[-1] + + # Return True if the last character is punctuation, otherwise False + return last_char in puncturation_chars + + +def add_period_if_needed(text: str) -> str: + """ + Add a period at the end of the text if it does not already end with one. + """ + if not ends_with_punctuation(text): + text += "." + return text.strip() + + +def capitalize_self_i(text): + # Replace standalone lowercase "i" with "I" + # Handles "i", "i.", "i?", "i'll", "i'm", etc. + return re.sub(r'\b(i)(?=[\s.,!?;:\'\"-]|$)', r'I', text) + + +def add_space_after_punctuation(text): + # Add a space after punctuation if it's not already followed by one or by the end of the string + return re.sub(r'([,\.?;:])(?=\S)', r'\1 ', text) + + +def add_auto_capitalization(text): + if text.lower() != text: + # If the text is not all lowercase, we assume it has some capitalization + return text + + # Remove space before punctuation (.,!?;:) + text = re.sub(r'\s+([.,!?;:])', r'\1', text) + + # Capitalize the first letter of each sentence + def capitalize_sentences(match): + return match.group(1) + match.group(2).upper() + + # Ensure first character is capitalized + text = text.strip() + if text: + text = text[0].upper() + text[1:] + + text = capitalize_self_i(text) + text = add_space_after_punctuation(text) + # Capitalize after sentence-ending punctuation followed by space(s) + text = re.sub(r'([.!?]\s+)([a-z])', capitalize_sentences, text) + return text + + +def unicode_to_ascii(text: str) -> str: + """ + Converts text with accented or special Latin characters (e.g., ó, ñ, ū, ō) + into their closest ASCII equivalents. + """ + # Normalize the string to NFKD to separate base characters from diacritics + normalized = unicodedata.normalize('NFKD', text) + + # Encode to ASCII bytes, ignoring characters that can't be converted + ascii_bytes = normalized.encode('ascii', 'ignore') + + # Decode back to string + ascii_text = ascii_bytes.decode('ascii') + + return ascii_text + + +def clean_text(text: str, args) -> str: + """ + Clean the text based on the provided arguments. + """ + text = unicode_to_ascii(text) + if args.normalize: + text = text_normalizer(text) + if args.replace_numbers: + text = convert_to_spoken(text) + text = replace_numbers_with_words(text) + if args.lowercase: + text = text.lower() + if args.remove_punc: + text = text.replace("-", " ") + text = text.replace("_", " ") + text = text.translate(str.maketrans("", "", punctuations)) + text = drop_punctuations(text) + if args.auto_pc: + text = add_auto_capitalization(text) + return clean_label(text) + + +def clean_asr_manifest(manifest, text_field, args): + for i, item in enumerate(manifest): + text = str(item[text_field]) + manifest[i][f"origin_{text_field}"] = text + manifest[i][text_field] = clean_text(text, args) + return manifest + + +def clean_conv_manifest(manifest, text_field, args): + new_manifest = [] + for i, item in enumerate(manifest): + conversations = [] + for turn in item["conversations"]: + conversations.append( + { + "role": turn["role"], + "value": clean_text(turn["value"], args), + "type": turn.get("type", "text"), + } + ) + item["conversations"] = conversations + new_manifest.append(item) + return manifest + + +def main(args): + text_field = args.text_field + manifest_files = Path(args.input_manifest) + if manifest_files.is_dir(): + manifest_files = list(manifest_files.glob(args.pattern)) + elif manifest_files.is_file(): + manifest_files = [manifest_files] + else: + raise ValueError(f"Invalid input manifest path: {args.input_manifest}") + + for manifest_file in manifest_files: + print(f"Processing manifest file: {manifest_file}") + postfix = "-cleaned" + postfix += "_norm" if args.normalize else "" + postfix += "_n2w" if args.replace_numbers else "" + if args.lowercase and args.remove_punc: + postfix += "_noPC" + else: + postfix += "_lc" if args.lowercase else "" + postfix += "_np" if args.remove_punc else "" + postfix += "_aPC" if args.auto_pc else "" + + output_manifest = manifest_file.with_name(f"{manifest_file.stem}{postfix}{manifest_file.suffix}") + + if args.output: + if args.output.endswith(".json"): + if len(manifest_files) > 1: + raise ValueError("Output path must be a directory when processing multiple manifest files.") + output_manifest = Path(args.output) + else: + output_dir = Path(args.output) + output_dir.mkdir(parents=True, exist_ok=True) + if args.keep_name: + output_manifest = output_dir / manifest_file.name + else: + output_manifest = output_dir / output_manifest.name + + manifest = read_manifest(str(manifest_file)) + + if args.format == "asr": + manifest = clean_asr_manifest(manifest, text_field, args) + elif args.format == "conv": + manifest = clean_conv_manifest(manifest, text_field, args) + else: + raise ValueError(f"Unsupported manifest format: {args.format}") + + write_manifest(str(output_manifest), manifest) + print(f"Cleaned manifest saved to {output_manifest}") + + +if __name__ == "__main__": + args = parser.parse_args() + main(args) diff --git a/scripts/asr_eou/conf/data.yaml b/scripts/asr_eou/conf/data.yaml new file mode 100644 index 000000000000..c27bf129fa1e --- /dev/null +++ b/scripts/asr_eou/conf/data.yaml @@ -0,0 +1,46 @@ + +output_dir: ??? + +data: + pattern: "*.json" + manifest_filepath: ??? + tarred_audio_filepaths: null + sample_rate: 16000 + max_duration: 30 # you may need to update it for your dataset + min_duration: 0.1 + batch_duration: 300 # you may disable batch_duration by setting it to `null` + batch_size: null + shuffle: false + seed: 42 + num_workers: 8 + pin_memory: true + quadratic_duration: 30 + num_buckets: 30 + num_cuts_for_bins_estimate: 10000 + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + + random_padding: + prob: 1.0 + min_pad_duration: 0.0 # minimum duration of pre/post padding in seconds + max_pad_duration: 3.0 # maximum duration of pre/post padding in seconds + max_total_duration: 40.0 # maximum total duration of the padded audio in seconds + pad_distribution: 'constant' # distribution of padding duration, 'uniform' or 'normal' or 'constant' + pre_pad_duration: 0.2 + post_pad_duration: 3.0 + + augmentor: + white_noise: + prob: 0.0 + min_level: -90 + max_level: -40 + gain: + prob: 0.0 + min_gain_dbfs: -10.0 + max_gain_dbfs: 10.0 + noise: + prob: 1.0 + manifest_path: ??? + min_snr_db: 0 + max_snr_db: 20 + max_gain_db: 300.0 \ No newline at end of file diff --git a/scripts/asr_eou/eval_eou_metrics.py b/scripts/asr_eou/eval_eou_metrics.py new file mode 100644 index 000000000000..9c96afff98b7 --- /dev/null +++ b/scripts/asr_eou/eval_eou_metrics.py @@ -0,0 +1,177 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example usage: + +The PREDICTION_ROOT and REFERENCE_ROOT directories should have the following structure: + +: +->dataset1/ + eou/ + -> sample1.json + -> sample2.json +->dataset2/ + eou/ + -> sample1.json + -> sample2.json + +: +->dataset1/ + -> sample1.json + -> sample2.json +->dataset2/ + -> sample1.json + -> sample2.json + + +each sample.json should contain a list of dictionaries with the following fields: +{ + "session_id": str, + "start_time": float, # start time in seconds + "end_time": float, # end time in seconds + "words": str, # transcription of the utterance + "audio_filepath": str, # only in prediction + "eou_prob": float, # only in prediction, probability of EOU in range [0.1] + "eou_pred": bool, # only in prediction + "full_text": str, # only in prediction, which is the full transcription up to the end_time +} + +```bash +python eval_eou_with_niva.py \ + --prediction $PREDICTION_ROOT \ + --reference $REFERENCE_ROOT \ + --multiple +``` +""" + + +import argparse +import json +from pathlib import Path +from typing import List + +from nemo.collections.asr.parts.utils.eou_utils import EOUResult, aggregate_eou_metrics, evaluate_eou + +parser = argparse.ArgumentParser(description="Evaluate end of utterance predictions against reference labels.") +parser.add_argument( + "-p", + "--prediction", + type=str, + required=True, + help="Path to the directory containing the predictions.", +) +parser.add_argument( + "-r", + "--reference", + type=str, + required=True, + help="Path to the directory containing the groundtruth.", +) +parser.add_argument( + "--eob", + action="store_true", + help="Whether to evaluate end of backchannel predictions.", +) +parser.add_argument( + "--ignore_eob", + action="store_true", + help="Whether to ignore end of backchannel predictions.", +) +parser.add_argument( + "--multiple", + action="store_true", + help="Whether to evaluate multiple datasets.", +) + + +def load_segLST(directory: str, use_eob: bool = False, ignore_eob: bool = False) -> dict: + json_files = list(Path(directory).glob("*.json")) + segLST = {} + for json_file in json_files: + key = json_file.stem + with open(json_file, 'r') as f: + data = json.load(f) + assert isinstance(data, list), f"Data in {json_file} is not a list." + if not ignore_eob: + # get the data with the correct eob label + data = [x for x in data if (x.get("is_backchannel", False) == use_eob)] + segLST[key] = data + return segLST + + +def evaluate_eou_predictions( + prediction_dir: str, reference_dir: str, use_eob: bool = False, ignore_eob: bool = False +) -> List[EOUResult]: + prediction_dir = Path(prediction_dir) / "eou" + prediction_segLST = load_segLST(prediction_dir, use_eob, ignore_eob) + reference_segLST = load_segLST(reference_dir, use_eob, ignore_eob) + + eou_metrics = [] + for key, reference in reference_segLST.items(): + if key not in prediction_segLST: + raise ValueError(f"Key {key} in reference not found in predictions.") + prediction = prediction_segLST[key] + eou_result = evaluate_eou( + prediction=prediction, reference=reference, threshold=None, collar=0.0, do_sorting=True + ) + eou_metrics.append(eou_result) + + results = aggregate_eou_metrics(eou_metrics) + + # add prefix to the keys of the results + prefix = Path(reference_dir).stem + prefix += "_eob" if use_eob else "_eou" + results = {f"{prefix}_{k}": v for k, v in results.items()} + + return results + + +if __name__ == "__main__": + args = parser.parse_args() + + prediction_dir = Path(args.prediction) + reference_dir = Path(args.reference) + + if not prediction_dir.is_dir(): + raise ValueError(f"Prediction directory {prediction_dir} does not exist or is not a directory.") + if not reference_dir.is_dir(): + raise ValueError(f"Reference directory {reference_dir} does not exist or is not a directory.") + + if args.multiple: + # get all subdirectories in the prediction and reference directories + prediction_dirs = sorted([x for x in prediction_dir.glob("*/") if x.is_dir()]) + reference_dirs = sorted([x for x in reference_dir.glob("*/") if x.is_dir()]) + if len(prediction_dirs) != len(reference_dirs): + raise ValueError( + f"Number of prediction directories {len(prediction_dirs)} must match number of reference directories {len(reference_dirs)}." + ) + else: + prediction_dirs = [prediction_dir] + reference_dirs = [reference_dir] + + for ref_dir, pred_dir in zip(reference_dirs, prediction_dirs): + if args.multiple and ref_dir.stem != pred_dir.stem: + raise ValueError( + f"Reference directory {ref_dir} and prediction directory {pred_dir} must have the same name." + ) + results = evaluate_eou_predictions( + prediction_dir=str(pred_dir), reference_dir=str(ref_dir), use_eob=args.eob, ignore_eob=args.ignore_eob + ) + # Print the results + print("==========================================") + print(f"Evaluation Results for: {pred_dir} against {ref_dir}") + for key, value in results.items(): + print(f"{key}: {value:.4f}") + print("==========================================") diff --git a/scripts/asr_eou/generate_noisy_eval_data.py b/scripts/asr_eou/generate_noisy_eval_data.py new file mode 100644 index 000000000000..daf3f807d4aa --- /dev/null +++ b/scripts/asr_eou/generate_noisy_eval_data.py @@ -0,0 +1,211 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script is used to generate noisy evaluation data for ASR and end of utterance detection. + +Example usage with a single manifest input: +python generate_noisy_eval_data.py \ + --config-path conf/ \ + --config-name data \ + output_dir=/path/to/output \ + data.manifest_filepath=/path/to/manifest.json \ + data.seed=42 \ + data.noise.manifest_path /path/to/noise_manifest.json + + +Example usage with multiple manifests matching a pattern: +python generate_noisy_eval_data.py \ + --config-path conf/ \ + --config-name data \ + output_dir=/path/to/output/dir \ + data.manifest_filepath=/path/to/manifest/dir/ \ + data.pattern="*.json" \ + data.seed=42 \ + data.noise.manifest_path /path/to/noise_manifest.json + +""" + +from copy import deepcopy +from pathlib import Path +from shutil import rmtree + +import librosa +import lightning.pytorch as pl +import numpy as np +import soundfile as sf +import torch +import yaml +from lhotse.cut import MixedCut +from omegaconf import ListConfig, OmegaConf, open_dict +from tqdm import tqdm + +from nemo.collections.asr.data.audio_to_eou_label_lhotse import LhotseSpeechToTextBpeEOUDataset +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.common.parts.preprocessing import parsers +from nemo.core.config import hydra_runner +from nemo.utils import logging + + +@hydra_runner(config_path="conf/", config_name="data") +def main(cfg): + # Seed everything for reproducibility + seed = cfg.data.get('seed', None) + if seed is None: + seed = np.random.randint(0, 2**32 - 1) + logging.info(f'No seed provided, using random seed: {seed}') + logging.info(f'Setting random seed to {seed}') + with open_dict(cfg): + cfg.data.seed = seed + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + pl.seed_everything(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + # Patch data config + with open_dict(cfg.data): + cfg.data.force_finite = True + cfg.data.force_map_dataset = True + cfg.data.shuffle = False + cfg.data.check_tokenizer = False # No need to check tokenizer in LhotseSpeechToTextBpeEOUDataset + + # Make output directory + output_dir = Path(cfg.output_dir) + if output_dir.exists() and cfg.get('overwrite', False): + logging.info(f'Removing existing output directory: {output_dir}') + rmtree(output_dir) + if not output_dir.exists(): + logging.info(f'Creating output directory: {output_dir}') + output_dir.mkdir(parents=True, exist_ok=True) + + # Dump the config to the output directory + config = OmegaConf.to_container(cfg, resolve=True) + with open(output_dir / 'config.yaml', 'w') as f: + yaml.dump(config, f) + logging.info(f'Config dumped to {output_dir / "config.yaml"}') + + if isinstance(cfg.data.manifest_filepath, (list, ListConfig)): + manifest_list = [Path(x) for x in cfg.data.manifest_filepath] + else: + input_manifest_file = Path(cfg.data.manifest_filepath) + if input_manifest_file.is_dir(): + pattern = cfg.data.get('pattern', '*.json') + manifest_list = list(input_manifest_file.glob(pattern)) + if not manifest_list: + raise ValueError(f"No files found in {input_manifest_file} matching pattern `{pattern}`") + else: + manifest_list = [Path(x) for x in str(input_manifest_file).split(",")] + + logging.info(f'Found {len(manifest_list)} manifest files to process...') + + for i, manifest_file in enumerate(manifest_list): + logging.info(f'[{i+1}/{len(manifest_list)}] Processing {manifest_file}...') + data_cfg = deepcopy(cfg.data) + data_cfg.manifest_filepath = str(manifest_file) + process_manifest(data_cfg, output_dir) + + +def process_manifest(data_cfg, output_dir): + # Load the input manifest + input_manifest = read_manifest(data_cfg.manifest_filepath) + logging.info(f'Found {len(input_manifest)} items in input manifest: {data_cfg.manifest_filepath}') + manifest_parent_dir = Path(data_cfg.manifest_filepath).parent + if Path(input_manifest[0]["audio_filepath"]).is_absolute(): + output_audio_dir = output_dir / 'wav' + flatten_audio_path = True + else: + output_audio_dir = output_dir + flatten_audio_path = False + + if "random_padding" in data_cfg and data_cfg.random_padding.pad_distribution == "constant": + is_constant_padding = True + pre_pad_dur = data_cfg.random_padding.pre_pad_duration + else: + is_constant_padding = False + pre_pad_dur = None + + # Load the dataset + tokenizer = parsers.make_parser() # dummy tokenizer + dataset = LhotseSpeechToTextBpeEOUDataset(cfg=data_cfg, tokenizer=tokenizer, return_cuts=True) + + dataloader = get_lhotse_dataloader_from_config( + config=data_cfg, + global_rank=0, + world_size=1, + dataset=dataset, + tokenizer=tokenizer, + ) + + # Generate noisy evaluation data + manifest = [] + for i, batch in enumerate(tqdm(dataloader, desc="Generating noisy evaluation data")): + audio_batch, audio_len_batch, cuts_batch = batch + audio_batch = audio_batch.cpu().numpy() + audio_len_batch = audio_len_batch.cpu().numpy() + + for j in range(len(cuts_batch)): + cut = cuts_batch[j] + if isinstance(cut, MixedCut): + cut = cut.first_non_padding_cut + + manifest_item = {} + for k, v in cut.custom.items(): + if k == "dataloading_info": + continue + manifest_item[k] = v + audio = audio_batch[j][: audio_len_batch[j]] + audio_file = cut.recording.sources[0].source + + if flatten_audio_path: + output_audio_file = output_audio_dir / str(audio_file).replace('/', '_')[:255] # type: Path + else: + output_audio_file = output_audio_dir / Path(audio_file).relative_to(manifest_parent_dir) # type: Path + + output_audio_file.parent.mkdir(parents=True, exist_ok=True) + sf.write(output_audio_file, audio, dataset.sample_rate) + + manifest_item["audio_filepath"] = str(output_audio_file.relative_to(output_audio_dir)) + manifest_item["offset"] = 0 + manifest_item["duration"] = audio.shape[0] / dataset.sample_rate + + if is_constant_padding: + # Adjust the sou_time and eou_time for constant padding + if 'sou_time' in manifest_item and 'eou_time' in manifest_item: + if not isinstance(manifest_item['sou_time'], list): + manifest_item['sou_time'] = manifest_item['sou_time'] + pre_pad_dur + manifest_item['eou_time'] = manifest_item['eou_time'] + pre_pad_dur + else: + manifest_item['sou_time'] = [x + pre_pad_dur for x in manifest_item['sou_time']] + manifest_item['eou_time'] = [x + pre_pad_dur for x in manifest_item['eou_time']] + else: + # add sou_time and eou_time to the manifest item + manifest_item['sou_time'] = pre_pad_dur + manifest_item['eou_time'] = pre_pad_dur + librosa.get_duration(filename=audio_file) + + manifest.append(manifest_item) + + # Write the output manifest + output_manifest_file = output_dir / Path(data_cfg.manifest_filepath).name + write_manifest(output_manifest_file, manifest) + logging.info(f'Output manifest written to {output_manifest_file}') + + +if __name__ == "__main__": + main() diff --git a/scripts/asr_eou/tokenizers/add_special_tokens_to_sentencepiece.py b/scripts/asr_eou/tokenizers/add_special_tokens_to_sentencepiece.py new file mode 100644 index 000000000000..cc4f86d0a71b --- /dev/null +++ b/scripts/asr_eou/tokenizers/add_special_tokens_to_sentencepiece.py @@ -0,0 +1,187 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" +import logging +import sys +import tempfile +from argparse import ArgumentParser +from pathlib import Path + +import sentencepiece as spm + +from nemo.collections.asr.data.audio_to_eou_label_lhotse import EOB_STRING, EOU_STRING +from nemo.core.connectors.save_restore_connector import SaveRestoreConnector + +try: + import sentencepiece_model_pb2 as spt +except (ImportError, ModuleNotFoundError): + raise Exception("Ensure that sentencepiece_model_pb2.py has been generated from the protoc compiler") + + +SPECIAL_TOKENS = [EOU_STRING, EOB_STRING] + +"""Utility to add special tokens to existing sentencepiece models. + +Generate sentencepiece_model_pb2.py in the directory of this script before running +To generate run `protoc --python_out=/scripts/asr_end_of_utterance/tokenizers sentencepiece_model.proto` +inside the src folder in sentencepiece repo +Refer: https://github.com/google/sentencepiece/issues/121 + +Usage: +python add_special_tokens_to_sentencepiece.py \ + --input_file your_model.nemo \ + --output_dir /path/to/new/tokenizer_dir/ +""" + + +parser = ArgumentParser(description="Add special tokens to sentencepiece model") +parser.add_argument( + "-i", + "--input_file", + type=str, + required=True, + help="Path to nemo model file, or sentencepiece model file", +) +parser.add_argument( + "-o", + "--output_dir", + type=str, + required=True, + help="Path to output directory for new tokenizer", +) +parser.add_argument( + "--tokens", + type=str, + nargs='+', + help="Special tokens to add to tokenizer", + default=SPECIAL_TOKENS, +) +parser.add_argument( + "--extract_only", + action="store_true", + help="Extract tokenizer without adding special tokens", +) + + +def extract_nemo_tokenizer(nemo_filepath, output_dir): + SaveRestoreConnector._unpack_nemo_file(path2file=nemo_filepath, out_folder=output_dir) + tokenizer = None + for file in Path(output_dir).glob("**/*"): + if file.is_file() and file.name.endswith("tokenizer.model"): + tokenizer = file + break + if tokenizer is None: + raise ValueError(f"Tokenizer not found in {output_dir}: {os.listdir(output_dir)}") + return str(tokenizer.absolute()) + + +def edit_spt_model(input_file, output_dir, tokens, is_userdefined, extract_only=False): + if extract_only: + logging.info("Extracting tokenizer only, no special tokens will be added.") + + output_dir = Path(output_dir) + + if output_dir.exists(): + logging.warning(f"Output directory {output_dir} already exists. Overwriting it.") + + output_dir.mkdir(parents=True, exist_ok=True) + + output_file = str(output_dir / "tokenizer.model") + + token_type = 3 + if is_userdefined: + token_type = 4 + + model = spt.ModelProto() + with open(input_file, 'rb') as f: + model.ParseFromString(f.read()) + + if not extract_only: + for token in tokens: + piece = model.SentencePiece(piece=token, score=0.0, type=token_type) + if piece in model.pieces: + logging.error(f"Special Token '{token}' already exists in the input model!") + sys.exit(1) + model.pieces.append(piece) + + sp = spm.SentencePieceProcessor() + sp.LoadFromSerializedProto(model.SerializeToString()) + + if not extract_only: + try: + for token in tokens: + id = sp.piece_to_id(token) + logging.info(f"Created token '{token}' at ID {id}") + logging.info(f"New tokenizer vocab size: {sp.get_piece_size()}") + except Exception: + logging.error( + "Could not appropriately configure new tokenizer. Verify if the special tokens already exist." + ) + sys.exit(1) + + with open(output_file, 'wb') as outf: + outf.write(model.SerializeToString()) + logging.info(f"Created new tokenizer at: {output_file}") + + # Write the vocab to file + vocab_file = str(output_dir / "tokenizer.vocab") + with open(vocab_file, "w", encoding="utf-8") as f: + for i in range(sp.get_piece_size()): + piece = sp.id_to_piece(i) + score = sp.get_score(i) # Optional: only available if using newer SentencePiece versions + f.write(f"{piece}\t{score}\n") # Format follows the original vocab format + logging.info(f"Created new tokenizer vocab at: {vocab_file}") + + special_tokens = ["", "", "", ""] + special_tokens.extend(tokens) + vocab_txt_file = str(output_dir / "vocab.txt") + with open(vocab_txt_file, "w", encoding="utf-8") as f: + for i in range(sp.get_piece_size()): + piece = sp.id_to_piece(i) + if piece in special_tokens: + # skip special tokens + continue + token = piece[1:] if piece.startswith("▁") else f"##{piece}" + if len(token) > 0: + f.write(f"{token}\n") # Format follows the original vocab format + logging.info(f"Created new tokenizer vocab at: {vocab_txt_file}") + + +def inject_special_tokens(input_file, output_dir, tokens, is_userdefined=True, extract_only=False): + """ + NOTE: is_userdefined should be set to True in order for ASR model to work + with the new special tokens properly. + """ + if not os.path.exists(input_file): + raise ValueError(f"Input file {input_file} does not exist") + + with tempfile.TemporaryDirectory() as temp_dir: + # Check if input file is a Nemo file + if input_file.endswith(".nemo"): + input_file = extract_nemo_tokenizer(input_file, temp_dir) + logging.info(f"Extracted tokenizer from Nemo file: {input_file}") + else: + input_file = os.path.abspath(input_file) + logging.info(f"Using input file: {input_file}") + + edit_spt_model(input_file, output_dir, tokens, is_userdefined, extract_only=extract_only) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + args = parser.parse_args() + inject_special_tokens(args.input_file, args.output_dir, args.tokens, extract_only=args.extract_only) diff --git a/scripts/asr_eou/tokenizers/sentencepiece_model_pb2.py b/scripts/asr_eou/tokenizers/sentencepiece_model_pb2.py new file mode 100644 index 000000000000..cb97411349aa --- /dev/null +++ b/scripts/asr_eou/tokenizers/sentencepiece_model_pb2.py @@ -0,0 +1,1442 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: sentencepiece_model.proto + +import sys + +_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='sentencepiece_model.proto', + package='sentencepiece', + syntax='proto2', + serialized_options=_b('H\003'), + serialized_pb=_b( + '\n\x19sentencepiece_model.proto\x12\rsentencepiece\"\xa4\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12\"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12\"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18\" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05\x12\x16\n\tbos_piece\x18. \x01(\t:\x03\x12\x17\n\teos_piece\x18/ \x01(\t:\x04\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse\x12\"\n\x18seed_sentencepieces_file\x18\x36 \x01(\t:\x00\"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32\".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL\"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03' + ), +) + + +_TRAINERSPEC_MODELTYPE = _descriptor.EnumDescriptor( + name='ModelType', + full_name='sentencepiece.TrainerSpec.ModelType', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor(name='UNIGRAM', index=0, number=1, serialized_options=None, type=None), + _descriptor.EnumValueDescriptor(name='BPE', index=1, number=2, serialized_options=None, type=None), + _descriptor.EnumValueDescriptor(name='WORD', index=2, number=3, serialized_options=None, type=None), + _descriptor.EnumValueDescriptor(name='CHAR', index=3, number=4, serialized_options=None, type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=1553, + serialized_end=1606, +) +_sym_db.RegisterEnumDescriptor(_TRAINERSPEC_MODELTYPE) + +_MODELPROTO_SENTENCEPIECE_TYPE = _descriptor.EnumDescriptor( + name='Type', + full_name='sentencepiece.ModelProto.SentencePiece.Type', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor(name='NORMAL', index=0, number=1, serialized_options=None, type=None), + _descriptor.EnumValueDescriptor(name='UNKNOWN', index=1, number=2, serialized_options=None, type=None), + _descriptor.EnumValueDescriptor(name='CONTROL', index=2, number=3, serialized_options=None, type=None), + _descriptor.EnumValueDescriptor(name='USER_DEFINED', index=3, number=4, serialized_options=None, type=None), + _descriptor.EnumValueDescriptor(name='BYTE', index=4, number=6, serialized_options=None, type=None), + _descriptor.EnumValueDescriptor(name='UNUSED', index=5, number=5, serialized_options=None, type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=2359, + serialized_end=2443, +) +_sym_db.RegisterEnumDescriptor(_MODELPROTO_SENTENCEPIECE_TYPE) + + +_TRAINERSPEC = _descriptor.Descriptor( + name='TrainerSpec', + full_name='sentencepiece.TrainerSpec', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='input', + full_name='sentencepiece.TrainerSpec.input', + index=0, + number=1, + type=9, + cpp_type=9, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='input_format', + full_name='sentencepiece.TrainerSpec.input_format', + index=1, + number=7, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='model_prefix', + full_name='sentencepiece.TrainerSpec.model_prefix', + index=2, + number=2, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='model_type', + full_name='sentencepiece.TrainerSpec.model_type', + index=3, + number=3, + type=14, + cpp_type=8, + label=1, + has_default_value=True, + default_value=1, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='vocab_size', + full_name='sentencepiece.TrainerSpec.vocab_size', + index=4, + number=4, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=8000, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='accept_language', + full_name='sentencepiece.TrainerSpec.accept_language', + index=5, + number=5, + type=9, + cpp_type=9, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='self_test_sample_size', + full_name='sentencepiece.TrainerSpec.self_test_sample_size', + index=6, + number=6, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='enable_differential_privacy', + full_name='sentencepiece.TrainerSpec.enable_differential_privacy', + index=7, + number=50, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='differential_privacy_noise_level', + full_name='sentencepiece.TrainerSpec.differential_privacy_noise_level', + index=8, + number=51, + type=2, + cpp_type=6, + label=1, + has_default_value=True, + default_value=float(0), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='differential_privacy_clipping_threshold', + full_name='sentencepiece.TrainerSpec.differential_privacy_clipping_threshold', + index=9, + number=52, + type=4, + cpp_type=4, + label=1, + has_default_value=True, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='character_coverage', + full_name='sentencepiece.TrainerSpec.character_coverage', + index=10, + number=10, + type=2, + cpp_type=6, + label=1, + has_default_value=True, + default_value=float(0.9995), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='input_sentence_size', + full_name='sentencepiece.TrainerSpec.input_sentence_size', + index=11, + number=11, + type=4, + cpp_type=4, + label=1, + has_default_value=True, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='shuffle_input_sentence', + full_name='sentencepiece.TrainerSpec.shuffle_input_sentence', + index=12, + number=19, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='mining_sentence_size', + full_name='sentencepiece.TrainerSpec.mining_sentence_size', + index=13, + number=12, + type=5, + cpp_type=1, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=_b('\030\001'), + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='training_sentence_size', + full_name='sentencepiece.TrainerSpec.training_sentence_size', + index=14, + number=13, + type=5, + cpp_type=1, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=_b('\030\001'), + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='seed_sentencepiece_size', + full_name='sentencepiece.TrainerSpec.seed_sentencepiece_size', + index=15, + number=14, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=1000000, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='shrinking_factor', + full_name='sentencepiece.TrainerSpec.shrinking_factor', + index=16, + number=15, + type=2, + cpp_type=6, + label=1, + has_default_value=True, + default_value=float(0.75), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='max_sentence_length', + full_name='sentencepiece.TrainerSpec.max_sentence_length', + index=17, + number=18, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=4192, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='num_threads', + full_name='sentencepiece.TrainerSpec.num_threads', + index=18, + number=16, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=16, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='num_sub_iterations', + full_name='sentencepiece.TrainerSpec.num_sub_iterations', + index=19, + number=17, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=2, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='max_sentencepiece_length', + full_name='sentencepiece.TrainerSpec.max_sentencepiece_length', + index=20, + number=20, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=16, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='split_by_unicode_script', + full_name='sentencepiece.TrainerSpec.split_by_unicode_script', + index=21, + number=21, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='split_by_number', + full_name='sentencepiece.TrainerSpec.split_by_number', + index=22, + number=23, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='split_by_whitespace', + full_name='sentencepiece.TrainerSpec.split_by_whitespace', + index=23, + number=22, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='treat_whitespace_as_suffix', + full_name='sentencepiece.TrainerSpec.treat_whitespace_as_suffix', + index=24, + number=24, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='allow_whitespace_only_pieces', + full_name='sentencepiece.TrainerSpec.allow_whitespace_only_pieces', + index=25, + number=26, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='split_digits', + full_name='sentencepiece.TrainerSpec.split_digits', + index=26, + number=25, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='pretokenization_delimiter', + full_name='sentencepiece.TrainerSpec.pretokenization_delimiter', + index=27, + number=53, + type=9, + cpp_type=9, + label=1, + has_default_value=True, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='control_symbols', + full_name='sentencepiece.TrainerSpec.control_symbols', + index=28, + number=30, + type=9, + cpp_type=9, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='user_defined_symbols', + full_name='sentencepiece.TrainerSpec.user_defined_symbols', + index=29, + number=31, + type=9, + cpp_type=9, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='required_chars', + full_name='sentencepiece.TrainerSpec.required_chars', + index=30, + number=36, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='byte_fallback', + full_name='sentencepiece.TrainerSpec.byte_fallback', + index=31, + number=35, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='vocabulary_output_piece_score', + full_name='sentencepiece.TrainerSpec.vocabulary_output_piece_score', + index=32, + number=32, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='hard_vocab_limit', + full_name='sentencepiece.TrainerSpec.hard_vocab_limit', + index=33, + number=33, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='use_all_vocab', + full_name='sentencepiece.TrainerSpec.use_all_vocab', + index=34, + number=34, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='unk_id', + full_name='sentencepiece.TrainerSpec.unk_id', + index=35, + number=40, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='bos_id', + full_name='sentencepiece.TrainerSpec.bos_id', + index=36, + number=41, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=1, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='eos_id', + full_name='sentencepiece.TrainerSpec.eos_id', + index=37, + number=42, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=2, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='pad_id', + full_name='sentencepiece.TrainerSpec.pad_id', + index=38, + number=43, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=-1, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='unk_piece', + full_name='sentencepiece.TrainerSpec.unk_piece', + index=39, + number=45, + type=9, + cpp_type=9, + label=1, + has_default_value=True, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='bos_piece', + full_name='sentencepiece.TrainerSpec.bos_piece', + index=40, + number=46, + type=9, + cpp_type=9, + label=1, + has_default_value=True, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='eos_piece', + full_name='sentencepiece.TrainerSpec.eos_piece', + index=41, + number=47, + type=9, + cpp_type=9, + label=1, + has_default_value=True, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='pad_piece', + full_name='sentencepiece.TrainerSpec.pad_piece', + index=42, + number=48, + type=9, + cpp_type=9, + label=1, + has_default_value=True, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='unk_surface', + full_name='sentencepiece.TrainerSpec.unk_surface', + index=43, + number=44, + type=9, + cpp_type=9, + label=1, + has_default_value=True, + default_value=_b(" \342\201\207 ").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='train_extremely_large_corpus', + full_name='sentencepiece.TrainerSpec.train_extremely_large_corpus', + index=44, + number=49, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='seed_sentencepieces_file', + full_name='sentencepiece.TrainerSpec.seed_sentencepieces_file', + index=45, + number=54, + type=9, + cpp_type=9, + label=1, + has_default_value=True, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[ + _TRAINERSPEC_MODELTYPE, + ], + serialized_options=None, + is_extendable=True, + syntax='proto2', + extension_ranges=[ + (200, 536870912), + ], + oneofs=[], + serialized_start=45, + serialized_end=1617, +) + + +_NORMALIZERSPEC = _descriptor.Descriptor( + name='NormalizerSpec', + full_name='sentencepiece.NormalizerSpec', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', + full_name='sentencepiece.NormalizerSpec.name', + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='precompiled_charsmap', + full_name='sentencepiece.NormalizerSpec.precompiled_charsmap', + index=1, + number=2, + type=12, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b(""), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='add_dummy_prefix', + full_name='sentencepiece.NormalizerSpec.add_dummy_prefix', + index=2, + number=3, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='remove_extra_whitespaces', + full_name='sentencepiece.NormalizerSpec.remove_extra_whitespaces', + index=3, + number=4, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='escape_whitespaces', + full_name='sentencepiece.NormalizerSpec.escape_whitespaces', + index=4, + number=5, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='normalization_rule_tsv', + full_name='sentencepiece.NormalizerSpec.normalization_rule_tsv', + index=5, + number=6, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=True, + syntax='proto2', + extension_ranges=[ + (200, 536870912), + ], + oneofs=[], + serialized_start=1620, + serialized_end=1829, +) + + +_SELFTESTDATA_SAMPLE = _descriptor.Descriptor( + name='Sample', + full_name='sentencepiece.SelfTestData.Sample', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='input', + full_name='sentencepiece.SelfTestData.Sample.input', + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='expected', + full_name='sentencepiece.SelfTestData.Sample.expected', + index=1, + number=2, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[], + serialized_start=1900, + serialized_end=1941, +) + +_SELFTESTDATA = _descriptor.Descriptor( + name='SelfTestData', + full_name='sentencepiece.SelfTestData', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='samples', + full_name='sentencepiece.SelfTestData.samples', + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[ + _SELFTESTDATA_SAMPLE, + ], + enum_types=[], + serialized_options=None, + is_extendable=True, + syntax='proto2', + extension_ranges=[ + (200, 536870912), + ], + oneofs=[], + serialized_start=1831, + serialized_end=1952, +) + + +_MODELPROTO_SENTENCEPIECE = _descriptor.Descriptor( + name='SentencePiece', + full_name='sentencepiece.ModelProto.SentencePiece', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='piece', + full_name='sentencepiece.ModelProto.SentencePiece.piece', + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='score', + full_name='sentencepiece.ModelProto.SentencePiece.score', + index=1, + number=2, + type=2, + cpp_type=6, + label=1, + has_default_value=False, + default_value=float(0), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='type', + full_name='sentencepiece.ModelProto.SentencePiece.type', + index=2, + number=3, + type=14, + cpp_type=8, + label=1, + has_default_value=True, + default_value=1, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[ + _MODELPROTO_SENTENCEPIECE_TYPE, + ], + serialized_options=None, + is_extendable=True, + syntax='proto2', + extension_ranges=[ + (200, 536870912), + ], + oneofs=[], + serialized_start=2244, + serialized_end=2454, +) + +_MODELPROTO = _descriptor.Descriptor( + name='ModelProto', + full_name='sentencepiece.ModelProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='pieces', + full_name='sentencepiece.ModelProto.pieces', + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='trainer_spec', + full_name='sentencepiece.ModelProto.trainer_spec', + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='normalizer_spec', + full_name='sentencepiece.ModelProto.normalizer_spec', + index=2, + number=3, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='self_test_data', + full_name='sentencepiece.ModelProto.self_test_data', + index=3, + number=4, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='denormalizer_spec', + full_name='sentencepiece.ModelProto.denormalizer_spec', + index=4, + number=5, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[ + _MODELPROTO_SENTENCEPIECE, + ], + enum_types=[], + serialized_options=None, + is_extendable=True, + syntax='proto2', + extension_ranges=[ + (200, 536870912), + ], + oneofs=[], + serialized_start=1955, + serialized_end=2465, +) + +_TRAINERSPEC.fields_by_name['model_type'].enum_type = _TRAINERSPEC_MODELTYPE +_TRAINERSPEC_MODELTYPE.containing_type = _TRAINERSPEC +_SELFTESTDATA_SAMPLE.containing_type = _SELFTESTDATA +_SELFTESTDATA.fields_by_name['samples'].message_type = _SELFTESTDATA_SAMPLE +_MODELPROTO_SENTENCEPIECE.fields_by_name['type'].enum_type = _MODELPROTO_SENTENCEPIECE_TYPE +_MODELPROTO_SENTENCEPIECE.containing_type = _MODELPROTO +_MODELPROTO_SENTENCEPIECE_TYPE.containing_type = _MODELPROTO_SENTENCEPIECE +_MODELPROTO.fields_by_name['pieces'].message_type = _MODELPROTO_SENTENCEPIECE +_MODELPROTO.fields_by_name['trainer_spec'].message_type = _TRAINERSPEC +_MODELPROTO.fields_by_name['normalizer_spec'].message_type = _NORMALIZERSPEC +_MODELPROTO.fields_by_name['self_test_data'].message_type = _SELFTESTDATA +_MODELPROTO.fields_by_name['denormalizer_spec'].message_type = _NORMALIZERSPEC +DESCRIPTOR.message_types_by_name['TrainerSpec'] = _TRAINERSPEC +DESCRIPTOR.message_types_by_name['NormalizerSpec'] = _NORMALIZERSPEC +DESCRIPTOR.message_types_by_name['SelfTestData'] = _SELFTESTDATA +DESCRIPTOR.message_types_by_name['ModelProto'] = _MODELPROTO +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +TrainerSpec = _reflection.GeneratedProtocolMessageType( + 'TrainerSpec', + (_message.Message,), + dict( + DESCRIPTOR=_TRAINERSPEC, + __module__='sentencepiece_model_pb2', + # @@protoc_insertion_point(class_scope:sentencepiece.TrainerSpec) + ), +) +_sym_db.RegisterMessage(TrainerSpec) + +NormalizerSpec = _reflection.GeneratedProtocolMessageType( + 'NormalizerSpec', + (_message.Message,), + dict( + DESCRIPTOR=_NORMALIZERSPEC, + __module__='sentencepiece_model_pb2', + # @@protoc_insertion_point(class_scope:sentencepiece.NormalizerSpec) + ), +) +_sym_db.RegisterMessage(NormalizerSpec) + +SelfTestData = _reflection.GeneratedProtocolMessageType( + 'SelfTestData', + (_message.Message,), + dict( + Sample=_reflection.GeneratedProtocolMessageType( + 'Sample', + (_message.Message,), + dict( + DESCRIPTOR=_SELFTESTDATA_SAMPLE, + __module__='sentencepiece_model_pb2', + # @@protoc_insertion_point(class_scope:sentencepiece.SelfTestData.Sample) + ), + ), + DESCRIPTOR=_SELFTESTDATA, + __module__='sentencepiece_model_pb2', + # @@protoc_insertion_point(class_scope:sentencepiece.SelfTestData) + ), +) +_sym_db.RegisterMessage(SelfTestData) +_sym_db.RegisterMessage(SelfTestData.Sample) + +ModelProto = _reflection.GeneratedProtocolMessageType( + 'ModelProto', + (_message.Message,), + dict( + SentencePiece=_reflection.GeneratedProtocolMessageType( + 'SentencePiece', + (_message.Message,), + dict( + DESCRIPTOR=_MODELPROTO_SENTENCEPIECE, + __module__='sentencepiece_model_pb2', + # @@protoc_insertion_point(class_scope:sentencepiece.ModelProto.SentencePiece) + ), + ), + DESCRIPTOR=_MODELPROTO, + __module__='sentencepiece_model_pb2', + # @@protoc_insertion_point(class_scope:sentencepiece.ModelProto) + ), +) +_sym_db.RegisterMessage(ModelProto) +_sym_db.RegisterMessage(ModelProto.SentencePiece) + + +DESCRIPTOR._options = None +_TRAINERSPEC.fields_by_name['mining_sentence_size']._options = None +_TRAINERSPEC.fields_by_name['training_sentence_size']._options = None +# @@protoc_insertion_point(module_scope) diff --git a/scripts/checkpoint_averaging/checkpoint_averaging.py b/scripts/checkpoint_averaging/checkpoint_averaging.py new file mode 100755 index 000000000000..863fc820acb9 --- /dev/null +++ b/scripts/checkpoint_averaging/checkpoint_averaging.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Builds a .nemo file with average weights over multiple .ckpt files (assumes .ckpt files in same folder as .nemo file). + +Usage example for building *-averaged.nemo for a given .nemo file: + +NeMo/scripts/checkpoint_averaging/checkpoint_averaging.py my_model.nemo + +Usage example for building *-averaged.nemo files for all results in sub-directories under current path: + +find . -name '*.nemo' | grep -v -- "-averaged.nemo" | xargs NeMo/scripts/checkpoint_averaging/checkpoint_averaging.py + + +NOTE: if yout get the following error `AttributeError: Can't get attribute '???' on ` + use --import_fname_list with all files that contains missing classes. +""" + +import argparse +import glob +import importlib +import os +import sys + +import torch +from tqdm.auto import tqdm + +from nemo.core import ModelPT +from nemo.utils import logging, model_utils + + +def main(): + """ + Main function + """ + + logging.info("This script is deprecated and will be removed in the 25.01 release.") + + parser = argparse.ArgumentParser() + parser.add_argument( + 'model_fname_list', + metavar='NEMO_FILE_OR_FOLDER', + type=str, + nargs='+', + help='Input .nemo files (or folders who contains them) to parse', + ) + parser.add_argument( + '--import_fname_list', + metavar='FILE', + type=str, + nargs='+', + default=[], + help='A list of Python file names to "from FILE import *"', + ) + parser.add_argument( + '--class_path', + type=str, + default='', + help='A path to class "module.submodule.class" (if given)', + ) + args = parser.parse_args() + + logging.info( + f"\n\nIMPORTANT:\nIf you get the following error:\n\t" + "(AttributeError: Can't get attribute '???' on )\nuse:\n\t" + "--import_fname_list\nfor all files that contain missing classes.\n\n" + ) + + for fn in args.import_fname_list: + logging.info(f"Importing * from {fn}") + sys.path.insert(0, os.path.dirname(fn)) + globals().update(importlib.import_module(os.path.splitext(os.path.basename(fn))[0]).__dict__) + + device = torch.device("cpu") + + # loop over all folders with .nemo files (or .nemo files) + for model_fname_i, model_fname in enumerate(args.model_fname_list): + if not model_fname.endswith(".nemo"): + # assume model_fname is a folder which contains a .nemo file + nemo_files = list( + filter(lambda fn: not fn.endswith("-averaged.nemo"), glob.glob(os.path.join(model_fname, "*.nemo"))) + ) + if len(nemo_files) != 1: + raise RuntimeError(f"Expected exactly one .nemo file but discovered {len(nemo_files)} .nemo files") + + model_fname = nemo_files[0] + + model_folder_path = os.path.dirname(model_fname) + fn, fe = os.path.splitext(model_fname) + avg_model_fname = f"{fn}-averaged{fe}" + + logging.info(f"\n===> [{model_fname_i+1} / {len(args.model_fname_list)}] Parsing folder {model_folder_path}\n") + + # restore model from .nemo file path + model_cfg = ModelPT.restore_from(restore_path=model_fname, return_config=True) + if args.class_path: + classpath = args.class_path + else: + classpath = model_cfg.target # original class path + imported_class = model_utils.import_class_by_path(classpath) + logging.info(f"Loading model {model_fname}") + nemo_model = imported_class.restore_from(restore_path=model_fname, map_location=device) + + # search for all checkpoints (ignore -last.ckpt) + checkpoint_paths = [ + os.path.join(model_folder_path, x) + for x in os.listdir(model_folder_path) + if x.endswith('.ckpt') and not x.endswith('-last.ckpt') + ] + """ < Checkpoint Averaging Logic > """ + # load state dicts + n = len(checkpoint_paths) + avg_state = None + + logging.info(f"Averaging {n} checkpoints ...") + + for ix, path in enumerate(tqdm(checkpoint_paths, total=n, desc='Averaging checkpoints')): + checkpoint = torch.load(path, map_location=device, weights_only=False) + + if 'state_dict' in checkpoint: + checkpoint = checkpoint['state_dict'] + else: + raise RuntimeError(f"Checkpoint from {path} does not include a state_dict.") + + if ix == 0: + # Initial state + avg_state = checkpoint + + logging.info(f"Initialized average state dict with checkpoint:\n\t{path}") + else: + # Accumulated state + for k in avg_state: + avg_state[k] = avg_state[k] + checkpoint[k] + + logging.info(f"Updated average state dict with state from checkpoint:\n\t{path}") + + for k in avg_state: + if str(avg_state[k].dtype).startswith("torch.int"): + # For int type, not averaged, but only accumulated. + # e.g. BatchNorm.num_batches_tracked + pass + else: + avg_state[k] = avg_state[k] / n + + # restore merged weights into model + nemo_model.load_state_dict(avg_state, strict=True) + # Save model + logging.info(f"Saving average model to:\n\t{avg_model_fname}") + nemo_model.save_to(avg_model_fname) + + +if __name__ == '__main__': + main() diff --git a/scripts/speech_recognition/convert_to_tarred_audio_dataset.py b/scripts/speech_recognition/convert_to_tarred_audio_dataset.py index e8562e686671..50c0de65985b 100644 --- a/scripts/speech_recognition/convert_to_tarred_audio_dataset.py +++ b/scripts/speech_recognition/convert_to_tarred_audio_dataset.py @@ -85,7 +85,7 @@ from dataclasses import dataclass, field from datetime import datetime from io import BytesIO -from typing import Any, List, Optional +from typing import Any, List, Optional, Union import numpy as np import soundfile @@ -563,15 +563,42 @@ def create_concatenated_dataset( metadata_yaml = OmegaConf.structured(metadata) OmegaConf.save(metadata_yaml, new_metadata_path, resolve=True) - def _read_manifest(self, manifest_path: str, config: ASRTarredDatasetConfig): + def _read_manifest(self, manifest_path: Union[str, List[str]], config: ASRTarredDatasetConfig): """Read and filters data from the manifest""" + entries = [] + total_duration = 0.0 + filtered_entries = [] + filtered_duration = 0.0 + + if isinstance(manifest_path, str): + manifest_paths = manifest_path.split(",") + else: + manifest_paths = manifest_path + + print(f"Found {len(manifest_paths)} manifest files to be processed") + for manifest_file in manifest_paths: + entries_i, total_dur_i, filtered_ent_i, filtered_dur_i = self._read_single_manifest( + str(manifest_file), config + ) + entries.extend(entries_i) + total_duration += total_dur_i + filtered_entries.extend(filtered_ent_i) + filtered_duration += filtered_dur_i + + return entries, total_duration, filtered_entries, filtered_duration + + def _read_single_manifest(self, manifest_path: str, config: ASRTarredDatasetConfig): # Read the existing manifest entries = [] total_duration = 0.0 filtered_entries = [] filtered_duration = 0.0 + print(f"Reading manifest: {manifest_path}") with open(manifest_path, 'r', encoding='utf-8') as m: for line in m: + line = line.strip() + if not line: + continue entry = json.loads(line) audio_key = "audio_filepath" if "audio_filepath" in entry else "audio_file" if config.slice_with_offset and "offset" not in entry: diff --git a/scripts/speech_recognition/oomptimizer.py b/scripts/speech_recognition/oomptimizer.py index b96fbec6ac46..65e17fda7bdf 100755 --- a/scripts/speech_recognition/oomptimizer.py +++ b/scripts/speech_recognition/oomptimizer.py @@ -17,7 +17,6 @@ import math import sys from numbers import Number -from typing import Iterable, Literal import click import lightning.pytorch as pl @@ -375,7 +374,10 @@ def oomptimizer( config_path is None and module_name is None ), "--pretrained-name cannot be used together with --module-name/--config-path" click.echo(f"Intializing ASR model from pretrained checkpoint {pretrained_name}.") - model = ASRModel.from_pretrained(pretrained_name, trainer=trainer).to(device) + if pretrained_name.endswith('.nemo'): + model = ASRModel.restore_from(pretrained_name, trainer=trainer).to(device) + else: + model = ASRModel.from_pretrained(pretrained_name, trainer=trainer).to(device) else: assert config_path is not None, "--module-name requires --config-path to be specified as well." assert module_name is not None, "--config-path requires --module-name to be specified as well." diff --git a/tools/nemo_forced_aligner/align_eou.py b/tools/nemo_forced_aligner/align_eou.py new file mode 100644 index 000000000000..54e7bce8ff6f --- /dev/null +++ b/tools/nemo_forced_aligner/align_eou.py @@ -0,0 +1,574 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json +import math +import os +import shutil +import unicodedata +import uuid +from dataclasses import dataclass, field, is_dataclass +from pathlib import Path +from string import punctuation +from typing import List, Optional + +import torch +from omegaconf import OmegaConf +from utils.data_prep import ( + get_batch_starts_ends, + get_manifest_lines_batch, + is_entry_in_all_lines, + is_entry_in_any_lines, +) +from utils.make_ass_files import make_ass_files +from utils.make_ctm_files import make_ctm_files +from utils.make_output_manifest import write_manifest_out_line + +from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel +from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR +from nemo.collections.asr.parts.utils.transcribe_utils import setup_model +from nemo.core.config import hydra_runner +from nemo.utils import logging + +try: + from nemo.collections.asr.parts.utils.aligner_utils import ( + add_t_start_end_to_utt_obj, + get_batch_variables, + viterbi_decoding, + ) +except ImportError: + raise ImportError( + "Missing required dependency for NFA. " + "Install NeMo with NFA utilities support:\n" + " pip install 'nemo_toolkit[all]>=2.5.0'\n" + "Or install the latest development version:\n" + " pip install git+https://github.com/NVIDIA/NeMo.git" + ) + +""" +Align the utterances in manifest_filepath. +Results are saved in ctm files in output_dir as well as json manifest in output_manifest_filepath. +If no output_manifest_filepath is specified, it will save the results in the same parent directory as +the input manifest_filepath. + +Arguments: + pretrained_name: string specifying the name of a CTC NeMo ASR model which will be automatically downloaded + from NGC and used for generating the log-probs which we will use to do alignment. + Note: NFA can only use CTC models (not Transducer models) at the moment. + model_path: string specifying the local filepath to a CTC NeMo ASR model which will be used to generate the + log-probs which we will use to do alignment. + Note: NFA can only use CTC models (not Transducer models) at the moment. + Note: if a model_path is provided, it will override the pretrained_name. + manifest_filepath: filepath to the manifest of the data you want to align, + containing 'audio_filepath' and 'text' fields. + output_dir: the folder where output CTM files and new JSON manifest will be saved. + output_manifest_filepath: Optional[str] = None # output of manfiest with sou_time and eou_time + manifest_pattern: Optional[str] = None # pattern used in Path.glob() for finding manifests + + align_using_pred_text: if True, will transcribe the audio using the specified model and then use that transcription + as the reference text for the forced alignment. + transcribe_device: None, or a string specifying the device that will be used for generating log-probs (i.e. "transcribing"). + The string needs to be in a format recognized by torch.device(). If None, NFA will set it to 'cuda' if it is available + (otherwise will set it to 'cpu'). + viterbi_device: None, or string specifying the device that will be used for doing Viterbi decoding. + The string needs to be in a format recognized by torch.device(). If None, NFA will set it to 'cuda' if it is available + (otherwise will set it to 'cpu'). + batch_size: int specifying batch size that will be used for generating log-probs and doing Viterbi decoding. + use_local_attention: boolean flag specifying whether to try to use local attention for the ASR Model (will only + work if the ASR Model is a Conformer model). If local attention is used, we will set the local attention context + size to [64,64]. + additional_segment_grouping_separator: an optional string used to separate the text into smaller segments. + If this is not specified, then the whole text will be treated as a single segment. + remove_blank_tokens_from_ctm: a boolean denoting whether to remove tokens from token-level output CTMs. + audio_filepath_parts_in_utt_id: int specifying how many of the 'parts' of the audio_filepath + we will use (starting from the final part of the audio_filepath) to determine the + utt_id that will be used in the CTM files. Note also that any spaces that are present in the audio_filepath + will be replaced with dashes, so as not to change the number of space-separated elements in the + CTM files. + e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 1 => utt_id will be "e1" + e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 2 => utt_id will be "d_e1" + e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 3 => utt_id will be "c_d_e1" + use_buffered_infer: False, if set True, using streaming to do get the logits for alignment + This flag is useful when aligning large audio file. + However, currently the chunk streaming inference does not support batch inference, + which means even you set batch_size > 1, it will only infer one by one instead of doing + the whole batch inference together. + chunk_len_in_secs: float chunk length in seconds + total_buffer_in_secs: float Length of buffer (chunk + left and right padding) in seconds + chunk_batch_size: int batch size for buffered chunk inference, + which will cut one audio into segments and do inference on chunk_batch_size segments at a time + + simulate_cache_aware_streaming: False, if set True, using cache aware streaming to do get the logits for alignment + + save_output_file_formats: List of strings specifying what type of output files to save (default: ["ctm", "ass"]) + ctm_file_config: CTMFileConfig to specify the configuration of the output CTM files + ass_file_config: ASSFileConfig to specify the configuration of the output ASS files +""" + + +@dataclass +class CTMFileConfig: + remove_blank_tokens: bool = False + # minimum duration (in seconds) for timestamps in the CTM.If any line in the CTM has a + # duration lower than this, it will be enlarged from the middle outwards until it + # meets the minimum_timestamp_duration, or reaches the beginning or end of the audio file. + # Note that this may cause timestamps to overlap. + minimum_timestamp_duration: float = 0 + + +@dataclass +class ASSFileConfig: + fontsize: int = 20 + vertical_alignment: str = "center" + # if resegment_text_to_fill_space is True, the ASS files will use new segments + # such that each segment will not take up more than (approximately) max_lines_per_segment + # when the ASS file is applied to a video + resegment_text_to_fill_space: bool = False + max_lines_per_segment: int = 2 + text_already_spoken_rgb: List[int] = field(default_factory=lambda: [49, 46, 61]) # dark gray + text_being_spoken_rgb: List[int] = field(default_factory=lambda: [57, 171, 9]) # dark green + text_not_yet_spoken_rgb: List[int] = field(default_factory=lambda: [194, 193, 199]) # light gray + + +@dataclass +class AlignmentConfig: + # Required configs + pretrained_name: Optional[str] = None + model_path: Optional[str] = None + manifest_filepath: Optional[str] = None # path to manifest file or directory + output_dir: Optional[str] = '.tmp' # set it to .tmp and will be removed after alignment + output_manifest_filepath: Optional[str] = None # output of manfiest with sou_time and eou_time + manifest_pattern: Optional[str] = None # pattern used in Path.glob() for finding manifests + + # General configs + align_using_pred_text: bool = False + transcribe_device: Optional[str] = None + viterbi_device: Optional[str] = None + batch_size: int = 1 + use_local_attention: bool = True + additional_segment_grouping_separator: Optional[str] = None + audio_filepath_parts_in_utt_id: int = 4 + + # Buffered chunked streaming configs + use_buffered_chunked_streaming: bool = False + chunk_len_in_secs: float = 1.6 + total_buffer_in_secs: float = 4.0 + chunk_batch_size: int = 32 + + # Cache aware streaming configs + simulate_cache_aware_streaming: Optional[bool] = False + + # Output file configs + save_output_file_formats: List[str] = field(default_factory=lambda: ["ctm", "ass"]) + ctm_file_config: CTMFileConfig = field(default_factory=lambda: CTMFileConfig()) + ass_file_config: ASSFileConfig = field(default_factory=lambda: ASSFileConfig()) + + # remove tmp dir after alignment + remove_tmp_dir: bool = False + clean_text: bool = True + + # For multi-node multi-gpu processing + num_nodes: int = 1 # total num of nodes/machines + num_gpus: int = 1 # num of GPUs per node/machine + node_idx: int = 0 # current node index + gpu_idx: int = 0 # current GPU index + + +def unicode_to_ascii(text: str) -> str: + """ + Converts text with accented or special Latin characters (e.g., ó, ñ, ū, ō) + into their closest ASCII equivalents. + """ + # Normalize the string to NFKD to separate base characters from diacritics + normalized = unicodedata.normalize('NFKD', text) + + # Encode to ASCII bytes, ignoring characters that can't be converted + ascii_bytes = normalized.encode('ascii', 'ignore') + + # Decode back to string + ascii_text = ascii_bytes.decode('ascii') + + return ascii_text + + +def drop_pnc(text): + """ + Clean the text by removing invalid characters and converting to lowercase. + + :param text: Input text. + :return: Cleaned text. + """ + valid_chars = "abcdefghijklmnopqrstuvwxyz'" + text = text.lower() + text = unicode_to_ascii(text) + text = text.replace(":", " ") + text = ''.join([c for c in text if c in valid_chars or c.isspace() or c == "'"]) + return " ".join(text.split()).strip() + + +def clean_text(manifest: List[dict]): + punctuations = punctuation.replace("'", "") + # replace_with_space = [char for char in '/?*\",.:=?_{|}~¨«·»¡¿„…‧‹›≪≫!:;ː→'] + replace_with_blank = [char for char in '`¨´‘’“”`ʻ‘’“"‘”'] + replace_with_apos = [char for char in '‘’ʻ‘’‘'] + + for i in range(len(manifest)): + text = manifest[i]["text"].strip().lower() # type: str + text = text.translate(str.maketrans("", "", punctuations)) + text = drop_pnc(text) + for c in replace_with_blank: + text = text.replace(c, "") + for c in replace_with_apos: + text = text.replace(c, "'") + manifest[i]["text"] = text + return manifest + + +def get_manifests_for_this_rank(manifest_list, num_nodes, num_gpus, node_idx, gpu_idx): + """ + Get the manifest files for this rank. + """ + if len(manifest_list) == 0: + return manifest_list + + assert num_nodes > 0, "num_nodes must be greater than 0" + assert num_gpus > 0, "num_gpus must be greater than 0" + assert 0 <= node_idx < num_nodes, f"node_idx {node_idx} must be between 0 and {num_nodes - 1}" + assert 0 <= gpu_idx < num_gpus, f"gpu_idx {gpu_idx} must be between 0 and {num_gpus - 1}" + + manifests_this_node = [] + for i, manifest_file in enumerate(manifest_list): + if num_nodes > 1: + if i % num_nodes == node_idx: + manifests_this_node.append(manifest_file) + else: + manifests_this_node.append(manifest_file) + + manifests_this_gpu = [] + for i, manifest_file in enumerate(manifests_this_node): + if num_gpus > 1: + if i % num_gpus == gpu_idx: + manifests_this_gpu.append(manifest_file) + else: + manifests_this_gpu.append(manifest_file) + return manifests_this_gpu + + +@hydra_runner(config_name="AlignmentConfig", schema=AlignmentConfig) +def main(cfg: AlignmentConfig): + + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + if is_dataclass(cfg): + cfg = OmegaConf.structured(cfg) + + # Validate config + if cfg.model_path is None and cfg.pretrained_name is None: + raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None") + + if cfg.model_path is not None and cfg.pretrained_name is not None: + raise ValueError("One of cfg.model_path and cfg.pretrained_name must be None") + + if cfg.manifest_filepath is None: + raise ValueError("cfg.manifest_filepath must be specified") + + if cfg.output_dir is None and not cfg.remove_tmp_dir: + raise ValueError("cfg.output_dir must be specified if cfg.remove_tmp_dir is False") + + if cfg.batch_size < 1: + raise ValueError("cfg.batch_size cannot be zero or a negative number") + + if cfg.additional_segment_grouping_separator == "" or cfg.additional_segment_grouping_separator == " ": + raise ValueError("cfg.additional_grouping_separator cannot be empty string or space character") + + if cfg.ctm_file_config.minimum_timestamp_duration < 0: + raise ValueError("cfg.minimum_timestamp_duration cannot be a negative number") + + if cfg.ass_file_config.vertical_alignment not in ["top", "center", "bottom"]: + raise ValueError("cfg.ass_file_config.vertical_alignment must be one of 'top', 'center' or 'bottom'") + + for rgb_list in [ + cfg.ass_file_config.text_already_spoken_rgb, + cfg.ass_file_config.text_already_spoken_rgb, + cfg.ass_file_config.text_already_spoken_rgb, + ]: + if len(rgb_list) != 3: + raise ValueError( + "cfg.ass_file_config.text_already_spoken_rgb," + " cfg.ass_file_config.text_being_spoken_rgb," + " and cfg.ass_file_config.text_already_spoken_rgb all need to contain" + " exactly 3 elements." + ) + + # init devices + if cfg.transcribe_device is None: + transcribe_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + transcribe_device = torch.device(cfg.transcribe_device) + logging.info(f"Device to be used for transcription step (`transcribe_device`) is {transcribe_device}") + + if cfg.viterbi_device is None: + viterbi_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + viterbi_device = torch.device(cfg.viterbi_device) + logging.info(f"Device to be used for viterbi step (`viterbi_device`) is {viterbi_device}") + + if transcribe_device.type == 'cuda' or viterbi_device.type == 'cuda': + logging.warning( + 'One or both of transcribe_device and viterbi_device are GPUs. If you run into OOM errors ' + 'it may help to change both devices to be the CPU.' + ) + + # load model + model, _ = setup_model(cfg, transcribe_device) + model.eval() + + if isinstance(model, EncDecHybridRNNTCTCModel): + model.change_decoding_strategy(decoder_type="ctc") + + if cfg.use_local_attention: + logging.info( + "Flag use_local_attention is set to True => will try to use local attention for model if it allows it" + ) + model.change_attention_model(self_attention_model="rel_pos_local_attn", att_context_size=[64, 64]) + + if not (isinstance(model, EncDecCTCModel) or isinstance(model, EncDecHybridRNNTCTCModel)): + raise NotImplementedError( + f"Model is not an instance of NeMo EncDecCTCModel or ENCDecHybridRNNTCTCModel." + " Currently only instances of these models are supported" + ) + + if cfg.ctm_file_config.minimum_timestamp_duration > 0: + logging.warning( + f"cfg.ctm_file_config.minimum_timestamp_duration has been set to {cfg.ctm_file_config.minimum_timestamp_duration} seconds. " + "This may cause the alignments for some tokens/words/additional segments to be overlapping." + ) + + buffered_chunk_params = {} + if cfg.use_buffered_chunked_streaming: + model_cfg = copy.deepcopy(model._cfg) + + OmegaConf.set_struct(model_cfg.preprocessor, False) + # some changes for streaming scenario + model_cfg.preprocessor.dither = 0.0 + model_cfg.preprocessor.pad_to = 0 + + if model_cfg.preprocessor.normalize != "per_feature": + logging.error( + "Only EncDecCTCModelBPE models trained with per_feature normalization are supported currently" + ) + # Disable config overwriting + OmegaConf.set_struct(model_cfg.preprocessor, True) + + feature_stride = model_cfg.preprocessor['window_stride'] + model_stride_in_secs = feature_stride * cfg.model_downsample_factor + total_buffer = cfg.total_buffer_in_secs + chunk_len = float(cfg.chunk_len_in_secs) + tokens_per_chunk = math.ceil(chunk_len / model_stride_in_secs) + mid_delay = math.ceil((chunk_len + (total_buffer - chunk_len) / 2) / model_stride_in_secs) + logging.info(f"tokens_per_chunk is {tokens_per_chunk}, mid_delay is {mid_delay}") + + model = FrameBatchASR( + asr_model=model, + frame_len=chunk_len, + total_buffer=cfg.total_buffer_in_secs, + batch_size=cfg.chunk_batch_size, + ) + buffered_chunk_params = { + "delay": mid_delay, + "model_stride_in_secs": model_stride_in_secs, + "tokens_per_chunk": tokens_per_chunk, + } + + if Path(cfg.manifest_filepath).is_file(): + manifest_list = [cfg.manifest_filepath] + elif Path(cfg.manifest_filepath).is_dir(): + if cfg.manifest_pattern is not None: + manifest_list = list(Path(cfg.manifest_filepath).glob(cfg.manifest_pattern)) + else: + manifest_list = list(Path(cfg.manifest_filepath).glob("*.json")) + else: + raise ValueError( + f"cfg.manifest_filepath is not a valid file or directory. " + f"Please check the path: {cfg.manifest_filepath}" + ) + + origin_output_manifest_filepath = cfg.output_manifest_filepath + + manifest_list = get_manifests_for_this_rank(manifest_list, cfg.num_nodes, cfg.num_gpus, cfg.node_idx, cfg.gpu_idx) + logging.info(f"Found {len(manifest_list)} manifest files to process.") + # process each manifest file + for manifest_filepath in manifest_list: + logging.info(f"Processing manifest file: {manifest_filepath}") + cfg.manifest_filepath = str(manifest_filepath) + + if origin_output_manifest_filepath is None: + manifest_stem = Path(manifest_filepath).stem + cfg.output_manifest_filepath = str(Path(manifest_filepath).parent / f"{manifest_stem}-aligned.json") + elif len(manifest_list) > 1 and origin_output_manifest_filepath is not None: + raise ValueError( + "cfg.output_manifest_filepath must be None when processing multiple manifest files. " + "Please set it to None." + ) + + if not cfg.remove_tmp_dir and len(manifest_list) > 1: + # if keep alignment files, then we need to set output_dir to be different for each manifest + cfg.output_dir = str(Path(manifest_filepath).parent / f"{Path(manifest_filepath).stem}_alignment") + + process_single_manifest(cfg, model, buffered_chunk_params, viterbi_device) + logging.info(f"Output manifest saved to: {cfg.output_manifest_filepath}") + + logging.info("All manifest files processed successfully.") + + +def process_single_manifest(cfg: AlignmentConfig, model, buffered_chunk_params, viterbi_device): + # Validate manifest contents + if not is_entry_in_all_lines(cfg.manifest_filepath, "audio_filepath"): + raise RuntimeError( + "At least one line in cfg.manifest_filepath does not contain an 'audio_filepath' entry. " + "All lines must contain an 'audio_filepath' entry." + ) + + if cfg.align_using_pred_text: + if is_entry_in_any_lines(cfg.manifest_filepath, "pred_text"): + raise RuntimeError( + "Cannot specify cfg.align_using_pred_text=True when the manifest at cfg.manifest_filepath " + "contains 'pred_text' entries. This is because the audio will be transcribed and may produce " + "a different 'pred_text'. This may cause confusion." + ) + else: + if not is_entry_in_all_lines(cfg.manifest_filepath, "text"): + raise RuntimeError( + "At least one line in cfg.manifest_filepath does not contain a 'text' entry. " + "NFA requires all lines to contain a 'text' entry when cfg.align_using_pred_text=False." + ) + + # get start and end line IDs of batches + starts, ends = get_batch_starts_ends(cfg.manifest_filepath, cfg.batch_size) + + # init output_timestep_duration = None and we will calculate and update it during the first batch + output_timestep_duration = None + + if cfg.remove_tmp_dir and cfg.output_dir is None: + cfg.output_dir = f"alignment-{uuid.uuid4()}" + + # init f_manifest_out + os.makedirs(cfg.output_dir, exist_ok=True) + tgt_manifest_name = str(Path(cfg.manifest_filepath).stem) + "_with_output_file_paths.json" + tgt_manifest_filepath = str(Path(cfg.output_dir) / tgt_manifest_name) + f_manifest_out = open(tgt_manifest_filepath, 'w') + + # get alignment and save in CTM batch-by-batch + for start, end in zip(starts, ends): + manifest_lines_batch = get_manifest_lines_batch(cfg.manifest_filepath, start, end) + + if cfg.clean_text: + manifest_lines_batch = clean_text(manifest_lines_batch) + ( + log_probs_batch, + y_batch, + T_batch, + U_batch, + utt_obj_batch, + output_timestep_duration, + ) = get_batch_variables( + manifest_lines_batch, + model, + cfg.additional_segment_grouping_separator, + cfg.align_using_pred_text, + cfg.audio_filepath_parts_in_utt_id, + output_timestep_duration, + cfg.simulate_cache_aware_streaming, + cfg.use_buffered_chunked_streaming, + buffered_chunk_params, + ) + + alignments_batch = viterbi_decoding(log_probs_batch, y_batch, T_batch, U_batch, viterbi_device) + + for utt_obj, alignment_utt in zip(utt_obj_batch, alignments_batch): + + utt_obj = add_t_start_end_to_utt_obj(utt_obj, alignment_utt, output_timestep_duration) + + if "ctm" in cfg.save_output_file_formats: + utt_obj = make_ctm_files( + utt_obj, + cfg.output_dir, + cfg.ctm_file_config, + ) + + if "ass" in cfg.save_output_file_formats: + utt_obj = make_ass_files(utt_obj, cfg.output_dir, cfg.ass_file_config) + + write_manifest_out_line( + f_manifest_out, + utt_obj, + ) + + f_manifest_out.close() + + # adding eou processing here + input_manifest_lines = [] + with open(cfg.manifest_filepath, 'r') as f: + for line in f.readlines(): + if line.strip(): + input_manifest_lines.append(json.loads(line)) + + output_manifest_lines = [] + with open(tgt_manifest_filepath, 'r') as f: + for i, line in enumerate(f.readlines()): + item = json.loads(line) + assert os.path.basename(input_manifest_lines[i]['audio_filepath']) == os.path.basename( + item['audio_filepath'] + ) + + if 'segments_level_ctm_filepath' not in item: + print( + f"`segments_level_ctm_filepath` not found for {input_manifest_lines[i]['audio_filepath']}, skipping" + ) + continue + + # get sou/eou time + with open(item['segments_level_ctm_filepath']) as f: + lines = [line.split() for line in f] + start_time = min([float(line[2]) for line in lines]) + end_time = max([float(line[2]) + float(line[3]) for line in lines]) + input_manifest_lines[i]['sou_time'] = start_time + input_manifest_lines[i]['eou_time'] = end_time + output_manifest_lines.append(input_manifest_lines[i]) + + with open(cfg.output_manifest_filepath, 'w') as f: + for item in output_manifest_lines: + f.write(json.dumps(item) + '\n') + + if cfg.remove_tmp_dir: # safely removing tmp dir after alignment + for file_or_folder in [ + tgt_manifest_filepath, + os.path.join(cfg.output_dir, 'ctm'), + os.path.join(cfg.output_dir, 'ass'), + ]: + if os.path.exists(file_or_folder): + if os.path.isfile(file_or_folder): + os.remove(file_or_folder) + else: + shutil.rmtree(file_or_folder) + if os.path.exists(cfg.output_dir) and len(os.listdir(cfg.output_dir)) == 0: + shutil.rmtree(cfg.output_dir) + + return None + + +if __name__ == "__main__": + main() diff --git a/tools/nemo_forced_aligner/utils/data_prep.py b/tools/nemo_forced_aligner/utils/data_prep.py index 3386b5744108..05900899c74e 100644 --- a/tools/nemo_forced_aligner/utils/data_prep.py +++ b/tools/nemo_forced_aligner/utils/data_prep.py @@ -69,6 +69,7 @@ def get_manifest_lines_batch(manifest_filepath, start, end): for line_i, line in enumerate(f): if line_i >= start and line_i <= end: data = json.loads(line) + data["audio_filepath"] = get_full_path(data["audio_filepath"], manifest_filepath) if "text" in data: # remove any BOM, any duplicated spaces, convert any # newline chars to spaces