1616"""Tests for tensorflow_datasets.scripts.cli.build."""
1717
1818import contextlib
19+ import dataclasses
20+ import multiprocessing
1921import os
2022import pathlib
2123from typing import Dict , Iterator , List , Optional
2224from unittest import mock
2325
24- from absl .testing import parameterized
2526from etils import epath
2627import pytest
2728import tensorflow_datasets as tfds
@@ -89,12 +90,12 @@ def _build(cmd_flags: str, mock_download_and_prepare: bool = True) -> List[str]:
8990 # to patch the function to record the generated_ds manually.
9091 # See:
9192 # https://stackoverflow.com/questions/64792295/how-to-get-self-instance-in-mock-mock-call-args
92- generated_ds_names = []
93+ queue = multiprocessing . Queue ()
9394
9495 def _download_and_prepare (self , * args , ** kwargs ):
9596 # Remove version from generated name (as only last version can be generated)
9697 full_name = '/' .join (self .info .full_name .split ('/' )[:- 1 ])
97- generated_ds_names . append (full_name )
98+ queue . put (full_name )
9899 if mock_download_and_prepare :
99100 return
100101 else :
@@ -105,6 +106,12 @@ def _download_and_prepare(self, *args, **kwargs):
105106 _download_and_prepare ,
106107 ):
107108 main .main (args )
109+ queue .put (None )
110+
111+ generated_ds_names = []
112+ while full_name := queue .get ():
113+ generated_ds_names .append (full_name )
114+
108115 return generated_ds_names
109116
110117
@@ -139,10 +146,10 @@ def test_build_multiple():
139146 ]
140147
141148
142- @parameterized . parameters ( range (5 ))
149+ @pytest . mark . parametrize ( 'num_processes' , range (1 , 4 ))
143150def test_build_parallel (num_processes ):
144151 # Order is not guaranteed
145- assert set (_build (f'trivia_qa --num-proccesses ={ num_processes } ' )) == set ([
152+ assert set (_build (f'trivia_qa --num-processes ={ num_processes } ' )) == set ([
146153 'trivia_qa/rc' ,
147154 'trivia_qa/rc.nocontext' ,
148155 'trivia_qa/unfiltered' ,
@@ -288,22 +295,29 @@ def test_download_only():
288295 mock_download .assert_called_with ({'file0' : 'http://data.org/file1.zip' })
289296
290297
291- @parameterized .parameters (
292- ('--manual_dir=/a/b' , {'manual_dir' : '/a/b' }),
293- ('--manual_dir=/a/b --add_name_to_manual_dir' , {'manual_dir' : '/a/b/x' }),
294- ('--extract_dir=/a/b' , {'extract_dir' : '/a/b' }),
295- ('--max_examples_per_split=42' , {'max_examples_per_split' : 42 }),
296- ('--register_checksums' , {'register_checksums' : True }),
297- ('--force_checksums_validation' , {'force_checksums_validation' : True }),
298- ('--max_shard_size_mb=128' , {'max_shard_size' : 128 << 20 }),
299- (
300- '--download_config="{\' max_shard_size\' : 1234}"' ,
301- {'max_shard_size' : 1234 },
302- ),
298+ @pytest .mark .parametrize (
299+ 'args,download_config_kwargs' ,
300+ [
301+ ('--manual_dir=/a/b' , {'manual_dir' : epath .Path ('/a/b' )}),
302+ (
303+ '--manual_dir=/a/b --add_name_to_manual_dir' ,
304+ {'manual_dir' : epath .Path ('/a/b/x' )},
305+ ),
306+ ('--extract_dir=/a/b' , {'extract_dir' : epath .Path ('/a/b' )}),
307+ ('--max_examples_per_split=42' , {'max_examples_per_split' : 42 }),
308+ ('--register_checksums' , {'register_checksums' : True }),
309+ ('--force_checksums_validation' , {'force_checksums_validation' : True }),
310+ ('--max_shard_size_mb=128' , {'max_shard_size' : 128 << 20 }),
311+ (
312+ '--download_config={"max_shard_size":1234}' ,
313+ {'max_shard_size' : 1234 },
314+ ),
315+ ],
303316)
304317def test_make_download_config (args : str , download_config_kwargs ):
305- args = main ._parse_flags (f'tfds build x { download_config_kwargs } ' .split ())
318+ args = main ._parse_flags (f'tfds build x { args } ' .split ())
306319 actual = build_lib ._make_download_config (args , dataset_name = 'x' )
307320 # Ignore the beam runner
308- actual .replace (beam_runner = None )
309- assert actual == tfds .download .DownloadConfig (** download_config_kwargs )
321+ actual = actual .replace (beam_runner = None )
322+ expected = tfds .download .DownloadConfig (** download_config_kwargs )
323+ assert dataclasses .asdict (actual ) == dataclasses .asdict (expected )
0 commit comments