Skip to content

Commit 673e5ca

Browse files
committed
Make training data of bert configurable
1 parent 0be7f97 commit 673e5ca

File tree

4 files changed

+123
-3
lines changed

4 files changed

+123
-3
lines changed

data/synthetic.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from pathlib import Path
2+
from typing import List
3+
4+
import jsonlines
5+
6+
from rsd.data import DifferenceDataset
7+
from rsd.recognizers.utils import DifferenceSample
8+
9+
class SyntheticSTSDataset(DifferenceDataset):
10+
11+
def __init__(self,
12+
path: Path,
13+
):
14+
super().__init__()
15+
self.path = path
16+
with jsonlines.open(path) as reader:
17+
self.dataset = list(reader)
18+
19+
def get_samples(self) -> List[DifferenceSample]:
20+
samples = []
21+
for sample in self.dataset:
22+
samples.append(DifferenceSample(
23+
tokens_a=sample["text_a"].split(),
24+
tokens_b=sample["text_b"].split(),
25+
labels_a=sample["labels_a"],
26+
labels_b=sample["labels_b"],
27+
))
28+
return samples
29+
30+
def __str__(self):
31+
return f"SyntheticDataset({self.path.name})"

experiments/benchmark.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,35 @@ def to_dataset(self, both_directions: bool = False) -> datasets.Dataset:
189189
data["labels_b"].append(list(doc.labels_a))
190190

191191
return datasets.Dataset.from_dict(data)
192+
193+
194+
class MultiLengthDifferenceRecognitionBenchmark:
195+
196+
def __init__(self,
197+
positive_dataset: DifferenceDataset,
198+
negative_dataset: DifferenceDataset = None,
199+
positive_ratio: float = 1.0,
200+
max_sentences_per_document: int = 1,
201+
max_inversions: int = 0,
202+
seed: int = 42,
203+
):
204+
assert max_sentences_per_document >= 1
205+
assert max_inversions <= max_sentences_per_document
206+
self.num_sentences_range = list(range(1, max_sentences_per_document + 1))
207+
self.num_inversions_range = [0] * (max_sentences_per_document - max_inversions) + list(range(1, max_inversions + 1))
208+
assert len(self.num_inversions_range) == len(self.num_sentences_range)
209+
self.benchmarks = []
210+
for num_sentences, num_inversions in zip(self.num_sentences_range, self.num_inversions_range):
211+
benchmark = DifferenceRecognitionBenchmark(
212+
positive_dataset=positive_dataset,
213+
negative_dataset=negative_dataset,
214+
positive_ratio=positive_ratio,
215+
num_sentences_per_document=num_sentences,
216+
num_inversions=num_inversions,
217+
seed=seed,
218+
)
219+
self.benchmarks.append(benchmark)
220+
221+
def to_dataset(self, both_directions: bool = False) -> datasets.Dataset:
222+
ds = [benchmark.to_dataset(both_directions=both_directions) for benchmark in self.benchmarks]
223+
return datasets.concatenate_datasets(ds)

tests/test_benchmark.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import random
2+
from pathlib import Path
23
from typing import Dict, Tuple
34
from unittest import TestCase
45

@@ -7,15 +8,19 @@
78

89
from rsd.data.ists import ISTSDataset
910
from rsd.data.pawsx import PAWSXDataset
10-
from rsd.experiments.benchmark import DifferenceRecognitionBenchmark
11+
from rsd.data.synthetic import SyntheticSTSDataset
12+
from rsd.experiments.benchmark import DifferenceRecognitionBenchmark, MultiLengthDifferenceRecognitionBenchmark
1113
from rsd.recognizers.base import DifferenceRecognizer
1214
from rsd.recognizers.utils import DifferenceSample
1315

1416

1517
class DifferenceRecognitionBenchmarkTestCase(TestCase):
1618

