Skip to content

Commit 4214938

Browse files
committed
fixed test
1 parent 564c64c commit 4214938

File tree

2 files changed

+169
-29
lines changed

2 files changed

+169
-29
lines changed

tests/test_gemini_ai.py

Lines changed: 102 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,21 @@ class TestGeminiAI(unittest.TestCase):
1414
@patch("src.app.gemini_ai.genai.configure")
1515
def setUp(self, mock_configure, mock_model):
1616
self.api_key = "test_api_key"
17+
self.mock_model_instance = MagicMock()
18+
mock_model.return_value = self.mock_model_instance
1719
self.gemini_ai = GeminiAI(api_key=self.api_key)
1820
self.transcript_text = "This is a test transcript."
19-
self.mock_model = mock_model
2021

21-
@patch("src.app.gemini_ai.genai.GenerativeModel")
22-
def test_initialization(self, mock_model):
23-
self.assertIsNotNone(self.gemini_ai.model)
22+
def test_initialization_with_key(self):
23+
with patch("src.app.gemini_ai.genai.configure") as mock_configure, \
24+
patch("src.app.gemini_ai.genai.GenerativeModel") as mock_model:
25+
26+
gemini_ai = GeminiAI(api_key="test_key")
27+
mock_configure.assert_called_once_with(api_key="test_key")
28+
mock_model.assert_called_once_with("gemini-1.5-flash")
29+
self.assertIsNotNone(gemini_ai.model)
2430

31+
def test_initialization_without_key(self):
2532
gemini_ai_no_key = GeminiAI(api_key=None)
2633
self.assertIsNone(gemini_ai_no_key.model)
2734

@@ -32,29 +39,109 @@ def test_is_configured(self):
3239
self.assertFalse(gemini_ai_no_key.is_configured())
3340

3441
def test_generate_summary_success(self):
35-
self.mock_model.return_value.generate_content.return_value = MagicMock(
36-
text="This is a summary."
37-
)
42+
mock_response = MagicMock()
43+
mock_response.text = "This is a summary."
44+
self.mock_model_instance.generate_content.return_value = mock_response
45+
3846
summary = self.gemini_ai.generate_summary(self.transcript_text)
3947
self.assertEqual(summary, "This is a summary.")
48+
self.mock_model_instance.generate_content.assert_called_once()
4049

