|
| 1 | +"""Visualize model predictions.""" |
| 2 | + |
| 3 | +import dataclasses |
| 4 | +import os |
| 5 | + |
| 6 | +from absl import app |
| 7 | +from absl import flags |
| 8 | + |
| 9 | +from flax.training import checkpoints |
| 10 | +from flax.training import common_utils |
| 11 | +import imageio |
| 12 | +import jax |
| 13 | +import jax.numpy as jnp |
| 14 | +from ml_collections.config_flags import config_flags |
| 15 | +import tensorflow_datasets as tfds |
| 16 | + |
| 17 | +from core.data import codenet |
| 18 | +from core.data import codenet_paths |
| 19 | +from core.data import error_kinds |
| 20 | +from core.data import info as info_lib |
| 21 | +from core.data import process |
| 22 | +from core.lib import metrics |
| 23 | +from core.lib import trainer |
| 24 | + |
| 25 | +DEFAULT_DATASET_PATH = codenet_paths.DEFAULT_DATASET_PATH |
| 26 | +DEFAULT_CONFIG_PATH = codenet_paths.DEFAULT_CONFIG_PATH |
| 27 | + |
| 28 | + |
| 29 | +flags.DEFINE_string('dataset_path', DEFAULT_DATASET_PATH, 'Dataset path.') |
| 30 | +flags.DEFINE_string('latex_template_path', 'example_figure_template.tex', |
| 31 | + 'LaTeX template path.') |
| 32 | +config_flags.DEFINE_config_file( |
| 33 | + name='config', default=DEFAULT_CONFIG_PATH, help_string='Config file.' |
| 34 | +) |
| 35 | +FLAGS = flags.FLAGS |
| 36 | + |
| 37 | + |
| 38 | +def get_raise_contribution_at_step(instruction_pointer, raise_decisions, raise_index): |
| 39 | + # instruction_pointer.shape: num_nodes |
| 40 | + # raise_decisions.shape: num_nodes, 2 |
| 41 | + # raise_index.shape: scalar. |
| 42 | + p_raise = raise_decisions[:, 0] |
| 43 | + raise_contribution = p_raise * instruction_pointer |
| 44 | + # raise_contribution.shape: num_nodes |
| 45 | + raise_contribution = raise_contribution.at[raise_index].set(0) |
| 46 | + return raise_contribution |
| 47 | +get_raise_contribution_at_steps = jax.vmap(get_raise_contribution_at_step, in_axes=(0, 0, None)) |
| 48 | + |
| 49 | + |
| 50 | +def get_raise_contribution(instruction_pointer, raise_decisions, raise_index, step_limit): |
| 51 | + # instruction_pointer.shape: steps, num_nodes |
| 52 | + # raise_decisions.shape: steps, num_nodes, 2 |
| 53 | + # raise_index.shape: scalar. |
| 54 | + # step_limit.shape: scalar. |
| 55 | + raise_contributions = get_raise_contribution_at_steps( |
| 56 | + instruction_pointer, raise_decisions, raise_index) |
| 57 | + # raise_contributions.shape: steps, num_nodes |
| 58 | + mask = jnp.arange(instruction_pointer.shape[0]) < step_limit |
| 59 | + # mask.shape: steps |
| 60 | + raise_contributions = jnp.where(mask[:, None], raise_contributions, 0) |
| 61 | + raise_contribution = jnp.sum(raise_contributions, axis=0) |
| 62 | + # raise_contribution.shape: num_nodes |
| 63 | + return raise_contribution |
| 64 | +get_raise_contribution_batch = jax.vmap(get_raise_contribution) |
| 65 | + |
| 66 | + |
| 67 | +def print_spans(raw): |
| 68 | + span_starts = raw.node_span_starts |
| 69 | + span_ends = raw.node_span_ends |
| 70 | + for i, (span_start, span_end) in enumerate(zip(span_starts, span_ends)): |
| 71 | + print(f'Span {i}: {raw.source[span_start:span_end]}') |
| 72 | + |
| 73 | + |
| 74 | +def get_spans(raw): |
| 75 | + span_starts = raw.node_span_starts |
| 76 | + span_ends = raw.node_span_ends |
| 77 | + for i, (span_start, span_end) in enumerate(zip(span_starts, span_ends)): |
| 78 | + yield raw.source[span_start:span_end] |
| 79 | + |
| 80 | + |
| 81 | +def set_config(config): |
| 82 | + """This function is hard-coded to load a particular checkpoint. |
| 83 | +
|
| 84 | + It also sets the model part of the config to match the config of that checkpoint. |
| 85 | + Everything related to parameter construction must match. |
| 86 | + """ |
| 87 | + |
| 88 | + # config.multidevice=False |
| 89 | + # config.batch_size=32 |
| 90 | + # config.raise_in_ipagnn=True |
| 91 | + # config.optimizer = 'sgd' |
| 92 | + # config.hidden_size = 128 |
| 93 | + # config.span_encoding_method = 'max' |
| 94 | + # config.permissive_node_embeddings = False |
| 95 | + # config.transformer_emb_dim = 512 |
| 96 | + # config.transformer_num_heads = 8 |
| 97 | + # config.transformer_num_layers = 6 |
| 98 | + # config.transformer_qkv_dim = 512 |
| 99 | + # config.transformer_mlp_dim = 2048 |
| 100 | + |
| 101 | + # config.restore_checkpoint_dir=( |
| 102 | + # # '/mnt/runtime-error-problems-experiments/experiments/2021-11-08-ckpts-001/36/I2-h=128,s=sum,b=32,pne=F/top-checkpoints/checkpoint_89901' |
| 103 | + # '/mnt/runtime-error-problems-experiments/experiments/2021-11-02-docstring/33/E1952,o=sgd,bs=32,lr=0.3,gc=0.5,hs=128,span=max,tdr=0.1,tadr=0,pe=False,T=default/checkpoints' |
| 104 | + # ) |
| 105 | + # config.span_encoding_method = 'mean' |
| 106 | + |
| 107 | + config.multidevice=False |
| 108 | + config.batch_size=32 |
| 109 | + config.raise_in_ipagnn=True |
| 110 | + config.restore_checkpoint_dir=( |
| 111 | + '/mnt/runtime-error-problems-experiments/experiments/2021-09-24-pretrain-004-copy/6/' |
| 112 | + 'I1466,o=sgd,bs=32,lr=0.3,gc=2,hs=256,span=max,' |
| 113 | + 'tdr=0,tadr=0,pe=False,T=default/checkpoints/' |
| 114 | + ) |
| 115 | + config.optimizer = 'sgd' |
| 116 | + config.hidden_size = 256 |
| 117 | + config.span_encoding_method = 'max' |
| 118 | + config.permissive_node_embeddings = False |
| 119 | + config.transformer_emb_dim = 512 |
| 120 | + config.transformer_num_heads = 8 |
| 121 | + config.transformer_num_layers = 6 |
| 122 | + config.transformer_qkv_dim = 512 |
| 123 | + config.transformer_mlp_dim = 2048 |
| 124 | + |
| 125 | + config.restore_checkpoint_dir=( |
| 126 | + '/mnt/runtime-error-problems-experiments/experiments/2021-10-11-finetune-006-copy/10/' |
| 127 | + 'E122,o=sgd,bs=32,lr=0.3,gc=2,hs=256,span=max,' |
| 128 | + 'tdr=0.1,tadr=0,pe=False,T=default/checkpoints' |
| 129 | + ) |
| 130 | + config.span_encoding_method = 'max' |
| 131 | + return config |
| 132 | + |
| 133 | + |
| 134 | +@dataclasses.dataclass |
| 135 | +class VisualizationInfo: |
| 136 | + """Information for visualizing model predictions.""" |
| 137 | + raw: process.RawRuntimeErrorProblem |
| 138 | + source: str |
| 139 | + target_error: str |
| 140 | + prediction_error: str |
| 141 | + error_contributions: jnp.array |
| 142 | + |
| 143 | + |
| 144 | +def main(argv): |
| 145 | + del argv # Unused. |
| 146 | + |
| 147 | + dataset_path = FLAGS.dataset_path |
| 148 | + config = FLAGS.config |
| 149 | + latex_template_path = FLAGS.latex_template_path |
| 150 | + config = set_config(config) |
| 151 | + |
| 152 | + jnp.set_printoptions(threshold=config.printoptions_threshold) |
| 153 | + info = info_lib.get_dataset_info(dataset_path, config) |
| 154 | + t = trainer.Trainer(config=config, info=info) |
| 155 | + |
| 156 | + split = 'valid' |
| 157 | + dataset = t.load_dataset( |
| 158 | + dataset_path=dataset_path, split=split, include_strings=True) |
| 159 | + |
| 160 | + # Initialize / Load the model state. |
| 161 | + rng = jax.random.PRNGKey(0) |
| 162 | + rng, init_rng = jax.random.split(rng) |
| 163 | + model = t.make_model(deterministic=False) |
| 164 | + state = t.create_train_state(init_rng, model) |
| 165 | + if config.restore_checkpoint_dir: |
| 166 | + state = checkpoints.restore_checkpoint(config.restore_checkpoint_dir, state) |
| 167 | + |
| 168 | + train_step = t.make_train_step() |
| 169 | + for batch in tfds.as_numpy(dataset): |
| 170 | + assert not config.multidevice |
| 171 | + # We do not allow multidevice in this script. |
| 172 | + # if config.multidevice: |
| 173 | + # batch = common_utils.shard(batch) |
| 174 | + problem_ids = batch.pop('problem_id') |
| 175 | + submission_ids = batch.pop('submission_id') |
| 176 | + state, aux = train_step(state, batch) |
| 177 | + |
| 178 | + instruction_pointer = aux['instruction_pointer_orig'] |
| 179 | + # instruction_pointer.shape: steps, batch_size, num_nodes |
| 180 | + instruction_pointer = jnp.transpose(instruction_pointer, [1, 0, 2]) |
| 181 | + # instruction_pointer.shape: batch_size, steps, num_nodes |
| 182 | + exit_index = batch['exit_index'] |
| 183 | + raise_index = exit_index + 1 |
| 184 | + raise_decisions = aux['raise_decisions'] |
| 185 | + # raise_decisions.shape: steps, batch_size, num_nodes, 2 |
| 186 | + raise_decisions = jnp.transpose(raise_decisions, [1, 0, 2, 3]) |
| 187 | + # raise_decisions.shape: batch_size, steps, num_nodes, 2 |
| 188 | + contributions = get_raise_contribution_batch(instruction_pointer, raise_decisions, raise_index, batch['step_limit']) |
| 189 | + # contributions.shape: batch_size, num_nodes |
| 190 | + |
| 191 | + for index, (problem_id, submission_id, contribution) \ |
| 192 | + in enumerate(zip(problem_ids, submission_ids, contributions)): |
| 193 | + problem_id = problem_id[0].decode('utf-8') |
| 194 | + submission_id = submission_id[0].decode('utf-8') |
| 195 | + python_path = codenet.get_python_path(problem_id, submission_id) |
| 196 | + r_index = int(raise_index[index]) |
| 197 | + num_nodes = int(raise_index[index]) + 1 |
| 198 | + target = int(batch['target'][index]) |
| 199 | + target_error = error_kinds.to_error(target) |
| 200 | + prediction = int(jnp.argmax(aux['logits'][index])) |
| 201 | + prediction_error = error_kinds.to_error(prediction) |
| 202 | + step_limit = batch['step_limit'][index, 0] |
| 203 | + instruction_pointer_single = instruction_pointer[index] |
| 204 | + |
| 205 | + total_contribution = jnp.sum(contribution) |
| 206 | + actual_value = instruction_pointer[index, -1, r_index] |
| 207 | + max_contributor = int(jnp.argmax(contribution)) |
| 208 | + max_contribution = contribution[max_contributor] |
| 209 | + |
| 210 | + # Not all submissions are in the copy of the dataset in gs://project-codenet-data. |
| 211 | + # So we only visualize those that are in the copy. |
| 212 | + if os.path.exists(python_path): |
| 213 | + found = True |
| 214 | + with open(python_path, 'r') as f: |
| 215 | + source = f.read() |
| 216 | + error_lineno = codenet.get_error_lineno(problem_id, submission_id) |
| 217 | + raw = process.make_rawruntimeerrorproblem( |
| 218 | + source, target, |
| 219 | + target_lineno=error_lineno, problem_id=problem_id, submission_id=submission_id) |
| 220 | + |
| 221 | + # Visualize the data. |
| 222 | + print('---') |
| 223 | + print(f'Problem: {problem_id} {submission_id} ({split})') |
| 224 | + print(f'Batch index: {index}') |
| 225 | + print(f'Target: {target} ({target_error})') |
| 226 | + print(f'Prediction: {prediction} ({prediction_error})') |
| 227 | + print() |
| 228 | + print(source.strip() + '\n') |
| 229 | + print_spans(raw) |
| 230 | + print(contribution[:num_nodes]) |
| 231 | + print(f'Main contributor: Node {max_contributor} ({max_contribution})') |
| 232 | + print(f'Total contribution: {total_contribution} (Actual: {actual_value})') |
| 233 | + |
| 234 | + if error_lineno: |
| 235 | + nodes_at_error = process.get_nodes_at_lineno(raw, error_lineno) |
| 236 | + print(f'Error lineno: {error_lineno} (nodes {nodes_at_error})') |
| 237 | + print(source.split('\n')[error_lineno - 1]) # -1 for line index. |
| 238 | + |
| 239 | + visualization_info = VisualizationInfo( |
| 240 | + raw=raw, |
| 241 | + source=source.strip(), |
| 242 | + target_error=target_error, |
| 243 | + prediction_error=prediction_error, |
| 244 | + error_contributions=contribution[:num_nodes]) |
| 245 | + |
| 246 | + show_latex_predictions(info=visualization_info, index=index) |
| 247 | + |
| 248 | + instruction_pointer_single_trim = instruction_pointer_single[:step_limit + 1, :num_nodes].T |
| 249 | + # instruction_pointer_single_trim.shape: num_nodes, timesteps |
| 250 | + image = metrics.instruction_pointer_to_image(instruction_pointer_single_trim) |
| 251 | + imageio.imwrite('viz-instruction-pointer.png', image, format='png') |
| 252 | + |
| 253 | + # Wait for the user to press enter, then continue visualizing. |
| 254 | + input() |
| 255 | + |
| 256 | + |
| 257 | +def show_latex_predictions(info: VisualizationInfo, index: int): |
| 258 | + raw = info.raw |
| 259 | + spans = tuple(get_spans(raw)) |
| 260 | + error_contributions = info.error_contributions |
| 261 | + |
| 262 | + latex_lines = [] |
| 263 | + span_count = len(spans) |
| 264 | + error_contribution_count = info.error_contributions.shape[0] |
| 265 | + if span_count != error_contribution_count: |
| 266 | + print( |
| 267 | + f'Expected span count {span_count} to match error contribution count ' |
| 268 | + f'{error_contribution_count}') |
| 269 | + |
| 270 | + for i, (span, |
| 271 | + error_contribution) in enumerate(zip(spans, error_contributions)): |
| 272 | + latex_lines.append( |
| 273 | + f'\code{{{i}}} & \code{{{span}}} & \code{{{error_contribution:0.2f}}}' |
| 274 | + ) |
| 275 | + |
| 276 | + line_separator = '\\\\ \hdashline\n' |
| 277 | + latex_content = line_separator.join(latex_lines) |
| 278 | + print('latex_content') |
| 279 | + print(latex_content) |
| 280 | + |
| 281 | +if __name__ == '__main__': |
| 282 | + app.run(main) |
0 commit comments