Skip to content

Commit bc44421

Browse files
committed
Added unit tests for bidirectional huggingface models
1 parent a4c52f4 commit bc44421

File tree

1 file changed

+33
-12
lines changed

1 file changed

+33
-12
lines changed

tests/test_model_helpers/test_huggingface.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,36 +11,38 @@
1111

1212

1313
class TestNextWord:
14-
@pytest.mark.parametrize('model_identifier, expected_next_word', [
15-
pytest.param('bert-base-uncased', '.', marks=pytest.mark.memory_intense),
16-
pytest.param('gpt2-xl', 'jumps', marks=pytest.mark.memory_intense),
17-
('distilgpt2', 'es'),
14+
@pytest.mark.parametrize('model_identifier, expected_next_word, bidirectional', [
15+
pytest.param('bert-base-uncased', 'and', True, marks=pytest.mark.memory_intense),
16+
pytest.param('bert-base-uncased', '.', False, marks=pytest.mark.memory_intense),
17+
pytest.param('gpt2-xl', 'jumps', False, marks=pytest.mark.memory_intense),
18+
('distilgpt2', 'es', False),
1819
])
19-
def test_single_string(self, model_identifier, expected_next_word):
20+
def test_single_string(self, model_identifier, expected_next_word, bidirectional):
2021
"""
2122
This is a simple test that takes in text = 'the quick brown fox', and tests the next word.
2223
This test is a stand-in prototype to check if our model definitions are correct.
2324
"""
2425

25-
model = HuggingfaceSubject(model_id=model_identifier, region_layer_mapping={})
26+
model = HuggingfaceSubject(model_id=model_identifier, region_layer_mapping={}, bidirectional=bidirectional)
2627
text = 'the quick brown fox'
2728
_logger.info(f'Running {model.identifier()} with text "{text}"')
2829
model.start_behavioral_task(task=ArtificialSubject.Task.next_word)
2930
next_word = model.digest_text(text)['behavior'].values
3031
assert next_word == expected_next_word
3132

32-
@pytest.mark.parametrize('model_identifier, expected_next_words', [
33-
pytest.param('bert-base-uncased', ['.', '.', '.'], marks=pytest.mark.memory_intense),
34-
pytest.param('gpt2-xl', ['jumps', 'the', 'dog'], marks=pytest.mark.memory_intense),
35-
('distilgpt2', ['es', 'the', 'fox']),
33+
@pytest.mark.parametrize('model_identifier, expected_next_words, bidirectional', [
34+
pytest.param('bert-base-uncased', [';', 'the', 'water'], True, marks=pytest.mark.memory_intense),
35+
pytest.param('bert-base-uncased', ['.', '.', '.'], False, marks=pytest.mark.memory_intense),
36+
pytest.param('gpt2-xl', ['jumps', 'the', 'dog'], False, marks=pytest.mark.memory_intense),
37+
('distilgpt2', ['es', 'the', 'fox'], False),
3638
])
37-
def test_list_input(self, model_identifier, expected_next_words):
39+
def test_list_input(self, model_identifier, expected_next_words, bidirectional):
3840
"""
3941
This is a simple test that takes in text = ['the quick brown fox', 'jumps over', 'the lazy'], and tests the
4042
next word for each text part in the list.
4143
This test is a stand-in prototype to check if our model definitions are correct.
4244
"""
43-
model = HuggingfaceSubject(model_id=model_identifier, region_layer_mapping={})
45+
model = HuggingfaceSubject(model_id=model_identifier, region_layer_mapping={}, bidirectional=bidirectional)
4446
text = ['the quick brown fox', 'jumps over', 'the lazy']
4547
_logger.info(f'Running {model.identifier()} with text "{text}"')
4648
model.start_behavioral_task(task=ArtificialSubject.Task.next_word)
@@ -173,6 +175,25 @@ def test_one_text_single_target(self):
173175
assert len(representations['neuroid']) == 768
174176
_logger.info(f'representation shape is correct: {representations.shape}')
175177

178+
@pytest.mark.memory_intense
179+
def test_one_text_single_target_bidirectional(self):
180+
"""
181+
This is a simple test that takes in text = 'the quick brown fox', and asserts that a bidirectiona BERT model
182+
layer indexed by `representation_layer` has 1 text presentation and 768 neurons. This test is a stand-in prototype
183+
to check if our model definitions are correct.
184+
"""
185+
model = HuggingfaceSubject(model_id='bert-base-uncased', region_layer_mapping={
186+
ArtificialSubject.RecordingTarget.language_system: 'bert.encoder.layer.4'})
187+
text = 'the quick brown fox'
188+
_logger.info(f'Running {model.identifier()} with text "{text}"')
189+
model.start_neural_recording(recording_target=ArtificialSubject.RecordingTarget.language_system,
190+
recording_type=ArtificialSubject.RecordingType.fMRI)
191+
representations = model.digest_text(text)['neural']
192+
assert len(representations['presentation']) == 1
193+
assert representations['stimulus'].squeeze() == text
194+
assert len(representations['neuroid']) == 768
195+
_logger.info(f'representation shape is correct: {representations.shape}')
196+
176197
@pytest.mark.memory_intense
177198
def test_one_text_two_targets(self):
178199
model = HuggingfaceSubject(model_id='distilgpt2', region_layer_mapping={

0 commit comments

Comments
 (0)