diff --git a/pytext/metric_reporters/word_tagging_metric_reporter.py b/pytext/metric_reporters/word_tagging_metric_reporter.py index 925f10e0c..6ab843dce 100644 --- a/pytext/metric_reporters/word_tagging_metric_reporter.py +++ b/pytext/metric_reporters/word_tagging_metric_reporter.py @@ -5,6 +5,7 @@ from collections import Counter from typing import Dict, List, NamedTuple +import torch from pytext.common.constants import DatasetFieldName, Stage from pytext.data import CommonMetadata from pytext.metrics import ( @@ -103,6 +104,43 @@ def __init__(self, label_names, pad_idx, channels, label_vocabs=None): self.label_vocabs = label_vocabs super().__init__(channels) + def add_batch_stats( + self, n_batches, preds, targets, scores, loss, m_input, **context + ): + """ + Aggregates a batch of output data (predictions, scores, targets/true labels + and loss). + + Args: + n_batches (int): number of current batch + preds (torch.Tensor): predictions of current batch + targets (torch.Tensor): targets of current batch + scores (torch.Tensor): scores of current batch + loss (double): average loss of current batch + m_input (Tuple[torch.Tensor, ...]): model inputs of current batch + context (Dict[str, Any]): any additional context data, it could be + either a list of data which maps to each example, or a single value + for the batch + """ + self.n_batches = n_batches + self.aggregate_preds(preds, context) + self.aggregate_targets(targets, context) + self.aggregate_scores(scores) + for key, val in context.items(): + if not (isinstance(val, torch.Tensor) or isinstance(val, List)): + continue + if key not in self.all_context: + self.all_context[key] = [] + self.aggregate_data(self.all_context[key], val) + if loss is not None: + self.all_loss.append(float(loss)) + self.batch_size.append(len(m_input[-1])) + + # realtime stats + if DatasetFieldName.NUM_TOKENS in context: + self.realtime_meters["tps"].update(context[DatasetFieldName.NUM_TOKENS]) + self.realtime_meters["ups"].update(1) + @classmethod def from_config(cls, config, tensorizers): return MultiLabelSequenceTaggingMetricReporter( @@ -129,6 +167,58 @@ def aggregate_targets(self, batch_targets, batch_context=None): def aggregate_scores(self, batch_scores): self.aggregate_tuple_data(self.all_scores, batch_scores) + def report_metric( + self, + model, + stage, + epoch, + reset=True, + print_to_channels=True, + optimizer=None, + privacy_engine=None, # to be handled by the subclassed metric reporters + ): + """ + Calculate metrics and average loss, report all statistic data to channels + + Args: + model (nn.Module): the PyTorch neural network model. + stage (Stage): training, evaluation or test + epoch (int): current epoch + reset (bool): if all data should be reset after report, default is True + print_to_channels (bool): if report data to channels, default is True + """ + self.gen_extra_context() + self.total_loss = self.calculate_loss() + metrics = self.calculate_metric() + model_select_metric = self.get_model_select_metric(metrics) + + # print_to_channels is true only on gpu 0, but we need all gpus to sync + # metric + self.report_realtime_metric(stage) + + if print_to_channels: + for channel in self.channels: + if stage in channel.stages: + channel.report( + stage, + epoch, + metrics, + model_select_metric, + self.total_loss, + self.predictions_to_report(), + self.targets_to_report(), + self.all_scores, + self.all_context, + self.get_meta(), + model, + optimizer, + ) + + if reset: + self._reset() + self._reset_realtime() + return metrics + def calculate_metric(self): list_score_pred_expect = [] for label_idx, _ in enumerate(self.label_names):