Skip to content
This repository was archived by the owner on Jan 22, 2024. It is now read-only.

Commit 518e52b

Browse files
committed
Add a prediction visualization script.
Work in-progress.
1 parent 71deb99 commit 518e52b

File tree

3 files changed

+81
-8
lines changed

3 files changed

+81
-8
lines changed

core/data/data_io.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def get_fake_input(batch_size, max_tokens, max_num_nodes, max_num_edges):
9292
}
9393

9494

95-
def get_padded_shapes(max_tokens, max_num_nodes, max_num_edges):
96-
return {
95+
def get_padded_shapes(max_tokens, max_num_nodes, max_num_edges, include_strings=False):
96+
shapes = {
9797
'tokens': [max_tokens],
9898
'edge_sources': [max_num_edges],
9999
'edge_dests': [max_num_edges],
@@ -111,6 +111,13 @@ def get_padded_shapes(max_tokens, max_num_nodes, max_num_edges):
111111
'num_nodes': [1],
112112
'num_edges': [1],
113113
}
114+
if include_strings:
115+
shapes.update({
116+
'problem_id': [1],
117+
'submission_id': [1],
118+
})
119+
120+
return shapes
114121

115122

116123
def make_filter(
@@ -186,4 +193,4 @@ def load_dataset(dataset_path=codenet_paths.DEFAULT_DATASET_PATH, split='train',
186193
return load_tfrecords_dataset(tfrecord_paths, include_strings=include_strings)
187194
else:
188195
tfrecord_path = codenet_paths.make_tfrecord_path(dataset_path, split)
189-
return load_tfrecord_dataset(tfrecord_path, include_strings=include_strings)
196+
return load_tfrecord_dataset(tfrecord_path, include_strings=include_strings)

core/lib/trainer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828
from core.lib import metrics
2929
from core.lib import models
3030
from core.lib import optimizer_lib
31-
from core.lib.metrics import EvaluationMetric
3231

3332

3433
DEFAULT_DATASET_PATH = codenet_paths.DEFAULT_DATASET_PATH
3534

3635
Config = ml_collections.ConfigDict
36+
EvaluationMetric = metrics.EvaluationMetric
3737

3838

3939
class TrainState(train_state.TrainState):
@@ -47,7 +47,8 @@ class Trainer:
4747
info: Any
4848

4949
def load_dataset(
50-
self, dataset_path=DEFAULT_DATASET_PATH, split='train', epochs=None
50+
self, dataset_path=DEFAULT_DATASET_PATH, split='train', epochs=None,
51+
include_strings=False,
5152
):
5253
config = self.config
5354
batch_size = config.batch_size
@@ -57,7 +58,9 @@ def load_dataset(
5758
allowlist = config.allowlist
5859

5960
padded_shapes = data_io.get_padded_shapes(
60-
config.max_tokens, config.max_num_nodes, config.max_num_edges)
61+
config.max_tokens, config.max_num_nodes, config.max_num_edges, include_strings=include_strings)
62+
print('padded_shapes')
63+
print(padded_shapes)
6164
if allowlist == 'TIER1_ERROR_IDS':
6265
allowlist = error_kinds.TIER1_ERROR_IDS
6366
filter_fn = data_io.make_filter(
@@ -68,7 +71,7 @@ def load_dataset(
6871
# Prepare a dataset with a single repeating batch.
6972
split = split[:-len('-batch')]
7073
return (
71-
data_io.load_dataset(dataset_path, split=split)
74+
data_io.load_dataset(dataset_path, split=split, include_strings=include_strings)
7275
.filter(filter_fn)
7376
.take(batch_size)
7477
.repeat(epochs)
@@ -77,7 +80,7 @@ def load_dataset(
7780

7881
# Return the requested dataset.
7982
return (
80-
data_io.load_dataset(dataset_path, split=split)
83+
data_io.load_dataset(dataset_path, split=split, include_strings=include_strings)
8184
.filter(filter_fn)
8285
.repeat(epochs)
8386
.shuffle(1000)
@@ -303,6 +306,7 @@ def run_train(self, dataset_path=DEFAULT_DATASET_PATH, split='train', steps=None
303306
train_predictions = []
304307
train_targets = []
305308
train_losses = []
309+
print('Starting training')
306310
for step_index, batch in itertools.islice(enumerate(tfds.as_numpy(dataset)), steps):
307311
step = state.step
308312
if config.multidevice:

scripts/visualize_predictions.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Visualize model predictions."""
2+
3+
from absl import app
4+
from absl import flags
5+
6+
from flax.training import common_utils
7+
import jax
8+
import jax.numpy as jnp
9+
from ml_collections.config_flags import config_flags
10+
import tensorflow_datasets as tfds
11+
12+
from core.data import codenet_paths
13+
from core.data import info as info_lib
14+
from core.lib import trainer
15+
16+
DEFAULT_DATASET_PATH = codenet_paths.DEFAULT_DATASET_PATH
17+
DEFAULT_CONFIG_PATH = codenet_paths.DEFAULT_CONFIG_PATH
18+
19+
20+
flags.DEFINE_string('dataset_path', DEFAULT_DATASET_PATH, 'Dataset path.')
21+
config_flags.DEFINE_config_file(
22+
name='config', default=DEFAULT_CONFIG_PATH, help_string='Config file.'
23+
)
24+
FLAGS = flags.FLAGS
25+
26+
27+
def main(argv):
28+
del argv # Unused.
29+
30+
dataset_path = FLAGS.dataset_path
31+
config = FLAGS.config
32+
jnp.set_printoptions(threshold=config.printoptions_threshold)
33+
info = info_lib.get_dataset_info(dataset_path)
34+
t = trainer.Trainer(config=config, info=info)
35+
36+
dataset = t.load_dataset(
37+
dataset_path=dataset_path, split='train', include_strings=True)
38+
39+
# for i, example in enumerate(dataset):
40+
# print('example', i)
41+
# print(example)
42+
# break
43+
44+
rng = jax.random.PRNGKey(0)
45+
rng, init_rng = jax.random.split(rng)
46+
model = t.make_model(deterministic=False)
47+
48+
state = t.create_train_state(init_rng, model)
49+
50+
train_step = t.make_train_step()
51+
for batch in tfds.as_numpy(dataset):
52+
if config.multidevice:
53+
batch = common_utils.shard(batch)
54+
problem_id = batch.pop('problem_id')
55+
submission_id = batch.pop('submission_id')
56+
state, aux = train_step(state, batch)
57+
print(aux.keys())
58+
print(aux)
59+
60+
61+
if __name__ == '__main__':
62+
app.run(main)

0 commit comments

Comments
 (0)