Skip to content

Commit c7816ee

Browse files
authored
Merge pull request octo-models#108 from rail-berkeley/kevin-tmp-shuffle
Shuffle before decode, interleave at transition level
2 parents 8d26158 + 585741a commit c7816ee

File tree

11 files changed

+400
-279
lines changed

11 files changed

+400
-279
lines changed

config.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,15 @@ def get_config(
6666
clip_gradient=1.0,
6767
frozen_keys=tuple(),
6868
),
69-
batch_size=1024,
7069
eval_batch_size=128,
71-
shuffle_buffer_size=100000,
70+
prefetch_num_batches=0,
7271
val_shuffle_buffer_size=1000,
7372
num_val_batches=16,
7473
start_step=placeholder(int),
7574
log_interval=100,
7675
eval_interval=5000,
77-
save_interval=5000,
76+
viz_interval=20000,
77+
save_interval=10000,
7878
trajs_for_metrics=100,
7979
trajs_for_viz=8,
8080
resume_path=placeholder(str),
@@ -114,7 +114,7 @@ def get_dataset_config(modality="multimodal", window_size=1):
114114
raise ValueError(f"Unknown modality {modality}")
115115

116116
return {
117-
# oxe_kwargs will generate data_kwargs_list and sampling weights
117+
# oxe_kwargs will generate dataset_kwargs_list and sampling weights
118118
"oxe_kwargs": dict(
119119
data_mix=placeholder(str),
120120
# for v4 TPUs: "gs://rail-orca-central2/resize_336_336"
@@ -123,18 +123,19 @@ def get_dataset_config(modality="multimodal", window_size=1):
123123
n_wrist_cameras=0,
124124
load_depth=False,
125125
),
126-
# common_kwargs override specific kwargs from data_kwargs_list
127-
"common_kwargs": dict(
128-
ram_budget=1, # limit RAM per dataset
129-
num_parallel_reads=8, # for reading from GCS
130-
num_parallel_calls=16, # for the less CPU-intensive ops in initial dataset construction
126+
# common_dataset_kwargs override specific kwargs from dataset_kwargs_list
127+
"common_dataset_kwargs": dict(
131128
action_proprio_normalization_type=normalization_type,
132129
),
133-
"transform_kwargs": dict(
134-
resize_size=(256, 256),
135-
num_parallel_calls=32, # for the most CPU-intensive ops (decoding, resizing, augmenting)
130+
"traj_transform_kwargs": dict(
136131
window_size=window_size,
137132
additional_action_window_size=0,
133+
goal_relabeling_strategy="uniform",
134+
subsample_length=100,
135+
**task_augmentation,
136+
),
137+
"frame_transform_kwargs": dict(
138+
resize_size=(256, 256),
138139
image_augment_kwargs=dict(
139140
random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
140141
random_brightness=[0.2],
@@ -149,9 +150,13 @@ def get_dataset_config(modality="multimodal", window_size=1):
149150
"random_hue",
150151
],
151152
),
152-
goal_relabeling_strategy="uniform",
153-
**task_augmentation,
154153
),
154+
"traj_transform_threads": 48, # shared between all datasets
155+
"traj_read_threads": 48, # shared between all datasets
156+
"frame_transform_threads": 200, # not shared between datasets
157+
"shuffle_buffer_size": 100000, # shared between all datasets
158+
"batch_size": 1024,
159+
"balance_weights": True,
155160
}
156161

157162

experiments/dibya/finetune_config.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def get_config(
3232
"action_encoding": ActionEncoding.EEF_POS,
3333
"action_proprio_normalization_type": "normal",
3434
# If the default data loading speed is too slow, try these:
35-
# and "num_parallel_calls" in `transform_kwargs` below
3635
# "num_parallel_reads": 8, # for reading from disk / GCS
3736
# "num_parallel_calls": 16, # for initial dataset construction
3837
}
@@ -70,7 +69,7 @@ def get_config(
7069
wandb=dict(
7170
project="orca_finetune", group=placeholder(str), entity=placeholder(str)
7271
),
73-
finetuning_dataset=FINETUNING_KWARGS,
72+
dataset_kwargs=FINETUNING_KWARGS,
7473
modality=task,
7574
finetuning_mode=mode,
7675
window_size=window_size,
@@ -107,9 +106,18 @@ def get_config(
107106
else:
108107
raise ValueError("Invalid modality")
109108

110-
transform_kwargs = dict(
109+
traj_transform_kwargs = dict(
111110
window_size=window_size,
112111
additional_action_window_size=0,
112+
goal_relabeling_strategy=goal_relabeling_strategy,
113+
task_augmentation_strategy="delete_task_conditioning",
114+
task_augmentation_kwargs=dict(
115+
delete_key_groups_probs=delete_key_groups_probs,
116+
),
117+
# If the default data loading speed is too slow, try these:
118+
# num_parallel_calls=16, # for less CPU-intensive ops
119+
)
120+
frame_transform_kwargs = dict(
113121
resize_size=(256, 256),
114122
image_augment_kwargs=dict(
115123
random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
@@ -125,13 +133,12 @@ def get_config(
125133
"random_hue",
126134
],
127135
),
128-
goal_relabeling_strategy=goal_relabeling_strategy,
129-
task_augmentation_strategy="delete_task_conditioning",
130-
task_augmentation_kwargs=dict(
131-
delete_key_groups_probs=delete_key_groups_probs,
132-
),
133-
# If the default data loading speed is too slow, try these:
134-
# num_parallel_calls=16, # for the most CPU-intensive ops (decoding, resizing, augmenting)
135136
)
136-
config["data_transforms"] = transform_kwargs
137+
# If the default data loading speed is too slow, try these:
138+
config[
139+
"frame_transform_threads"
140+
] = 16 # for the most CPU-intensive ops (decoding, resizing, augmenting)
141+
142+
config["traj_transform_kwargs"] = traj_transform_kwargs
143+
config["frame_transform_kwargs"] = frame_transform_kwargs
137144
return ConfigDict(config)

finetune.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def main(_):
6464
ORCA Finetuning Script
6565
======================
6666
Pretrained model: {FLAGS.config.pretrained_path}
67-
Finetuning Dataset: {FLAGS.config.finetuning_dataset.name}
68-
Data dir: {FLAGS.config.finetuning_dataset.data_dir}
67+
Finetuning Dataset: {FLAGS.config.dataset_kwargs.name}
68+
Data dir: {FLAGS.config.dataset_kwargs.data_dir}
6969
Task Modality: {FLAGS.config.modality}
7070
Finetuning Mode: {FLAGS.config.finetuning_mode}
7171
@@ -159,13 +159,22 @@ def process_text(batch):
159159
del batch["dataset_name"]
160160
return batch
161161

162-
data_kwargs = FLAGS.config.finetuning_dataset
163-
transform_kwargs = FLAGS.config.data_transforms
164-
165-
dataset = make_single_dataset(data_kwargs, transform_kwargs, train=True)
166-
val_dataset = make_single_dataset(data_kwargs, transform_kwargs, train=False)
162+
dataset = make_single_dataset(
163+
FLAGS.config.dataset_kwargs,
164+
FLAGS.config.traj_transform_kwargs,
165+
FLAGS.config.frame_transform_kwargs,
166+
train=True,
167+
frame_transform_threads=FLAGS.config.frame_transform_threads,
168+
)
169+
val_dataset = make_single_dataset(
170+
FLAGS.config.dataset_kwargs,
171+
FLAGS.config.traj_transform_kwargs,
172+
FLAGS.config.frame_transform_kwargs,
173+
train=False,
174+
frame_transform_threads=FLAGS.config.frame_transform_threads,
175+
)
167176
visualizer = Visualizer(
168-
val_dataset, text_processor=text_processor, cache_trajs=False
177+
val_dataset, text_processor=text_processor, freeze_trajs=False
169178
)
170179

171180
def create_iterator(dataset):

0 commit comments

Comments
 (0)