Skip to content
This repository was archived by the owner on Jan 22, 2024. It is now read-only.

Commit 55bc7d7

Browse files
committed
Add a prediction visualization script.
Work in-progress.
1 parent b8d7f31 commit 55bc7d7

File tree

4 files changed

+285
-2
lines changed

4 files changed

+285
-2
lines changed

core/data/data_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def get_padded_shapes(max_tokens, max_num_nodes, max_num_edges, include_strings=
147147
'problem_id': [1],
148148
'submission_id': [1],
149149
})
150-
150+
151151
return shapes
152152

153153

core/lib/metrics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ def make_figure(*,
128128
fig = plt.figure()
129129
ax = fig.add_subplot(111)
130130
ax.set_title(title)
131-
plt.imshow(data, interpolation=interpolation, **kwargs)
131+
# plt.imshow(data, interpolation=interpolation, **kwargs)
132+
plt.imshow(data, cmap='Greys', interpolation=interpolation, **kwargs)
132133
ax.set_aspect('equal')
133134
ax.set_xlabel(xlabel)
134135
ax.set_ylabel(ylabel)
File renamed without changes.

scripts/visualize_predictions.py

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
"""Visualize model predictions."""
2+
3+
import dataclasses
4+
import os
5+
6+
from absl import app
7+
from absl import flags
8+
9+
from flax.training import checkpoints
10+
from flax.training import common_utils
11+
import imageio
12+
import jax
13+
import jax.numpy as jnp
14+
from ml_collections.config_flags import config_flags
15+
import tensorflow_datasets as tfds
16+
17+
from core.data import codenet
18+
from core.data import codenet_paths
19+
from core.data import error_kinds
20+
from core.data import info as info_lib
21+
from core.data import process
22+
from core.lib import metrics
23+
from core.lib import trainer
24+
25+
DEFAULT_DATASET_PATH = codenet_paths.DEFAULT_DATASET_PATH
26+
DEFAULT_CONFIG_PATH = codenet_paths.DEFAULT_CONFIG_PATH
27+
28+
29+
flags.DEFINE_string('dataset_path', DEFAULT_DATASET_PATH, 'Dataset path.')
30+
flags.DEFINE_string('latex_template_path', 'example_figure_template.tex',
31+
'LaTeX template path.')
32+
config_flags.DEFINE_config_file(
33+
name='config', default=DEFAULT_CONFIG_PATH, help_string='Config file.'
34+
)
35+
FLAGS = flags.FLAGS
36+
37+
38+
def get_raise_contribution_at_step(instruction_pointer, raise_decisions, raise_index):
39+
# instruction_pointer.shape: num_nodes
40+
# raise_decisions.shape: num_nodes, 2
41+
# raise_index.shape: scalar.
42+
p_raise = raise_decisions[:, 0]
43+
raise_contribution = p_raise * instruction_pointer
44+
# raise_contribution.shape: num_nodes
45+
raise_contribution = raise_contribution.at[raise_index].set(0)
46+
return raise_contribution
47+
get_raise_contribution_at_steps = jax.vmap(get_raise_contribution_at_step, in_axes=(0, 0, None))
48+
49+
50+
def get_raise_contribution(instruction_pointer, raise_decisions, raise_index, step_limit):
51+
# instruction_pointer.shape: steps, num_nodes
52+
# raise_decisions.shape: steps, num_nodes, 2
53+
# raise_index.shape: scalar.
54+
# step_limit.shape: scalar.
55+
raise_contributions = get_raise_contribution_at_steps(
56+
instruction_pointer, raise_decisions, raise_index)
57+
# raise_contributions.shape: steps, num_nodes
58+
mask = jnp.arange(instruction_pointer.shape[0]) < step_limit
59+
# mask.shape: steps
60+
raise_contributions = jnp.where(mask[:, None], raise_contributions, 0)
61+
raise_contribution = jnp.sum(raise_contributions, axis=0)
62+
# raise_contribution.shape: num_nodes
63+
return raise_contribution
64+
get_raise_contribution_batch = jax.vmap(get_raise_contribution)
65+
66+
67+
def print_spans(raw):
68+
span_starts = raw.node_span_starts
69+
span_ends = raw.node_span_ends
70+
for i, (span_start, span_end) in enumerate(zip(span_starts, span_ends)):
71+
print(f'Span {i}: {raw.source[span_start:span_end]}')
72+
73+
74+
def get_spans(raw):
75+
span_starts = raw.node_span_starts
76+
span_ends = raw.node_span_ends
77+
for i, (span_start, span_end) in enumerate(zip(span_starts, span_ends)):
78+
yield raw.source[span_start:span_end]
79+
80+
81+
def set_config(config):
82+
"""This function is hard-coded to load a particular checkpoint.
83+
84+
It also sets the model part of the config to match the config of that checkpoint.
85+
Everything related to parameter construction must match.
86+
"""
87+
88+
# config.multidevice=False
89+
# config.batch_size=32
90+
# config.raise_in_ipagnn=True
91+
# config.optimizer = 'sgd'
92+
# config.hidden_size = 128
93+
# config.span_encoding_method = 'max'
94+
# config.permissive_node_embeddings = False
95+
# config.transformer_emb_dim = 512
96+
# config.transformer_num_heads = 8
97+
# config.transformer_num_layers = 6
98+
# config.transformer_qkv_dim = 512
99+
# config.transformer_mlp_dim = 2048
100+
101+
# config.restore_checkpoint_dir=(
102+
# # '/mnt/runtime-error-problems-experiments/experiments/2021-11-08-ckpts-001/36/I2-h=128,s=sum,b=32,pne=F/top-checkpoints/checkpoint_89901'
103+
# '/mnt/runtime-error-problems-experiments/experiments/2021-11-02-docstring/33/E1952,o=sgd,bs=32,lr=0.3,gc=0.5,hs=128,span=max,tdr=0.1,tadr=0,pe=False,T=default/checkpoints'
104+
# )
105+
# config.span_encoding_method = 'mean'
106+
107+
config.multidevice=False
108+
config.batch_size=32
109+
config.raise_in_ipagnn=True
110+
config.restore_checkpoint_dir=(
111+
'/mnt/runtime-error-problems-experiments/experiments/2021-09-24-pretrain-004-copy/6/'
112+
'I1466,o=sgd,bs=32,lr=0.3,gc=2,hs=256,span=max,'
113+
'tdr=0,tadr=0,pe=False,T=default/checkpoints/'
114+
)
115+
config.optimizer = 'sgd'
116+
config.hidden_size = 256
117+
config.span_encoding_method = 'max'
118+
config.permissive_node_embeddings = False
119+
config.transformer_emb_dim = 512
120+
config.transformer_num_heads = 8
121+
config.transformer_num_layers = 6
122+
config.transformer_qkv_dim = 512
123+
config.transformer_mlp_dim = 2048
124+
125+
config.restore_checkpoint_dir=(
126+
'/mnt/runtime-error-problems-experiments/experiments/2021-10-11-finetune-006-copy/10/'
127+
'E122,o=sgd,bs=32,lr=0.3,gc=2,hs=256,span=max,'
128+
'tdr=0.1,tadr=0,pe=False,T=default/checkpoints'
129+
)
130+
config.span_encoding_method = 'max'
131+
return config
132+
133+
134+
@dataclasses.dataclass
135+
class VisualizationInfo:
136+
"""Information for visualizing model predictions."""
137+
raw: process.RawRuntimeErrorProblem
138+
source: str
139+
target_error: str
140+
prediction_error: str
141+
error_contributions: jnp.array
142+
143+
144+
def main(argv):
145+
del argv # Unused.
146+
147+
dataset_path = FLAGS.dataset_path
148+
config = FLAGS.config
149+
latex_template_path = FLAGS.latex_template_path
150+
config = set_config(config)
151+
152+
jnp.set_printoptions(threshold=config.printoptions_threshold)
153+
info = info_lib.get_dataset_info(dataset_path, config)
154+
t = trainer.Trainer(config=config, info=info)
155+
156+
split = 'valid'
157+
dataset = t.load_dataset(
158+
dataset_path=dataset_path, split=split, include_strings=True)
159+
160+
# Initialize / Load the model state.
161+
rng = jax.random.PRNGKey(0)
162+
rng, init_rng = jax.random.split(rng)
163+
model = t.make_model(deterministic=False)
164+
state = t.create_train_state(init_rng, model)
165+
if config.restore_checkpoint_dir:
166+
state = checkpoints.restore_checkpoint(config.restore_checkpoint_dir, state)
167+
168+
train_step = t.make_train_step()
169+
for batch in tfds.as_numpy(dataset):
170+
assert not config.multidevice
171+
# We do not allow multidevice in this script.
172+
# if config.multidevice:
173+
# batch = common_utils.shard(batch)
174+
problem_ids = batch.pop('problem_id')
175+
submission_ids = batch.pop('submission_id')
176+
state, aux = train_step(state, batch)
177+
178+
instruction_pointer = aux['instruction_pointer_orig']
179+
# instruction_pointer.shape: steps, batch_size, num_nodes
180+
instruction_pointer = jnp.transpose(instruction_pointer, [1, 0, 2])
181+
# instruction_pointer.shape: batch_size, steps, num_nodes
182+
exit_index = batch['exit_index']
183+
raise_index = exit_index + 1
184+
raise_decisions = aux['raise_decisions']
185+
# raise_decisions.shape: steps, batch_size, num_nodes, 2
186+
raise_decisions = jnp.transpose(raise_decisions, [1, 0, 2, 3])
187+
# raise_decisions.shape: batch_size, steps, num_nodes, 2
188+
contributions = get_raise_contribution_batch(instruction_pointer, raise_decisions, raise_index, batch['step_limit'])
189+
# contributions.shape: batch_size, num_nodes
190+
191+
for index, (problem_id, submission_id, contribution) \
192+
in enumerate(zip(problem_ids, submission_ids, contributions)):
193+
problem_id = problem_id[0].decode('utf-8')
194+
submission_id = submission_id[0].decode('utf-8')
195+
python_path = codenet.get_python_path(problem_id, submission_id)
196+
r_index = int(raise_index[index])
197+
num_nodes = int(raise_index[index]) + 1
198+
target = int(batch['target'][index])
199+
target_error = error_kinds.to_error(target)
200+
prediction = int(jnp.argmax(aux['logits'][index]))
201+
prediction_error = error_kinds.to_error(prediction)
202+
step_limit = batch['step_limit'][index, 0]
203+
instruction_pointer_single = instruction_pointer[index]
204+
205+
total_contribution = jnp.sum(contribution)
206+
actual_value = instruction_pointer[index, -1, r_index]
207+
max_contributor = int(jnp.argmax(contribution))
208+
max_contribution = contribution[max_contributor]
209+
210+
# Not all submissions are in the copy of the dataset in gs://project-codenet-data.
211+
# So we only visualize those that are in the copy.
212+
if os.path.exists(python_path):
213+
found = True
214+
with open(python_path, 'r') as f:
215+
source = f.read()
216+
error_lineno = codenet.get_error_lineno(problem_id, submission_id)
217+
raw = process.make_rawruntimeerrorproblem(
218+
source, target,
219+
target_lineno=error_lineno, problem_id=problem_id, submission_id=submission_id)
220+
221+
# Visualize the data.
222+
print('---')
223+
print(f'Problem: {problem_id} {submission_id} ({split})')
224+
print(f'Batch index: {index}')
225+
print(f'Target: {target} ({target_error})')
226+
print(f'Prediction: {prediction} ({prediction_error})')
227+
print()
228+
print(source.strip() + '\n')
229+
print_spans(raw)
230+
print(contribution[:num_nodes])
231+
print(f'Main contributor: Node {max_contributor} ({max_contribution})')
232+
print(f'Total contribution: {total_contribution} (Actual: {actual_value})')
233+
234+
if error_lineno:
235+
nodes_at_error = process.get_nodes_at_lineno(raw, error_lineno)
236+
print(f'Error lineno: {error_lineno} (nodes {nodes_at_error})')
237+
print(source.split('\n')[error_lineno - 1]) # -1 for line index.
238+
239+
visualization_info = VisualizationInfo(
240+
raw=raw,
241+
source=source.strip(),
242+
target_error=target_error,
243+
prediction_error=prediction_error,
244+
error_contributions=contribution[:num_nodes])
245+
246+
show_latex_predictions(info=visualization_info, index=index)
247+
248+
instruction_pointer_single_trim = instruction_pointer_single[:step_limit + 1, :num_nodes].T
249+
# instruction_pointer_single_trim.shape: num_nodes, timesteps
250+
image = metrics.instruction_pointer_to_image(instruction_pointer_single_trim)
251+
imageio.imwrite('viz-instruction-pointer.png', image, format='png')
252+
253+
# Wait for the user to press enter, then continue visualizing.
254+
input()
255+
256+
257+
def show_latex_predictions(info: VisualizationInfo, index: int):
258+
raw = info.raw
259+
spans = tuple(get_spans(raw))
260+
error_contributions = info.error_contributions
261+
262+
latex_lines = []
263+
span_count = len(spans)
264+
error_contribution_count = info.error_contributions.shape[0]
265+
if span_count != error_contribution_count:
266+
print(
267+
f'Expected span count {span_count} to match error contribution count '
268+
f'{error_contribution_count}')
269+
270+
for i, (span,
271+
error_contribution) in enumerate(zip(spans, error_contributions)):
272+
latex_lines.append(
273+
f'\code{{{i}}} & \code{{{span}}} & \code{{{error_contribution:0.2f}}}'
274+
)
275+
276+
line_separator = '\\\\ \hdashline\n'
277+
latex_content = line_separator.join(latex_lines)
278+
print('latex_content')
279+
print(latex_content)
280+
281+
if __name__ == '__main__':
282+
app.run(main)

0 commit comments

Comments
 (0)