@@ -76,13 +76,13 @@ def run_train(self, train_data, dev_data):
76
76
77
77
for epoch_id in range (self .start_epoch , self .num_epochs ):
78
78
print ('Epoch {}' .format (epoch_id ))
79
- if self .rl_variation_tag .startswith ('rs' ):
80
- # Reward shaping module sanity check:
81
- # Make sure the reward shaping module output value is in the correct range
82
- train_scores = self .test_fn (train_data )
83
- dev_scores = self .test_fn (dev_data )
84
- print ('Train set average fact score: {}' .format (float (train_scores .mean ())))
85
- print ('Dev set average fact score: {}' .format (float (dev_scores .mean ())))
79
+ # if self.rl_variation_tag.startswith('rs'):
80
+ # # Reward shaping module sanity check:
81
+ # # Make sure the reward shaping module output value is in the correct range
82
+ # train_scores = self.test_fn(train_data)
83
+ # dev_scores = self.test_fn(dev_data)
84
+ # print('Train set average fact score: {}'.format(float(train_scores.mean())))
85
+ # print('Dev set average fact score: {}'.format(float(dev_scores.mean())))
86
86
87
87
# Update model parameters
88
88
self .train ()
@@ -98,7 +98,7 @@ def run_train(self, train_data, dev_data):
98
98
if self .run_analysis :
99
99
rewards = None
100
100
fns = None
101
- for example_id in tqdm (range (0 , len ( train_data ) , self .batch_size )):
101
+ for example_id in tqdm (range (0 , 127 , self .batch_size )):
102
102
103
103
self .optim .zero_grad ()
104
104
@@ -154,7 +154,7 @@ def run_train(self, train_data, dev_data):
154
154
eta = self .action_dropout_anneal_interval
155
155
if len (dev_metrics_history ) > eta and metrics < min (dev_metrics_history [- eta :]):
156
156
old_action_dropout_rate = self .action_dropout_rate
157
- self .action_dropout_rate *= self .action_dropout_anneal_factor
157
+ self .action_dropout_rate *= self .action_dropout_anneal_factor
158
158
print ('Decreasing action dropout rate: {} -> {}' .format (
159
159
old_action_dropout_rate , self .action_dropout_rate ))
160
160
# Save checkpoint
0 commit comments