1719
def setUp(self) -> None:
18-
self.positive_dataset = ISTSDataset()
20+
# self.positive_dataset = ISTSDataset()
21+
self.positive_dataset = SyntheticSTSDataset(
22+
path=Path(__file__).parent.parent.parent / "ists_finetuning" / "data" / "synthetic_rsd" / "ft:gpt-4o-mini-2024-07-18:cl-uzh:rsd-test-en-v3:AxfIBckt_train_v2.jsonl",
23+
)
1924
self.negative_dataset = PAWSXDataset()
2025
self.benchmark = DifferenceRecognitionBenchmark(
2126
positive_dataset=self.positive_dataset,
@@ -109,7 +114,7 @@ def predict(self, a: str, b: str, *args, **kwargs):
109114

110115
recognizer = OracleRecognizer(self.benchmark)
111116
result = self.benchmark.evaluate(recognizer)
112-
self.assertEqual(1, result.spearman)
117+
self.assertAlmostEqual(1, result.spearman)
113118

114119
def test_to_dataset(self):
115120
# Test conversion to HuggingFace dataset
@@ -153,3 +158,39 @@ def test_to_dataset_both_directions(self):
153158
self.assertEqual(second_example["labels_a"], list(first_doc.labels_b))
154159
self.assertEqual(second_example["labels_b"], list(first_doc.labels_a))
155160

161+
162+
class TestMultiLengthDifferenceRecognitionBenchmark(TestCase):
163+
def setUp(self):
164+
self.positive_dataset = SyntheticSTSDataset(
165+
path=Path(__file__).parent.parent.parent / "ists_finetuning" / "data" / "synthetic_rsd" / "ft:gpt-4o-mini-2024-07-18:cl-uzh:rsd-test-en-v3:AxfIBckt_train_v2.jsonl",
166+
)
167+
self.negative_dataset = PAWSXDataset()
168+
169+
def test_basic_functionality(self):
170+
benchmark = MultiLengthDifferenceRecognitionBenchmark(
171+
positive_dataset=self.positive_dataset,
172+
negative_dataset=self.negative_dataset,
173+
positive_ratio=0.5,
174+
max_sentences_per_document=3,
175+
max_inversions=2,
176+
seed=42
177+
)
178+
179+
# Check that we have the expected number of benchmarks
180+
self.assertEqual(len(benchmark.benchmarks), 3)
181+
182+
# Check that num_sentences_range is correct
183+
self.assertEqual(benchmark.num_sentences_range, [1, 2, 3])
184+
185+
# Check that num_inversions_range is correct (1 for first entry, then 1, 2)
186+
self.assertEqual(benchmark.num_inversions_range, [0, 1, 2])
187+
188+
# Test dataset conversion
189+
dataset = benchmark.to_dataset()
190+
self.assertIsInstance(dataset, datasets.Dataset)
191+
self.assertTrue({"text_a", "text_b", "labels_a", "labels_b"}.issubset(dataset.features))
192+
193+
# Test both directions
194+
dataset_both = benchmark.to_dataset(both_directions=True)
195+
self.assertEqual(len(dataset_both), len(dataset) * 2)
196+

tests/test_data.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from pathlib import Path
12
from unittest import TestCase
23

34
from rsd.data.ists import ISTSDataset
45
from rsd.data.pawsx import PAWSXDataset, CrosslingualPAWSXDataset
6+
from rsd.data.synthetic import SyntheticSTSDataset
57
from rsd.recognizers.utils import DifferenceSample
68

79

@@ -59,3 +61,17 @@ def test_get_samples(self):
5961
self.assertSetEqual(set(sample.labels_b), {-1})
6062
print(samples[0])
6163
print(samples[1])
64+
65+
66+
class SyntheticSTSDatasetTestCase(TestCase):
67+
68+
def setUp(self) -> None:
69+
self.dataset = SyntheticSTSDataset(
70+
path=Path(__file__).parent.parent.parent / "ists_finetuning" / "data" / "synthetic_rsd" / "ft:gpt-4o-mini-2024-07-18:cl-uzh:rsd-test-en-v3:AxfIBckt_train_v2.jsonl",
71+
)
72+
73+
def test_get_samples(self):
74+
sample: DifferenceSample = self.dataset.get_samples()[0]
75+
self.assertEqual(len(sample.tokens_a), len(sample.labels_a))
76+
self.assertEqual(len(sample.tokens_b), len(sample.labels_b))
77+
print(sample)

0 commit comments

Comments
 (0)