diff --git a/core/lib/metrics.py b/core/lib/metrics.py index bef84a42..6061a154 100644 --- a/core/lib/metrics.py +++ b/core/lib/metrics.py @@ -20,6 +20,7 @@ class EvaluationMetric(enum.Enum): F1_SCORE = 'f1_score' CONFUSION_MATRIX = 'confusion_matrix' INSTRUCTION_POINTER = 'instruction_pointer' + INSTRUCTION_POINTER_ENTROPY = 'instruction_pointer_entropy' def all_metric_names() -> Tuple[str]: @@ -161,6 +162,21 @@ def instruction_pointers_to_images(instruction_pointer, multidevice: bool): return jnp.array(instruction_pointer_image_list) +def instruction_pointers_to_entropy(instruction_pointer, multidevice: bool): + """Converts the given batched instruction pointer to an entropy value. + + The entropy value measures the sharpness of the instruction pointer, i.e. how + hard vs soft it is. + """ + if multidevice: + # instruction_pointer: device, batch_size / device, timesteps, num_nodes + instruction_pointer = instruction_pointer[0] + + # instruction_pointer: batch_size / device, timesteps, num_nodes + # TODO: Implement entropy calculation. + raise NotImplementedError() + + def pad(array, leading_dim_size: int): """Pad the leading dimension of the given array.""" leading_dim_difference = max(0, leading_dim_size - array.shape[0]) diff --git a/core/lib/trainer.py b/core/lib/trainer.py index 401bfe65..2a200a53 100644 --- a/core/lib/trainer.py +++ b/core/lib/trainer.py @@ -378,6 +378,14 @@ def run_train(self, dataset_path=DEFAULT_DATASET_PATH, split='train', steps=None transform_fn=functools.partial( metrics.instruction_pointers_to_images, multidevice=config.multidevice)) + metrics.write_metric( + EvaluationMetric.INSTRUCTION_POINTER_ENTROPY.value, + aux, + train_writer.scalar, + step, + transform_fn=functools.partial( + metrics.instruction_pointers_to_entropy, + multidevice=config.multidevice)) # Write validation metrics. valid_writer.scalar('loss', valid_loss, step)