Skip to content
This repository was archived by the owner on Jan 22, 2024. It is now read-only.

Commit 8cbb8f1

Browse files
committed
Add a prediction visualization script.
Work in-progress.
1 parent 71deb99 commit 8cbb8f1

File tree

3 files changed

+78
-8
lines changed

3 files changed

+78
-8
lines changed

core/data/data_io.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def get_fake_input(batch_size, max_tokens, max_num_nodes, max_num_edges):
9292
}
9393

9494

95-
def get_padded_shapes(max_tokens, max_num_nodes, max_num_edges):
96-
return {
95+
def get_padded_shapes(max_tokens, max_num_nodes, max_num_edges, include_strings=False):
96+
shapes = {
9797
'tokens': [max_tokens],
9898
'edge_sources': [max_num_edges],
9999
'edge_dests': [max_num_edges],
@@ -111,6 +111,13 @@ def get_padded_shapes(max_tokens, max_num_nodes, max_num_edges):
111111
'num_nodes': [1],
112112
'num_edges': [1],
113113
}
114+
if include_strings:
115+
shapes.update({
116+
'problem_id': [1],
117+
'submission_id': [1],
118+
})
119+
120+
return shapes
114121

115122

116123
def make_filter(
@@ -186,4 +193,4 @@ def load_dataset(dataset_path=codenet_paths.DEFAULT_DATASET_PATH, split='train',
186193
return load_tfrecords_dataset(tfrecord_paths, include_strings=include_strings)
187194
else:
188195
tfrecord_path = codenet_paths.make_tfrecord_path(dataset_path, split)
189-
return load_tfrecord_dataset(tfrecord_path, include_strings=include_strings)
196+
return load_tfrecord_dataset(tfrecord_path, include_strings=include_strings)

core/lib/trainer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828
from core.lib import metrics
2929
from core.lib import models
3030
from core.lib import optimizer_lib
31-
from core.lib.metrics import EvaluationMetric
3231

3332

3433
DEFAULT_DATASET_PATH = codenet_paths.DEFAULT_DATASET_PATH
3534

3635
Config = ml_collections.ConfigDict
36+
EvaluationMetric = metrics.EvaluationMetric
3737

3838

3939
class TrainState(train_state.TrainState):
@@ -47,7 +47,8 @@ class Trainer:
4747
info: Any
4848

4949
def load_dataset(
50-
self, dataset_path=DEFAULT_DATASET_PATH, split='train', epochs=None
50+
self, dataset_path=DEFAULT_DATASET_PATH, split='train', epochs=None,
51+
include_strings=False,
5152
):
5253
config = self.config
5354
batch_size = config.batch_size
@@ -57,7 +58,9 @@ def load_dataset(
5758
allowlist = config.allowlist
5859

5960
padded_shapes = data_io.get_padded_shapes(
60-
config.max_tokens, config.max_num_nodes, config.max_num_edges)
61+
config.max_tokens, config.max_num_nodes, config.max_num_edges, include_strings=include_strings)
62+
print('padded_shapes')
63+
print(padded_shapes)
6164
if allowlist == 'TIER1_ERROR_IDS':
6265
allowlist = error_kinds.TIER1_ERROR_IDS
6366
filter_fn = data_io.make_filter(
@@ -68,7 +71,7 @@ def load_dataset(
6871
# Prepare a dataset with a single repeating batch.
6972
split = split[:-len('-batch')]
7073
return (
71-
data_io.load_dataset(dataset_path, split=split)
74+
data_io.load_dataset(dataset_path, split=split, include_strings=include_strings)
7275
.filter(filter_fn)
7376
.take(batch_size)
7477
.repeat(epochs)
@@ -77,7 +80,7 @@ def load_dataset(
7780

7881
# Return the requested dataset.
7982
return (
80-
data_io.load_dataset(dataset_path, split=split)
83+
data_io.load_dataset(dataset_path, split=split, include_strings=include_strings)
8184
.filter(filter_fn)
8285
.repeat(epochs)
8386
.shuffle(1000)
@@ -303,6 +306,7 @@ def run_train(self, dataset_path=DEFAULT_DATASET_PATH, split='train', steps=None
303306
train_predictions = []
304307
train_targets = []
305308
train_losses = []
309+
print('Starting training')
306310
for step_index, batch in itertools.islice(enumerate(tfds.as_numpy(dataset)), steps):
307311
step = state.step
308312
if config.multidevice:

scripts/visualize_predictions.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Visualize model predictions."""
2+
3+
from absl import app
4+
from absl import flags
5+
6+
import jax.numpy as jnp
7+
from ml_collections.config_flags import config_flags
8+
9+
from core.data import codenet_paths
10+
from core.data import info as info_lib
11+
from core.lib import trainer
12+
13+
DEFAULT_DATASET_PATH = codenet_paths.DEFAULT_DATASET_PATH
14+
DEFAULT_CONFIG_PATH = codenet_paths.DEFAULT_CONFIG_PATH
15+
16+
17+
flags.DEFINE_string('dataset_path', DEFAULT_DATASET_PATH, 'Dataset path.')
18+
config_flags.DEFINE_config_file(
19+
name='config', default=DEFAULT_CONFIG_PATH, help_string='Config file.'
20+
)
21+
FLAGS = flags.FLAGS
22+
23+
24+
def main(argv):
25+
del argv # Unused.
26+
27+
dataset_path = FLAGS.dataset_path
28+
config = FLAGS.config
29+
jnp.set_printoptions(threshold=config.printoptions_threshold)
30+
info = info_lib.get_dataset_info(dataset_path)
31+
t = trainer.Trainer(config=config, info=info)
32+
33+
dataset = t.load_dataset(
34+
dataset_path=dataset_path, split='train', include_strings=True)
35+
36+
# for i, example in enumerate(dataset):
37+
# print('example', i)
38+
# print(example)
39+
# break
40+
41+
rng = jax.random.PRNGKey(0)
42+
rng, init_rng = jax.random.split(rng)
43+
model = t.make_model(deterministic=False)
44+
45+
state = t.create_train_state(init_rng, model)
46+
47+
train_step = t.make_train_step()
48+
for batch in tfds.as_numpy(dataset):
49+
if config.multidevice:
50+
batch = common_utils.shard(batch)
51+
problem_id = batch.pop('problem_id')
52+
submission_id = batch.pop('submission_id')
53+
state, aux = train_step(state, batch)
54+
print(aux.keys())
55+
print(aux)
56+
57+
58+
if __name__ == '__main__':
59+
app.run(main)

0 commit comments

Comments
 (0)