@@ -472,3 +472,49 @@ def test_layoutelements_concatenate():
472
472
assert joint .sources .tolist () == ["yolox" , "yolox" , "ocr" , "ocr" ]
473
473
assert joint .element_class_ids .tolist () == [0 , 1 , 1 , 2 ]
474
474
assert joint .element_class_id_map == {0 : "type0" , 1 : "type1" , 2 : "type2" }
475
+
476
+
477
+ @pytest .mark .parametrize (
478
+ "test_elements" ,
479
+ [
480
+ TextRegions (
481
+ element_coords = np .array (
482
+ [
483
+ [0.0 , 0.0 , 1.0 , 1.0 ],
484
+ [1.0 , 0.0 , 1.5 , 1.0 ],
485
+ [2.0 , 0.0 , 2.5 , 1.0 ],
486
+ [3.0 , 0.0 , 4.0 , 1.0 ],
487
+ [4.0 , 0.0 , 5.0 , 1.0 ],
488
+ ]
489
+ ),
490
+ texts = np .array (["0" , "1" , "2" , "3" , "4" ]),
491
+ sources = np .array (["foo" , "foo" , "foo" , "foo" , "foo" ], dtype = "<U3" ),
492
+ source = np .str_ ("foo" ),
493
+ ),
494
+ LayoutElements (
495
+ element_coords = np .array (
496
+ [
497
+ [0.0 , 0.0 , 1.0 , 1.0 ],
498
+ [1.0 , 0.0 , 1.5 , 1.0 ],
499
+ [2.0 , 0.0 , 2.5 , 1.0 ],
500
+ [3.0 , 0.0 , 4.0 , 1.0 ],
501
+ [4.0 , 0.0 , 5.0 , 1.0 ],
502
+ ]
503
+ ),
504
+ texts = np .array (["0" , "1" , "2" , "3" , "4" ]),
505
+ sources = np .array (["foo" , "foo" , "foo" , "foo" , "foo" ], dtype = "<U3" ),
506
+ source = np .str_ ("foo" ),
507
+ element_probs = np .array ([0.0 , 0.1 , 0.2 , 0.3 , 0.4 ]),
508
+ ),
509
+ ],
510
+ )
511
+ def test_textregions_support_numpy_slicing (test_elements ):
512
+ np .testing .assert_equal (test_elements [1 :4 ].texts , np .array (["1" , "2" , "3" ]))
513
+ np .testing .assert_equal (test_elements [0 ::2 ].texts , np .array (["0" , "2" , "4" ]))
514
+ np .testing .assert_equal (test_elements [[1 , 2 , 4 ]].texts , np .array (["1" , "2" , "4" ]))
515
+ np .testing .assert_equal (test_elements [np .array ([1 , 2 , 4 ])].texts , np .array (["1" , "2" , "4" ]))
516
+ np .testing .assert_equal (
517
+ test_elements [np .array ([True , False , False , True , False ])].texts , np .array (["0" , "3" ])
518
+ )
519
+ if isinstance (test_elements , LayoutElements ):
520
+ np .testing .assert_almost_equal (test_elements [1 :4 ].element_probs , np .array ([0.1 , 0.2 , 0.3 ]))
0 commit comments