@@ -20,6 +20,7 @@ class EvaluationMetric(enum.Enum):
2020 F1_SCORE = 'f1_score'
2121 CONFUSION_MATRIX = 'confusion_matrix'
2222 INSTRUCTION_POINTER = 'instruction_pointer'
23+ INSTRUCTION_POINTER_ENTROPY = 'instruction_pointer_entropy'
2324
2425
2526def all_metric_names () -> Tuple [str ]:
@@ -161,6 +162,33 @@ def instruction_pointers_to_images(instruction_pointer, multidevice: bool):
161162 return jnp .array (instruction_pointer_image_list )
162163
163164
165+ def instruction_pointers_to_entropy (instruction_pointer , multidevice : bool ):
166+ """Converts the given batched instruction pointer to an entropy value.
167+
168+ The entropy value measures the sharpness of the instruction pointer, i.e. how
169+ hard vs soft it is.
170+ """
171+ if multidevice :
172+ # instruction_pointer: device, batch_size / device, timesteps, num_nodes
173+ instruction_pointer = instruction_pointer [0 ]
174+
175+ # instruction_pointer: batch_size / device, timesteps, num_nodes
176+ instruction_pointer = jnp .transpose (instruction_pointer [:, :16 , :],
177+ (1 , 2 , 0 ))
178+ # instruction_pointer: logging_slice_size, num_nodes, timesteps
179+ instruction_pointer_image_list = [
180+ instruction_pointer_to_image (ip )
181+ for ip in instruction_pointer
182+ ]
183+ instruction_pointer_image_leading_dim_max = max (
184+ image .shape [0 ] for image in instruction_pointer_image_list )
185+ instruction_pointer_image_list = [
186+ pad (image , instruction_pointer_image_leading_dim_max )
187+ for image in instruction_pointer_image_list
188+ ]
189+ return jnp .array (instruction_pointer_image_list )
190+
191+
164192def pad (array , leading_dim_size : int ):
165193 """Pad the leading dimension of the given array."""
166194 leading_dim_difference = max (0 , leading_dim_size - array .shape [0 ])
0 commit comments