Skip to content
This repository was archived by the owner on Jan 22, 2024. It is now read-only.

Commit cdb5d43

Browse files
committed
Add instruction pointer entropy metric.
Work in-progress.
1 parent 71deb99 commit cdb5d43

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

core/lib/metrics.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class EvaluationMetric(enum.Enum):
2020
F1_SCORE = 'f1_score'
2121
CONFUSION_MATRIX = 'confusion_matrix'
2222
INSTRUCTION_POINTER = 'instruction_pointer'
23+
INSTRUCTION_POINTER_ENTROPY = 'instruction_pointer_entropy'
2324

2425

2526
def all_metric_names() -> Tuple[str]:
@@ -161,6 +162,33 @@ def instruction_pointers_to_images(instruction_pointer, multidevice: bool):
161162
return jnp.array(instruction_pointer_image_list)
162163

163164

165+
def instruction_pointers_to_entropy(instruction_pointer, multidevice: bool):
166+
"""Converts the given batched instruction pointer to an entropy value.
167+
168+
The entropy value measures the sharpness of the instruction pointer, i.e. how
169+
hard vs soft it is.
170+
"""
171+
if multidevice:
172+
# instruction_pointer: device, batch_size / device, timesteps, num_nodes
173+
instruction_pointer = instruction_pointer[0]
174+
175+
# instruction_pointer: batch_size / device, timesteps, num_nodes
176+
instruction_pointer = jnp.transpose(instruction_pointer[:, :16, :],
177+
(1, 2, 0))
178+
# instruction_pointer: logging_slice_size, num_nodes, timesteps
179+
instruction_pointer_image_list = [
180+
instruction_pointer_to_image(ip)
181+
for ip in instruction_pointer
182+
]
183+
instruction_pointer_image_leading_dim_max = max(
184+
image.shape[0] for image in instruction_pointer_image_list)
185+
instruction_pointer_image_list = [
186+
pad(image, instruction_pointer_image_leading_dim_max)
187+
for image in instruction_pointer_image_list
188+
]
189+
return jnp.array(instruction_pointer_image_list)
190+
191+
164192
def pad(array, leading_dim_size: int):
165193
"""Pad the leading dimension of the given array."""
166194
leading_dim_difference = max(0, leading_dim_size - array.shape[0])

core/lib/trainer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,14 @@ def run_train(self, dataset_path=DEFAULT_DATASET_PATH, split='train', steps=None
378378
transform_fn=functools.partial(
379379
metrics.instruction_pointers_to_images,
380380
multidevice=config.multidevice))
381+
metrics.write_metric(
382+
EvaluationMetric.INSTRUCTION_POINTER_ENTROPY.value,
383+
aux,
384+
train_writer.scalar,
385+
step,
386+
transform_fn=functools.partial(
387+
metrics.instruction_pointers_to_entropy,
388+
multidevice=config.multidevice))
381389

382390
# Write validation metrics.
383391
valid_writer.scalar('loss', valid_loss, step)

0 commit comments

Comments
 (0)