Skip to content
This repository was archived by the owner on Jan 22, 2024. It is now read-only.

Commit cbd7156

Browse files
committed
Add a prediction visualization script.
Work in-progress.
1 parent b8d7f31 commit cbd7156

File tree

3 files changed

+292
-1
lines changed

3 files changed

+292
-1
lines changed

core/data/data_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def get_padded_shapes(max_tokens, max_num_nodes, max_num_edges, include_strings=
147147
'problem_id': [1],
148148
'submission_id': [1],
149149
})
150-
150+
151151
return shapes
152152

153153

File renamed without changes.

scripts/visualize_predictions.py

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
"""Visualize model predictions."""
2+
3+
import dataclasses
4+
import os
5+
6+
from absl import app
7+
from absl import flags
8+
9+
from flax.training import checkpoints
10+
from flax.training import common_utils
11+
import imageio
12+
import jax
13+
import jax.numpy as jnp
14+
from ml_collections.config_flags import config_flags
15+
import tensorflow_datasets as tfds
16+
17+
from core.data import codenet
18+
from core.data import codenet_paths
19+
from core.data import error_kinds
20+
from core.data import info as info_lib
21+
from core.data import process
22+
from core.lib import metrics
23+
from core.lib import trainer
24+
25+
DEFAULT_DATASET_PATH = codenet_paths.DEFAULT_DATASET_PATH
26+
DEFAULT_CONFIG_PATH = codenet_paths.DEFAULT_CONFIG_PATH
27+
28+
29+
flags.DEFINE_string('dataset_path', DEFAULT_DATASET_PATH, 'Dataset path.')
30+
flags.DEFINE_string('latex_template_path', 'example_figure_template.tex',
31+
'LaTeX template path.')
32+
config_flags.DEFINE_config_file(
33+
name='config', default=DEFAULT_CONFIG_PATH, help_string='Config file.'
34+
)
35+
FLAGS = flags.FLAGS
36+
37+
38+
def get_raise_contribution_at_step(instruction_pointer, raise_decisions, raise_index):
39+
# instruction_pointer.shape: num_nodes
40+
# raise_decisions.shape: num_nodes, 2
41+
# raise_index.shape: scalar.
42+
p_raise = raise_decisions[:, 0]
43+
raise_contribution = p_raise * instruction_pointer
44+
# raise_contribution.shape: num_nodes
45+
raise_contribution = raise_contribution.at[raise_index].set(0)
46+
return raise_contribution
47+
get_raise_contribution_at_steps = jax.vmap(get_raise_contribution_at_step, in_axes=(0, 0, None))
48+
49+
50+
def get_raise_contribution(instruction_pointer, raise_decisions, raise_index, step_limit):
51+
# instruction_pointer.shape: steps, num_nodes
52+
# raise_decisions.shape: steps, num_nodes, 2
53+
# raise_index.shape: scalar.
54+
# step_limit.shape: scalar.
55+
raise_contributions = get_raise_contribution_at_steps(
56+
instruction_pointer, raise_decisions, raise_index)
57+
# raise_contributions.shape: steps, num_nodes
58+
mask = jnp.arange(instruction_pointer.shape[0]) < step_limit
59+
# mask.shape: steps
60+
raise_contributions = jnp.where(mask[:, None], raise_contributions, 0)
61+
raise_contribution = jnp.sum(raise_contributions, axis=0)
62+
# raise_contribution.shape: num_nodes
63+
return raise_contribution
64+
get_raise_contribution_batch = jax.vmap(get_raise_contribution)
65+
66+
67+
def print_spans(raw):
68+
span_starts = raw.node_span_starts
69+
span_ends = raw.node_span_ends
70+
for i, (span_start, span_end) in enumerate(zip(span_starts, span_ends)):
71+
print(f'Span {i}: {raw.source[span_start:span_end]}')
72+
73+
74+
def get_spans(raw):
75+
span_starts = raw.node_span_starts
76+
span_ends = raw.node_span_ends
77+
for i, (span_start, span_end) in enumerate(zip(span_starts, span_ends)):
78+
yield raw.source[span_start:span_end]
79+
80+
81+
def set_config(config):
82+
"""This function is hard-coded to load a particular checkpoint.
83+
84+
It also sets the model part of the config to match the config of that checkpoint.
85+
Everything related to parameter construction must match.
86+
"""
87+
config.multidevice=False
88+
config.batch_size=32
89+
config.raise_in_ipagnn=True
90+
config.optimizer = 'sgd'
91+
config.hidden_size = 128
92+
config.span_encoding_method = 'max'
93+
config.permissive_node_embeddings = False
94+
config.transformer_emb_dim = 512
95+
config.transformer_num_heads = 8
96+
config.transformer_num_layers = 6
97+
config.transformer_qkv_dim = 512
98+
config.transformer_mlp_dim = 2048
99+
100+
# config.restore_checkpoint_dir=(
101+
# # '/mnt/runtime-error-problems-experiments/experiments/2021-11-08-ckpts-001/36/I2-h=128,s=sum,b=32,pne=F/top-checkpoints/checkpoint_89901'
102+
# '/mnt/runtime-error-problems-experiments/experiments/2021-11-02-docstring/33/E1952,o=sgd,bs=32,lr=0.3,gc=0.5,hs=128,span=max,tdr=0.1,tadr=0,pe=False,T=default/checkpoints'
103+
# )
104+
# config.span_encoding_method = 'mean'
105+
return config
106+
107+
108+
@dataclasses.dataclass
109+
class VisualizationInfo:
110+
"""Information for visualizing model predictions."""
111+
raw: process.RawRuntimeErrorProblem
112+
target_error: str
113+
prediction_error: str
114+
error_contributions: jnp.array
115+
116+
117+
def main(argv):
118+
del argv # Unused.
119+
120+
dataset_path = FLAGS.dataset_path
121+
config = FLAGS.config
122+
latex_template_path = FLAGS.latex_template_path
123+
config = set_config(config)
124+
125+
jnp.set_printoptions(threshold=config.printoptions_threshold)
126+
info = info_lib.get_dataset_info(dataset_path, config)
127+
t = trainer.Trainer(config=config, info=info)
128+
129+
split = 'valid'
130+
dataset = t.load_dataset(
131+
dataset_path=dataset_path, split=split, include_strings=True)
132+
133+
# Initialize / Load the model state.
134+
rng = jax.random.PRNGKey(0)
135+
rng, init_rng = jax.random.split(rng)
136+
model = t.make_model(deterministic=False)
137+
state = t.create_train_state(init_rng, model)
138+
if config.restore_checkpoint_dir:
139+
state = checkpoints.restore_checkpoint(config.restore_checkpoint_dir, state)
140+
141+
train_step = t.make_train_step()
142+
for batch in tfds.as_numpy(dataset):
143+
assert not config.multidevice
144+
# We do not allow multidevice in this script.
145+
# if config.multidevice:
146+
# batch = common_utils.shard(batch)
147+
problem_ids = batch.pop('problem_id')
148+
submission_ids = batch.pop('submission_id')
149+
state, aux = train_step(state, batch)
150+
151+
instruction_pointer = aux['instruction_pointer_orig']
152+
# instruction_pointer.shape: steps, batch_size, num_nodes
153+
instruction_pointer = jnp.transpose(instruction_pointer, [1, 0, 2])
154+
# instruction_pointer.shape: batch_size, steps, num_nodes
155+
exit_index = batch['exit_index']
156+
raise_index = exit_index + 1
157+
raise_decisions = aux['raise_decisions']
158+
# raise_decisions.shape: steps, batch_size, num_nodes, 2
159+
raise_decisions = jnp.transpose(raise_decisions, [1, 0, 2, 3])
160+
# raise_decisions.shape: batch_size, steps, num_nodes, 2
161+
contributions = get_raise_contribution_batch(instruction_pointer, raise_decisions, raise_index, batch['step_limit'])
162+
# contributions.shape: batch_size, num_nodes
163+
164+
for index, (problem_id, submission_id, contribution) \
165+
in enumerate(zip(problem_ids, submission_ids, contributions)):
166+
problem_id = problem_id[0].decode('utf-8')
167+
submission_id = submission_id[0].decode('utf-8')
168+
python_path = codenet.get_python_path(problem_id, submission_id)
169+
r_index = int(raise_index[index])
170+
num_nodes = int(raise_index[index]) + 1
171+
target = int(batch['target'][index])
172+
target_error = error_kinds.to_error(target)
173+
prediction = int(jnp.argmax(aux['logits'][index]))
174+
prediction_error = error_kinds.to_error(prediction)
175+
step_limit = batch['step_limit'][index]
176+
instruction_pointer_single = instruction_pointer[index]
177+
178+
total_contribution = jnp.sum(contribution)
179+
actual_value = instruction_pointer[index, -1, r_index]
180+
max_contributor = int(jnp.argmax(contribution))
181+
max_contribution = contribution[max_contributor]
182+
183+
# Not all submissions are in the copy of the dataset in gs://project-codenet-data.
184+
# So we only visualize those that are in the copy.
185+
if os.path.exists(python_path):
186+
found = True
187+
with open(python_path, 'r') as f:
188+
source = f.read()
189+
error_lineno = codenet.get_error_lineno(problem_id, submission_id)
190+
raw = process.make_rawruntimeerrorproblem(
191+
source, target,
192+
target_lineno=error_lineno, problem_id=problem_id, submission_id=submission_id)
193+
194+
# Visualize the data.
195+
print('---')
196+
print(f'Problem: {problem_id} {submission_id} ({split})')
197+
print(f'Batch index: {index}')
198+
print(f'Target: {target} ({target_error})')
199+
print(f'Prediction: {prediction} ({prediction_error})')
200+
print()
201+
print(source.strip() + '\n')
202+
print_spans(raw)
203+
print(contribution[:num_nodes])
204+
print(f'Main contributor: Node {max_contributor} ({max_contribution})')
205+
print(f'Total contribution: {total_contribution} (Actual: {actual_value})')
206+
207+
instruction_pointer_single_trim = instruction_pointer_single[:step_limit + 1, :num_nodes].T
208+
# instruction_pointer_single_trim.shape: num_nodes, timesteps
209+
image = metrics.instruction_pointer_to_image(instruction_pointer_single_trim)
210+
imageio.imwrite('viz-instruction-pointer.png', image, format='png')
211+
with open('viz-source.txt', 'w') as f:
212+
f.write(source)
213+
214+
if error_lineno:
215+
nodes_at_error = process.get_nodes_at_lineno(raw, error_lineno)
216+
print(f'Error lineno: {error_lineno} (nodes {nodes_at_error})')
217+
print(source.split('\n')[error_lineno - 1]) # -1 for line index.
218+
219+
visualization_info = VisualizationInfo(
220+
raw=raw,
221+
target_error=target_error,
222+
prediction_error=prediction_error,
223+
error_contributions=contribution[:num_nodes])
224+
225+
show_latex_predictions(info=visualization_info, index=index)
226+
227+
# Wait for the user to press enter, then continue visualizing.
228+
input()
229+
230+
231+
def show_latex_predictions(info: VisualizationInfo, index: int):
232+
raw = info.raw
233+
spans = tuple(get_spans(raw))
234+
error_contributions = info.error_contributions
235+
236+
latex_lines = []
237+
span_count = len(spans)
238+
error_contribution_count = info.error_contributions.shape[0]
239+
if span_count != error_contribution_count:
240+
print(
241+
f'Expected span count {span_count} to match error contribution count '
242+
f'{error_contribution_count}')
243+
# raise AssertionError(
244+
# f'Expected span count {span_count} to match error contribution count '
245+
# f'{error_contribution_count}')
246+
247+
# Always three more
248+
for i, (span,
249+
error_contribution) in enumerate(zip(spans, error_contributions)):
250+
latex_lines.append(
251+
f'\code{{{i}}} & \code{{{span}}} & \code{{{error_contribution}}}'
252+
)
253+
254+
line_separator = '\\ \hdashline\n'
255+
latex_content = line_separator.join(latex_lines)
256+
print('latex_content')
257+
print(latex_content)
258+
259+
# latex_template = '''\
260+
# \begin{figure}%[ht]
261+
# % \hspace{0pt}
262+
# \centering
263+
# \resizebox{\textwidth}{!}
264+
# {
265+
# \begin{tabular}{cl|cccc|ccc}
266+
# \toprule
267+
# $n$ & Source & \multicolumn{4}{c}{Tokenization ($x_n$)} & Exception provenance & $\incomingneighborsX(n)$ & $\outgoingneighborsX(n)$ \\
268+
# \midrule
269+
# \code{0} & \code{v0 = 23} & \code{0} & \code{=} & & \code{v0} & \code{23} &
270+
# $\emptyset$ & $\{1\}$
271+
# \\ \hdashline
272+
# \code{1} & \code{v1 = 6} & \code{0} & \code{=} & & \code{v1} & \code{~6} & $\{0\}$ & $\{2\}$\\ \hdashline
273+
# \code{2} & \code{while v1 > 0:} & \code{0} & \code{while >} & & \code{v1} & \code{~0} & $\{1, 7\}$ & $\{3, 8\}$\\ \hdashline
274+
# \code{3} & \code{~~v1 -= 1} & \code{1} & \code{-=} & & \code{v1} & \code{~1} & $\{2\}$ & $\{4\}$\\ \hdashline
275+
# \code{4} & \code{~~if v0 \% 10 <= 3:} & \code{1} & \code{if <= \%} & & \code{v0} & \code{~3} & $\{3\}$ & $\{5\}$\\ \hdashline
276+
# \code{5} & \code{~~~~v0 += 4} & \code{2} & \code{+=} & & \code{v0} & \code{~4} & $\{4\}$ & $\{6\}$\\ \hdashline
277+
# \code{6} & \code{~~~~v0 *= 6} & \code{2} & \code{*=} & & \code{v0} & \code{~6} & $\{5\}$ & $\{7\}$ \\ \hdashline
278+
# \code{7} & \code{~~v0 -= 1} & \code{1} & \code{-=} & & \code{v0} & \code{~1} & $\{4, 6\}$ & $\{2\}$\\ \hdashline
279+
# \code{8} & \code{<exit>} & \code{-} & \code{-} & & \code{-} & \code{~-} & $\{2, 8\}$ & $\{8\}$\\
280+
# \bottomrule
281+
# \end{tabular}
282+
# }
283+
# \caption{
284+
# \textbf{Program representation.} Each line of a program is represented by a 4-tuple tokenization containing that line's (indentation level, operation, variable, operand), and is associated with a node in the program's statement-level control flow graph.
285+
# }
286+
# \label{fig:program-representations}
287+
# \end{figure}
288+
# '''
289+
290+
if __name__ == '__main__':
291+
app.run(main)

0 commit comments

Comments
 (0)