Skip to content

Commit a0df407

Browse files
authored
feat: add element slicing support (#432)
I noticed we had support for slicing of `LayoutElements` and `TextRegions`, but through a `slice` method rather than using an index, so I added the indexing as an alias for `slice`.
1 parent 6a46303 commit a0df407

File tree

5 files changed

+57
-1
lines changed

5 files changed

+57
-1
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 1.0.6
2+
3+
* Add slicing through indexing for vectorized elements
4+
15
## 1.0.5
26

37
* feat: add thread lock to prevent racing condition when instantiating singletons

test_unstructured_inference/test_elements.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,3 +472,49 @@ def test_layoutelements_concatenate():
472472
assert joint.sources.tolist() == ["yolox", "yolox", "ocr", "ocr"]
473473
assert joint.element_class_ids.tolist() == [0, 1, 1, 2]
474474
assert joint.element_class_id_map == {0: "type0", 1: "type1", 2: "type2"}
475+
476+
477+
@pytest.mark.parametrize(
478+
"test_elements",
479+
[
480+
TextRegions(
481+
element_coords=np.array(
482+
[
483+
[0.0, 0.0, 1.0, 1.0],
484+
[1.0, 0.0, 1.5, 1.0],
485+
[2.0, 0.0, 2.5, 1.0],
486+
[3.0, 0.0, 4.0, 1.0],
487+
[4.0, 0.0, 5.0, 1.0],
488+
]
489+
),
490+
texts=np.array(["0", "1", "2", "3", "4"]),
491+
sources=np.array(["foo", "foo", "foo", "foo", "foo"], dtype="<U3"),
492+
source=np.str_("foo"),
493+
),
494+
LayoutElements(
495+
element_coords=np.array(
496+
[
497+
[0.0, 0.0, 1.0, 1.0],
498+
[1.0, 0.0, 1.5, 1.0],
499+
[2.0, 0.0, 2.5, 1.0],
500+
[3.0, 0.0, 4.0, 1.0],
501+
[4.0, 0.0, 5.0, 1.0],
502+
]
503+
),
504+
texts=np.array(["0", "1", "2", "3", "4"]),
505+
sources=np.array(["foo", "foo", "foo", "foo", "foo"], dtype="<U3"),
506+
source=np.str_("foo"),
507+
element_probs=np.array([0.0, 0.1, 0.2, 0.3, 0.4]),
508+
),
509+
],
510+
)
511+
def test_textregions_support_numpy_slicing(test_elements):
512+
np.testing.assert_equal(test_elements[1:4].texts, np.array(["1", "2", "3"]))
513+
np.testing.assert_equal(test_elements[0::2].texts, np.array(["0", "2", "4"]))
514+
np.testing.assert_equal(test_elements[[1, 2, 4]].texts, np.array(["1", "2", "4"]))
515+
np.testing.assert_equal(test_elements[np.array([1, 2, 4])].texts, np.array(["1", "2", "4"]))
516+
np.testing.assert_equal(
517+
test_elements[np.array([True, False, False, True, False])].texts, np.array(["0", "3"])
518+
)
519+
if isinstance(test_elements, LayoutElements):
520+
np.testing.assert_almost_equal(test_elements[1:4].element_probs, np.array([0.1, 0.2, 0.3]))
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.5" # pragma: no cover
1+
__version__ = "1.0.6" # pragma: no cover

unstructured_inference/inference/elements.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,9 @@ def __post_init__(self):
226226
# we convert to float so data type is more consistent (e.g., None will be np.nan)
227227
self.element_coords = self.element_coords.astype(float)
228228

229+
def __getitem__(self, indices) -> TextRegions:
230+
return self.slice(indices)
231+
229232
def slice(self, indices) -> TextRegions:
230233
"""slice text regions based on indices"""
231234
return TextRegions(

unstructured_inference/inference/layoutelement.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def __eq__(self, other: object) -> bool:
7575
and np.array_equal(self.table_as_cells[mask], other.table_as_cells[mask])
7676
)
7777

78+
def __getitem__(self, indices):
79+
return self.slice(indices)
80+
7881
def slice(self, indices) -> LayoutElements:
7982
"""slice and return only selected indices"""
8083
return LayoutElements(

0 commit comments

Comments
 (0)