1
1
import os
2
+ import threading
3
+ from copy import deepcopy
2
4
3
5
import numpy as np
4
6
import pytest
7
9
from transformers .models .table_transformer .modeling_table_transformer import (
8
10
TableTransformerDecoder ,
9
11
)
10
- from copy import deepcopy
11
12
12
13
import unstructured_inference .models .table_postprocess as postprocess
13
14
from unstructured_inference .models import tables
@@ -572,7 +573,7 @@ def test_load_table_model_raises_when_not_available(model_path):
572
573
573
574
574
575
@pytest .mark .parametrize (
575
- "bbox1, bbox2, expected_result" ,
576
+ ( "bbox1" , " bbox2" , " expected_result") ,
576
577
[
577
578
((0 , 0 , 5 , 5 ), (2 , 2 , 7 , 7 ), 0.36 ),
578
579
((0 , 0 , 0 , 0 ), (6 , 6 , 10 , 10 ), 0 ),
@@ -921,7 +922,9 @@ def test_table_prediction_output_format(
921
922
)
922
923
if output_format :
923
924
result = table_transformer .run_prediction (
924
- example_image , result_format = output_format , ocr_tokens = mocked_ocr_tokens
925
+ example_image ,
926
+ result_format = output_format ,
927
+ ocr_tokens = mocked_ocr_tokens ,
925
928
)
926
929
else :
927
930
result = table_transformer .run_prediction (example_image , ocr_tokens = mocked_ocr_tokens )
@@ -952,7 +955,9 @@ def test_table_prediction_output_format_when_wrong_type_then_value_error(
952
955
)
953
956
with pytest .raises (ValueError ):
954
957
table_transformer .run_prediction (
955
- example_image , result_format = "Wrong format" , ocr_tokens = mocked_ocr_tokens
958
+ example_image ,
959
+ result_format = "Wrong format" ,
960
+ ocr_tokens = mocked_ocr_tokens ,
956
961
)
957
962
958
963
@@ -991,7 +996,8 @@ def test_table_prediction_with_no_ocr_tokens(table_transformer, example_image):
991
996
],
992
997
)
993
998
def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_and_threshold (
994
- thresholds , expected_object_number
999
+ thresholds ,
1000
+ expected_object_number ,
995
1001
):
996
1002
objects = [
997
1003
{"label" : "0" , "score" : 0.2 },
@@ -1010,7 +1016,8 @@ def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_
1010
1016
],
1011
1017
)
1012
1018
def test_objects_are_filtered_based_on_class_thresholds_when_two_classes (
1013
- thresholds , expected_object_number
1019
+ thresholds ,
1020
+ expected_object_number ,
1014
1021
):
1015
1022
objects = [
1016
1023
{"label" : "0" , "score" : 0.2 },
@@ -1800,7 +1807,7 @@ def test_compute_confidence_score_zero_division_error_handling():
1800
1807
1801
1808
1802
1809
@pytest .mark .parametrize (
1803
- "column_span_score, row_span_score, expected_text_to_indexes" ,
1810
+ ( "column_span_score" , " row_span_score" , " expected_text_to_indexes") ,
1804
1811
[
1805
1812
(
1806
1813
0.9 ,
@@ -1827,7 +1834,9 @@ def test_compute_confidence_score_zero_division_error_handling():
1827
1834
],
1828
1835
)
1829
1836
def test_subcells_filtering_when_overlapping_spanning_cells (
1830
- column_span_score , row_span_score , expected_text_to_indexes
1837
+ column_span_score ,
1838
+ row_span_score ,
1839
+ expected_text_to_indexes ,
1831
1840
):
1832
1841
"""
1833
1842
# table
@@ -1894,3 +1903,17 @@ def test_subcells_filtering_when_overlapping_spanning_cells(
1894
1903
1895
1904
predicted_cells_after_reorder , _ = structure_to_cells (saved_table_structure , tokens = tokens )
1896
1905
assert predicted_cells_after_reorder == predicted_cells
1906
+
1907
+
1908
+ def test_model_init_is_thread_safe ():
1909
+ threads = []
1910
+ tables .tables_agent .model = None
1911
+ for i in range (5 ):
1912
+ thread = threading .Thread (target = tables .load_agent )
1913
+ threads .append (thread )
1914
+ thread .start ()
1915
+
1916
+ for thread in threads :
1917
+ thread .join ()
1918
+
1919
+ assert tables .tables_agent .model is not None
0 commit comments