|
1 | 1 | import random |
| 2 | +from pathlib import Path |
2 | 3 | from typing import Dict, Tuple |
3 | 4 | from unittest import TestCase |
4 | 5 |
|
|
7 | 8 |
|
8 | 9 | from rsd.data.ists import ISTSDataset |
9 | 10 | from rsd.data.pawsx import PAWSXDataset |
10 | | -from rsd.experiments.benchmark import DifferenceRecognitionBenchmark |
| 11 | +from rsd.data.synthetic import SyntheticSTSDataset |
| 12 | +from rsd.experiments.benchmark import DifferenceRecognitionBenchmark, MultiLengthDifferenceRecognitionBenchmark |
11 | 13 | from rsd.recognizers.base import DifferenceRecognizer |
12 | 14 | from rsd.recognizers.utils import DifferenceSample |
13 | 15 |
|
14 | 16 |
|
15 | 17 | class DifferenceRecognitionBenchmarkTestCase(TestCase): |
16 | 18 |
|
17 | 19 | def setUp(self) -> None: |
18 | | - self.positive_dataset = ISTSDataset() |
| 20 | + # self.positive_dataset = ISTSDataset() |
| 21 | + self.positive_dataset = SyntheticSTSDataset( |
| 22 | + path=Path(__file__).parent.parent.parent / "ists_finetuning" / "data" / "synthetic_rsd" / "ft:gpt-4o-mini-2024-07-18:cl-uzh:rsd-test-en-v3:AxfIBckt_train_v2.jsonl", |
| 23 | + ) |
19 | 24 | self.negative_dataset = PAWSXDataset() |
20 | 25 | self.benchmark = DifferenceRecognitionBenchmark( |
21 | 26 | positive_dataset=self.positive_dataset, |
@@ -109,7 +114,7 @@ def predict(self, a: str, b: str, *args, **kwargs): |
109 | 114 |
|
110 | 115 | recognizer = OracleRecognizer(self.benchmark) |
111 | 116 | result = self.benchmark.evaluate(recognizer) |
112 | | - self.assertEqual(1, result.spearman) |
| 117 | + self.assertAlmostEqual(1, result.spearman) |
113 | 118 |
|
114 | 119 | def test_to_dataset(self): |
115 | 120 | # Test conversion to HuggingFace dataset |
@@ -153,3 +158,39 @@ def test_to_dataset_both_directions(self): |
153 | 158 | self.assertEqual(second_example["labels_a"], list(first_doc.labels_b)) |
154 | 159 | self.assertEqual(second_example["labels_b"], list(first_doc.labels_a)) |
155 | 160 |
|
| 161 | + |
| 162 | +class TestMultiLengthDifferenceRecognitionBenchmark(TestCase): |
| 163 | + def setUp(self): |
| 164 | + self.positive_dataset = SyntheticSTSDataset( |
| 165 | + path=Path(__file__).parent.parent.parent / "ists_finetuning" / "data" / "synthetic_rsd" / "ft:gpt-4o-mini-2024-07-18:cl-uzh:rsd-test-en-v3:AxfIBckt_train_v2.jsonl", |
| 166 | + ) |
| 167 | + self.negative_dataset = PAWSXDataset() |
| 168 | + |
| 169 | + def test_basic_functionality(self): |
| 170 | + benchmark = MultiLengthDifferenceRecognitionBenchmark( |
| 171 | + positive_dataset=self.positive_dataset, |
| 172 | + negative_dataset=self.negative_dataset, |
| 173 | + positive_ratio=0.5, |
| 174 | + max_sentences_per_document=3, |
| 175 | + max_inversions=2, |
| 176 | + seed=42 |
| 177 | + ) |
| 178 | + |
| 179 | + # Check that we have the expected number of benchmarks |
| 180 | + self.assertEqual(len(benchmark.benchmarks), 3) |
| 181 | + |
| 182 | + # Check that num_sentences_range is correct |
| 183 | + self.assertEqual(benchmark.num_sentences_range, [1, 2, 3]) |
| 184 | + |
| 185 | + # Check that num_inversions_range is correct (1 for first entry, then 1, 2) |
| 186 | + self.assertEqual(benchmark.num_inversions_range, [0, 1, 2]) |
| 187 | + |
| 188 | + # Test dataset conversion |
| 189 | + dataset = benchmark.to_dataset() |
| 190 | + self.assertIsInstance(dataset, datasets.Dataset) |
| 191 | + self.assertTrue({"text_a", "text_b", "labels_a", "labels_b"}.issubset(dataset.features)) |
| 192 | + |
| 193 | + # Test both directions |
| 194 | + dataset_both = benchmark.to_dataset(both_directions=True) |
| 195 | + self.assertEqual(len(dataset_both), len(dataset) * 2) |
| 196 | + |
0 commit comments