@@ -93,17 +93,21 @@ def _model(self):
93
93
94
94
def test_default_job_spec (self ):
95
95
self .assertStartsWith (self ._job_spec ["job_id" ], "cloud_fit_" )
96
- self . assertDictContainsSubset (
97
- {
98
- "masterConfig " : { "imageUri" : self . _image_uri ,},
99
- "args" : [
100
- "--remote_dir" ,
101
- self . _remote_dir ,
102
- "--distribution_strategy" ,
103
- MULTI_WORKER_MIRRORED_STRATEGY_NAME ,
104
- ],
105
- },
96
+ expected = {
97
+ "masterConfig" : { "imageUri" : self . _image_uri ,},
98
+ "args " : [
99
+ "--remote_dir" ,
100
+ self . _remote_dir ,
101
+ "--distribution_strategy" ,
102
+ MULTI_WORKER_MIRRORED_STRATEGY_NAME ,
103
+ ] ,
104
+ }
105
+ self . assertEqual (
106
106
self ._job_spec ["trainingInput" ],
107
+ {
108
+ ** self ._job_spec ["trainingInput" ],
109
+ ** expected ,
110
+ }
107
111
)
108
112
109
113
@mock .patch .object (discovery , "build" , autospec = True )
@@ -125,17 +129,21 @@ def test_submit_job(self, mock_discovery_build):
125
129
126
130
_ , fit_kwargs = list (self ._mock_create .call_args )
127
131
body = fit_kwargs ["body" ]
128
- self . assertDictContainsSubset (
129
- {
130
- "masterConfig " : { "imageUri" : self . _image_uri ,},
131
- "args" : [
132
- "--remote_dir" ,
133
- self . _remote_dir ,
134
- "--distribution_strategy" ,
135
- MULTI_WORKER_MIRRORED_STRATEGY_NAME ,
136
- ],
137
- },
132
+ expected = {
133
+ "masterConfig" : { "imageUri" : self . _image_uri ,},
134
+ "args " : [
135
+ "--remote_dir" ,
136
+ self . _remote_dir ,
137
+ "--distribution_strategy" ,
138
+ MULTI_WORKER_MIRRORED_STRATEGY_NAME ,
139
+ ] ,
140
+ }
141
+ self . assertEqual (
138
142
body ["trainingInput" ],
143
+ {
144
+ ** body ["trainingInput" ],
145
+ ** expected ,
146
+ }
139
147
)
140
148
self .assertStartsWith (body ["job_id" ], "cloud_fit_" )
141
149
self ._mock_get .execute .assert_called_with ()
@@ -212,8 +220,9 @@ def test_fit_kwargs(self, mock_submit_job):
212
220
os .path .join (remote_dir , "training_assets" )
213
221
)
214
222
elements = training_assets_graph .fit_kwargs_fn ()
215
- self .assertDictContainsSubset (tfds .as_numpy (
216
- elements ), {"batch_size" : 1 , "epochs" : 2 , "verbose" : 3 })
223
+ actual = {"batch_size" : 1 , "epochs" : 2 , "verbose" : 3 }
224
+ expected = tfds .as_numpy (elements )
225
+ self .assertEqual (actual , {** actual , ** expected })
217
226
218
227
@mock .patch .object (client , "_submit_job" , autospec = True )
219
228
def test_custom_job_spec (self , mock_submit_job ):
@@ -245,17 +254,21 @@ def test_custom_job_spec(self, mock_submit_job):
245
254
246
255
kargs , _ = mock_submit_job .call_args
247
256
body , _ = kargs
248
- self . assertDictContainsSubset (
249
- {
250
- "masterConfig " : { "imageUri" : self . _image_uri ,},
251
- "args" : [
252
- "--remote_dir" ,
253
- self . _remote_dir ,
254
- "--distribution_strategy" ,
255
- MULTI_WORKER_MIRRORED_STRATEGY_NAME ,
256
- ],
257
- },
257
+ expected = {
258
+ "masterConfig" : { "imageUri" : self . _image_uri ,},
259
+ "args " : [
260
+ "--remote_dir" ,
261
+ self . _remote_dir ,
262
+ "--distribution_strategy" ,
263
+ MULTI_WORKER_MIRRORED_STRATEGY_NAME ,
264
+ ] ,
265
+ }
266
+ self . assertEqual (
258
267
body ["trainingInput" ],
268
+ {
269
+ ** body ["trainingInput" ],
270
+ ** expected ,
271
+ }
259
272
)
260
273
261
274
@mock .patch .object (client , "_submit_job" , autospec = True )
@@ -275,16 +288,20 @@ def test_distribution_strategy(
275
288
276
289
kargs , _ = mock_submit_job .call_args
277
290
body , _ = kargs
278
- self . assertDictContainsSubset (
279
- {
280
- "args" : [
281
- "--remote_dir" ,
282
- self . _remote_dir ,
283
- "--distribution_strategy" ,
284
- MULTI_WORKER_MIRRORED_STRATEGY_NAME ,
285
- ],
286
- },
291
+ expected = {
292
+ "args" : [
293
+ "--remote_dir" ,
294
+ self . _remote_dir ,
295
+ "--distribution_strategy" ,
296
+ MULTI_WORKER_MIRRORED_STRATEGY_NAME ,
297
+ ] ,
298
+ }
299
+ self . assertEqual (
287
300
body ["trainingInput" ],
301
+ {
302
+ ** body ["trainingInput" ],
303
+ ** expected ,
304
+ }
288
305
)
289
306
290
307
client .cloud_fit (
@@ -297,16 +314,20 @@ def test_distribution_strategy(
297
314
298
315
kargs , _ = mock_submit_job .call_args
299
316
body , _ = kargs
300
- self . assertDictContainsSubset (
301
- {
302
- "args" : [
303
- "--remote_dir" ,
304
- self . _remote_dir ,
305
- "--distribution_strategy" ,
306
- MIRRORED_STRATEGY_NAME ,
307
- ],
308
- },
317
+ expected = {
318
+ "args" : [
319
+ "--remote_dir" ,
320
+ self . _remote_dir ,
321
+ "--distribution_strategy" ,
322
+ MIRRORED_STRATEGY_NAME ,
323
+ ] ,
324
+ }
325
+ self . assertEqual (
309
326
body ["trainingInput" ],
327
+ {
328
+ ** body ["trainingInput" ],
329
+ ** expected ,
330
+ }
310
331
)
311
332
312
333
with self .assertRaises (ValueError ):
@@ -351,7 +372,8 @@ def test_job_id(self, mock_serialize_assets, mock_submit_job):
351
372
352
373
kargs , _ = mock_submit_job .call_args
353
374
body , _ = kargs
354
- self .assertDictContainsSubset ({"job_id" : test_job_id ,}, body )
375
+ expected = {"job_id" : test_job_id ,}
376
+ self .assertEqual (body , {** body , ** expected })
355
377
356
378
357
379
if __name__ == "__main__" :
0 commit comments