Skip to content

Commit 9334cad

Browse files
author
Googler
committed
remove deprecated assertDictContainsSubset
PiperOrigin-RevId: 715816954
1 parent 03b3f6a commit 9334cad

File tree

2 files changed

+75
-53
lines changed

2 files changed

+75
-53
lines changed

src/python/tensorflow_cloud/tuner/tests/unit/cloud_fit_client_test.py

Lines changed: 73 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,21 @@ def _model(self):
9393

9494
def test_default_job_spec(self):
9595
self.assertStartsWith(self._job_spec["job_id"], "cloud_fit_")
96-
self.assertDictContainsSubset(
97-
{
98-
"masterConfig": {"imageUri": self._image_uri,},
99-
"args": [
100-
"--remote_dir",
101-
self._remote_dir,
102-
"--distribution_strategy",
103-
MULTI_WORKER_MIRRORED_STRATEGY_NAME,
104-
],
105-
},
96+
expected = {
97+
"masterConfig": {"imageUri": self._image_uri,},
98+
"args": [
99+
"--remote_dir",
100+
self._remote_dir,
101+
"--distribution_strategy",
102+
MULTI_WORKER_MIRRORED_STRATEGY_NAME,
103+
],
104+
}
105+
self.assertEqual(
106106
self._job_spec["trainingInput"],
107+
{
108+
**self._job_spec["trainingInput"],
109+
**expected,
110+
}
107111
)
108112

109113
@mock.patch.object(discovery, "build", autospec=True)
@@ -125,17 +129,21 @@ def test_submit_job(self, mock_discovery_build):
125129

126130
_, fit_kwargs = list(self._mock_create.call_args)
127131
body = fit_kwargs["body"]
128-
self.assertDictContainsSubset(
129-
{
130-
"masterConfig": {"imageUri": self._image_uri,},
131-
"args": [
132-
"--remote_dir",
133-
self._remote_dir,
134-
"--distribution_strategy",
135-
MULTI_WORKER_MIRRORED_STRATEGY_NAME,
136-
],
137-
},
132+
expected = {
133+
"masterConfig": {"imageUri": self._image_uri,},
134+
"args": [
135+
"--remote_dir",
136+
self._remote_dir,
137+
"--distribution_strategy",
138+
MULTI_WORKER_MIRRORED_STRATEGY_NAME,
139+
],
140+
}
141+
self.assertEqual(
138142
body["trainingInput"],
143+
{
144+
**body["trainingInput"],
145+
**expected,
146+
}
139147
)
140148
self.assertStartsWith(body["job_id"], "cloud_fit_")
141149
self._mock_get.execute.assert_called_with()
@@ -212,8 +220,9 @@ def test_fit_kwargs(self, mock_submit_job):
212220
os.path.join(remote_dir, "training_assets")
213221
)
214222
elements = training_assets_graph.fit_kwargs_fn()
215-
self.assertDictContainsSubset(tfds.as_numpy(
216-
elements), {"batch_size": 1, "epochs": 2, "verbose": 3})
223+
actual = {"batch_size": 1, "epochs": 2, "verbose": 3}
224+
expected = tfds.as_numpy(elements)
225+
self.assertEqual(actual, {**actual, **expected})
217226

218227
@mock.patch.object(client, "_submit_job", autospec=True)
219228
def test_custom_job_spec(self, mock_submit_job):
@@ -245,17 +254,21 @@ def test_custom_job_spec(self, mock_submit_job):
245254

246255
kargs, _ = mock_submit_job.call_args
247256
body, _ = kargs
248-
self.assertDictContainsSubset(
249-
{
250-
"masterConfig": {"imageUri": self._image_uri,},
251-
"args": [
252-
"--remote_dir",
253-
self._remote_dir,
254-
"--distribution_strategy",
255-
MULTI_WORKER_MIRRORED_STRATEGY_NAME,
256-
],
257-
},
257+
expected = {
258+
"masterConfig": {"imageUri": self._image_uri,},
259+
"args": [
260+
"--remote_dir",
261+
self._remote_dir,
262+
"--distribution_strategy",
263+
MULTI_WORKER_MIRRORED_STRATEGY_NAME,
264+
],
265+
}
266+
self.assertEqual(
258267
body["trainingInput"],
268+
{
269+
**body["trainingInput"],
270+
**expected,
271+
}
259272
)
260273

261274
@mock.patch.object(client, "_submit_job", autospec=True)
@@ -275,16 +288,20 @@ def test_distribution_strategy(
275288

276289
kargs, _ = mock_submit_job.call_args
277290
body, _ = kargs
278-
self.assertDictContainsSubset(
279-
{
280-
"args": [
281-
"--remote_dir",
282-
self._remote_dir,
283-
"--distribution_strategy",
284-
MULTI_WORKER_MIRRORED_STRATEGY_NAME,
285-
],
286-
},
291+
expected = {
292+
"args": [
293+
"--remote_dir",
294+
self._remote_dir,
295+
"--distribution_strategy",
296+
MULTI_WORKER_MIRRORED_STRATEGY_NAME,
297+
],
298+
}
299+
self.assertEqual(
287300
body["trainingInput"],
301+
{
302+
**body["trainingInput"],
303+
**expected,
304+
}
288305
)
289306

290307
client.cloud_fit(
@@ -297,16 +314,20 @@ def test_distribution_strategy(
297314

298315
kargs, _ = mock_submit_job.call_args
299316
body, _ = kargs
300-
self.assertDictContainsSubset(
301-
{
302-
"args": [
303-
"--remote_dir",
304-
self._remote_dir,
305-
"--distribution_strategy",
306-
MIRRORED_STRATEGY_NAME,
307-
],
308-
},
317+
expected = {
318+
"args": [
319+
"--remote_dir",
320+
self._remote_dir,
321+
"--distribution_strategy",
322+
MIRRORED_STRATEGY_NAME,
323+
],
324+
}
325+
self.assertEqual(
309326
body["trainingInput"],
327+
{
328+
**body["trainingInput"],
329+
**expected,
330+
}
310331
)
311332

312333
with self.assertRaises(ValueError):
@@ -351,7 +372,8 @@ def test_job_id(self, mock_serialize_assets, mock_submit_job):
351372

352373
kargs, _ = mock_submit_job.call_args
353374
body, _ = kargs
354-
self.assertDictContainsSubset({"job_id": test_job_id,}, body)
375+
expected = {"job_id": test_job_id,}
376+
self.assertEqual(body, {**body, **expected})
355377

356378

357379
if __name__ == "__main__":

src/python/tensorflow_cloud/utils/tests/unit/google_api_client_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,8 @@ def test_get_or_set_consent_status_notify_user(self):
256256

257257
with open(self._local_config_path) as config_json:
258258
config_data = json.load(config_json)
259-
self.assertDictContainsSubset(
260-
config_data, {"notification_version": version.__version__})
259+
actual = {"notification_version": version.__version__}
260+
self.assertEqual(actual, {**actual, **config_data})
261261

262262
@mock.patch.object(google_api_client,
263263
"get_or_set_consent_status", autospec=True)

0 commit comments

Comments
 (0)