diff --git a/langextract/annotation.py b/langextract/annotation.py index 4f0ec7b..0782d4b 100644 --- a/langextract/annotation.py +++ b/langextract/annotation.py @@ -357,13 +357,14 @@ def _annotate_documents_single_pass( "Completing annotation for document ID %s.", curr_document.document_id, ) + doc_extractions = list(annotated_extractions) annotated_doc = data.AnnotatedDocument( document_id=curr_document.document_id, - extractions=annotated_extractions, + extractions=doc_extractions, text=curr_document.text, ) yield annotated_doc - annotated_extractions.clear() + annotated_extractions = [] # Ensure next document gets a fresh list. curr_document = next(doc_iter, None) assert curr_document is not None, ( @@ -400,13 +401,15 @@ def _annotate_documents_single_pass( logging.info( "Finalizing annotation for document ID %s.", curr_document.document_id ) + doc_extractions = list(annotated_extractions) annotated_doc = data.AnnotatedDocument( document_id=curr_document.document_id, - extractions=annotated_extractions, + extractions=doc_extractions, text=curr_document.text, ) yield annotated_doc + annotated_extractions = [] # Fresh list prevents bleed into future docs. logging.info("Document annotation completed.") diff --git a/tests/annotation_test.py b/tests/annotation_test.py index 78df7dd..ce9d97c 100644 --- a/tests/annotation_test.py +++ b/tests/annotation_test.py @@ -742,6 +742,93 @@ def mock_infer_side_effect(batch_prompts, **kwargs): self.assertGreaterEqual(mock_language_model.infer.call_count, 0) + def test_annotate_documents_prevents_document_bleed(self): + mock_language_model = mock.Mock() + call_outputs = [ + [types.ScoredOutput(score=1.0, output="DOC1")], + [types.ScoredOutput(score=1.0, output="DOC2")], + ] + + def infer_side_effect(batch_prompts, **kwargs): + self.assertLen(batch_prompts, 1) + self.assertTrue(call_outputs) + outputs_for_call = call_outputs.pop(0) + return [list(outputs_for_call)] + + mock_language_model.infer.side_effect = infer_side_effect + + class _FakeResolver: + + def resolve(self, llm_output, **kwargs): + if llm_output == "DOC1": + return [ + data.Extraction( + extraction_class="Person", + extraction_text="Bob", + ) + ] + if llm_output == "DOC2": + return [ + data.Extraction( + extraction_class="Person", + extraction_text="Charlie", + ) + ] + raise AssertionError(f"Unexpected llm_output: {llm_output!r}") + + def align( + self, + extractions, + chunk_text, + token_offset, + char_offset, + **kwargs, + ): + return list(extractions) + + annotator = annotation.Annotator( + language_model=mock_language_model, + prompt_template=prompting.PromptTemplateStructured(description=""), + ) + + documents = [ + data.Document(text="Doc1 text", document_id="doc1"), + data.Document(text="Doc2 text", document_id="doc2"), + ] + + annotations = list( + annotator.annotate_documents( + documents, + resolver=_FakeResolver(), + max_char_buffer=200, + batch_length=1, + debug=False, + show_progress=False, + ) + ) + + self.assertLen(annotations, 2) + self.assertEmpty(call_outputs) + self.assertEqual(annotations[0].document_id, "doc1") + self.assertEqual(annotations[1].document_id, "doc2") + self.assertIsNotNone(annotations[0].extractions) + self.assertIsNotNone(annotations[1].extractions) + self.assertEqual( + [ + extraction.extraction_text + for extraction in annotations[0].extractions + ], + ["Bob"], + ) + self.assertEqual( + [ + extraction.extraction_text + for extraction in annotations[1].extractions + ], + ["Charlie"], + ) + self.assertIsNot(annotations[0].extractions, annotations[1].extractions) + @parameterized.named_parameters( dict( testcase_name="same_document_id_contiguous",