|
7 | 7 |
|
8 | 8 | # pyre-strict
|
9 | 9 |
|
| 10 | +import concurrent |
10 | 11 | import copy
|
11 | 12 | import dataclasses
|
12 | 13 | import logging
|
|
42 | 43 | ThroughputDef,
|
43 | 44 | )
|
44 | 45 | from torchrec.metrics.model_utils import parse_task_model_outputs
|
45 |
| -from torchrec.metrics.rec_metric import RecMetricList, RecTaskInfo |
| 46 | +from torchrec.metrics.rec_metric import RecMetricException, RecMetricList, RecTaskInfo |
46 | 47 | from torchrec.metrics.test_utils import gen_test_batch
|
47 | 48 | from torchrec.metrics.throughput import ThroughputMetric
|
48 | 49 | from torchrec.test_utils import get_free_port, seed_and_log, skip_if_asan_class
|
@@ -647,6 +648,22 @@ def test_save_and_load_state_dict(self) -> None:
|
647 | 648 | # Make sure num_batch wasn't created on the throughput module (and no exception was thrown above)
|
648 | 649 | self.assertFalse(hasattr(no_bss_metric_module.throughput_metric, "_num_batch"))
|
649 | 650 |
|
| 651 | + def test_async_compute_raises_exception(self) -> None: |
| 652 | + metric_module = generate_metric_module( |
| 653 | + TestMetricModule, |
| 654 | + metrics_config=DefaultMetricsConfig, |
| 655 | + batch_size=128, |
| 656 | + world_size=1, |
| 657 | + my_rank=0, |
| 658 | + state_metrics_mapping={}, |
| 659 | + device=torch.device("cpu"), |
| 660 | + ) |
| 661 | + with self.assertRaisesRegex( |
| 662 | + RecMetricException, |
| 663 | + "async_compute is not supported in RecMetricModule", |
| 664 | + ): |
| 665 | + metric_module.async_compute(concurrent.futures.Future()) |
| 666 | + |
650 | 667 |
|
651 | 668 | def metric_module_gather_state(
|
652 | 669 | rank: int,
|
@@ -702,6 +719,8 @@ def metric_module_gather_state(
|
702 | 719 | new_tensor = new_computed_value[metric]
|
703 | 720 | torch.testing.assert_close(tensor, new_tensor, check_device=False)
|
704 | 721 |
|
| 722 | + metric_module.shutdown() |
| 723 | + |
705 | 724 |
|
706 | 725 | @skip_if_asan_class
|
707 | 726 | class MetricModuleDistributedTest(MultiProcessTestBase):
|
|
0 commit comments