4150
def test_generate_summary_failure(self):
42-
self.mock_model.return_value.generate_content.side_effect = Exception(
43-
"API Error"
44-
)
45-
51+
self.mock_model_instance.generate_content.side_effect = Exception("API Error")
52+
4653
summary = self.gemini_ai.generate_summary(self.transcript_text)
4754
self.assertIn("Error generating summary", summary)
55+
self.assertIn("API Error", summary)
56+
57+
def test_extract_key_quotes_success(self):
58+
mock_response = MagicMock()
59+
mock_response.text = "These are key quotes."
60+
self.mock_model_instance.generate_content.return_value = mock_response
61+
62+
quotes = self.gemini_ai.extract_key_quotes(self.transcript_text)
63+
self.assertEqual(quotes, "These are key quotes.")
64+
65+
def test_extract_key_quotes_failure(self):
66+
self.mock_model_instance.generate_content.side_effect = Exception("API Error")
67+
68+
quotes = self.gemini_ai.extract_key_quotes(self.transcript_text)
69+
self.assertIn("Error extracting quotes", quotes)
70+
71+
def test_create_study_guide_success(self):
72+
mock_response = MagicMock()
73+
mock_response.text = "This is a study guide."
74+
self.mock_model_instance.generate_content.return_value = mock_response
75+
76+
study_guide = self.gemini_ai.create_study_guide(self.transcript_text)
77+
self.assertEqual(study_guide, "This is a study guide.")
78+
79+
def test_create_study_guide_failure(self):
80+
self.mock_model_instance.generate_content.side_effect = Exception("API Error")
81+
82+
study_guide = self.gemini_ai.create_study_guide(self.transcript_text)
83+
self.assertIn("Error creating study guide", study_guide)
84+
85+
def test_generate_qa_success(self):
86+
mock_response = MagicMock()
87+
mock_response.text = "Q: Test question?\nA: Test answer."
88+
self.mock_model_instance.generate_content.return_value = mock_response
89+
90+
qa = self.gemini_ai.generate_qa(self.transcript_text)
91+
self.assertEqual(qa, "Q: Test question?\nA: Test answer.")
92+
93+
def test_generate_qa_failure(self):
94+
self.mock_model_instance.generate_content.side_effect = Exception("API Error")
95+
96+
qa = self.gemini_ai.generate_qa(self.transcript_text)
97+
self.assertIn("Error generating Q&A", qa)
98+
99+
def test_create_flashcards_success(self):
100+
mock_response = MagicMock()
101+
mock_response.text = "FRONT: Test term\nBACK: Test definition\n---"
102+
self.mock_model_instance.generate_content.return_value = mock_response
103+
104+
flashcards = self.gemini_ai.create_flashcards(self.transcript_text)
105+
self.assertEqual(flashcards, "FRONT: Test term\nBACK: Test definition\n---")
106+
107+
def test_create_flashcards_failure(self):
108+
self.mock_model_instance.generate_content.side_effect = Exception("API Error")
109+
110+
flashcards = self.gemini_ai.create_flashcards(self.transcript_text)
111+
self.assertIn("Error creating flashcards", flashcards)
112+
113+
def test_highlight_insights_success(self):
114+
mock_response = MagicMock()
115+
mock_response.text = "🔍 Key Insights: Test insights"
116+
self.mock_model_instance.generate_content.return_value = mock_response
117+
118+
insights = self.gemini_ai.highlight_insights(self.transcript_text)
119+
self.assertEqual(insights, "🔍 Key Insights: Test insights")
120+
121+
def test_highlight_insights_failure(self):
122+
self.mock_model_instance.generate_content.side_effect = Exception("API Error")
123+
124+
insights = self.gemini_ai.highlight_insights(self.transcript_text)
125+
self.assertIn("Error extracting insights", insights)
48126

49127
def test_chat_with_transcript_success(self):
50-
self.mock_model.return_value.generate_content.return_value = MagicMock(
51-
text="This is an answer."
52-
)
128+
mock_response = MagicMock()
129+
mock_response.text = "This is an answer."
130+
self.mock_model_instance.generate_content.return_value = mock_response
131+
53132
answer = self.gemini_ai.chat_with_transcript(
54133
self.transcript_text, "What is this?"
55134
)
56135
self.assertEqual(answer, "This is an answer.")
57136

137+
def test_chat_with_transcript_failure(self):
138+
self.mock_model_instance.generate_content.side_effect = Exception("API Error")
139+
140+
answer = self.gemini_ai.chat_with_transcript(
141+
self.transcript_text, "What is this?"
142+
)
143+
self.assertIn("Error in chat", answer)
144+
58145

59146
if __name__ == "__main__":
60-
unittest.main()
147+
unittest.main()

tests/test_transcript_extractor.py

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
55

66
import unittest
7-
from unittest.mock import patch, MagicMock
7+
from unittest.mock import patch, MagicMock, mock_open
88
from src.app.transcript_extractor import YouTubeTranscriptExtractor
99

1010

@@ -38,23 +38,32 @@ def test_get_transcript_youtube_api_failure(self, mock_api):
3838
transcript = self.extractor.get_transcript_youtube_api("dQw4w9WgXcQ")
3939
self.assertIsNone(transcript)
4040

