Skip to content

Commit 22406c2

Browse files
jeffkbkimmeta-codesync[bot]
authored andcommitted
fix metric module code coverage (#3439)
Summary: Pull Request resolved: #3439 Add missing test cases. RecMetricModule does not support async_compute(). Reviewed By: nipung90 Differential Revision: D84070148 fbshipit-source-id: f791da40f41950f66fc05f570826dfe4cf6918b0
1 parent fdb46f1 commit 22406c2

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

torchrec/metrics/tests/test_metric_module.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-strict
99

10+
import concurrent
1011
import copy
1112
import dataclasses
1213
import logging
@@ -42,7 +43,7 @@
4243
ThroughputDef,
4344
)
4445
from torchrec.metrics.model_utils import parse_task_model_outputs
45-
from torchrec.metrics.rec_metric import RecMetricList, RecTaskInfo
46+
from torchrec.metrics.rec_metric import RecMetricException, RecMetricList, RecTaskInfo
4647
from torchrec.metrics.test_utils import gen_test_batch
4748
from torchrec.metrics.throughput import ThroughputMetric
4849
from torchrec.test_utils import get_free_port, seed_and_log, skip_if_asan_class
@@ -647,6 +648,22 @@ def test_save_and_load_state_dict(self) -> None:
647648
# Make sure num_batch wasn't created on the throughput module (and no exception was thrown above)
648649
self.assertFalse(hasattr(no_bss_metric_module.throughput_metric, "_num_batch"))
649650

651+
def test_async_compute_raises_exception(self) -> None:
652+
metric_module = generate_metric_module(
653+
TestMetricModule,
654+
metrics_config=DefaultMetricsConfig,
655+
batch_size=128,
656+
world_size=1,
657+
my_rank=0,
658+
state_metrics_mapping={},
659+
device=torch.device("cpu"),
660+
)
661+
with self.assertRaisesRegex(
662+
RecMetricException,
663+
"async_compute is not supported in RecMetricModule",
664+
):
665+
metric_module.async_compute(concurrent.futures.Future())
666+
650667

651668
def metric_module_gather_state(
652669
rank: int,
@@ -702,6 +719,8 @@ def metric_module_gather_state(
702719
new_tensor = new_computed_value[metric]
703720
torch.testing.assert_close(tensor, new_tensor, check_device=False)
704721

722+
metric_module.shutdown()
723+
705724

706725
@skip_if_asan_class
707726
class MetricModuleDistributedTest(MultiProcessTestBase):

0 commit comments

Comments
 (0)