41+
@patch("src.app.transcript_extractor.tempfile.TemporaryDirectory")
42+
@patch("src.app.transcript_extractor.os.listdir")
43+
@patch("builtins.open", new_callable=mock_open, read_data="WEBVTT\n\n00:00:00.000 --> 00:00:01.000\nhello from ytdlp")
4144
@patch("src.app.transcript_extractor.yt_dlp")
42-
def test_get_transcript_ytdlp_success(self, mock_ytdlp):
45+
def test_get_transcript_ytdlp_success(self, mock_ytdlp, mock_file, mock_listdir, mock_tempdir):
46+
# Mock the temporary directory
47+
mock_temp_path = "/tmp/test"
48+
mock_tempdir.return_value.__enter__.return_value = mock_temp_path
49+
50+
# Mock the yt-dlp instance
4351
mock_ydl_instance = MagicMock()
4452
mock_ydl_instance.extract_info.return_value = {
45-
"subtitles": {"en": [{"ext": "vtt", "url": "..."}]}
53+
"title": "Test Video",
54+
"duration": 60,
55+
"subtitles": {},
56+
"automatic_captions": {}
4657
}
4758
mock_ytdlp.YoutubeDL.return_value.__enter__.return_value = mock_ydl_instance
48-
49-
with patch.object(
50-
self.extractor,
51-
"_process_ytdlp_subtitles",
52-
return_value=[{"text": "hello from ytdlp"}],
53-
):
54-
transcript = self.extractor.get_transcript_ytdlp("dQw4w9WgXcQ")
55-
self.assertIsNotNone(transcript)
56-
if transcript:
57-
self.assertEqual(transcript[0]["text"], "hello from ytdlp")
59+
60+
# Mock file listing to return a VTT file
61+
mock_listdir.return_value = ["test_video.en.vtt"]
62+
63+
transcript = self.extractor.get_transcript_ytdlp("dQw4w9WgXcQ")
64+
self.assertIsNotNone(transcript)
65+
if transcript:
66+
self.assertEqual(transcript[0]["text"], "hello from ytdlp")
5867

5968
@patch("src.app.transcript_extractor.yt_dlp")
6069
def test_get_transcript_ytdlp_failure(self, mock_ytdlp):
@@ -94,6 +103,50 @@ def test_extract_transcript_fallback_to_ytdlp(self, mock_ytdlp, mock_api):
94103
self.assertIsNotNone(transcript)
95104
self.assertIn("Success using yt-dlp", status)
96105

106+
def test_parse_vtt_content(self):
107+
vtt_content = """WEBVTT
108+
109+
00:00:00.000 --> 00:00:01.000
110+
hello world
111+
112+
00:00:01.000 --> 00:00:02.000
113+
this is a test"""
114+
115+
transcript = self.extractor._parse_vtt_content(vtt_content)
116+
self.assertEqual(len(transcript), 2)
117+
self.assertEqual(transcript[0]["text"], "hello world")
118+
self.assertEqual(transcript[0]["start"], 0.0)
119+
self.assertEqual(transcript[1]["text"], "this is a test")
120+
self.assertEqual(transcript[1]["start"], 1.0)
121+
122+
def test_parse_srt_content(self):
123+
srt_content = """1
124+
00:00:00,000 --> 00:00:01,000
125+
hello world
126+
127+
2
128+
00:00:01,000 --> 00:00:02,000
129+
this is a test"""
130+
131+
transcript = self.extractor._parse_srt_content(srt_content)
132+
self.assertEqual(len(transcript), 2)
133+
self.assertEqual(transcript[0]["text"], "hello world")
134+
self.assertEqual(transcript[0]["start"], 0.0)
135+
self.assertEqual(transcript[1]["text"], "this is a test")
136+
self.assertEqual(transcript[1]["start"], 1.0)
137+
138+
def test_parse_timestamp(self):
139+
# Test VTT timestamp parsing
140+
self.assertEqual(self.extractor._parse_timestamp("00:00:01.500"), 1.5)
141+
self.assertEqual(self.extractor._parse_timestamp("00:01:30.250"), 90.25)
142+
self.assertEqual(self.extractor._parse_timestamp("01:00:00.000"), 3600.0)
143+
144+
def test_parse_srt_timestamp(self):
145+
# Test SRT timestamp parsing
146+
self.assertEqual(self.extractor._parse_srt_timestamp("00:00:01,500"), 1.5)
147+
self.assertEqual(self.extractor._parse_srt_timestamp("00:01:30,250"), 90.25)
148+
self.assertEqual(self.extractor._parse_srt_timestamp("01:00:00,000"), 3600.0)
149+
97150

98151
if __name__ == "__main__":
99-
unittest.main()
152+
unittest.main()

0 commit comments

Comments
 (0)