Skip to content

Commit fbbdd9c

Browse files
authored
Dependency Updates (#37)
* Update pytorch-partial-tagger * Align new interfaces. * Bump version * Update requirements.txt
1 parent 3a5342e commit fbbdd9c

File tree

5 files changed

+50
-53
lines changed

5 files changed

+50
-53
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ requires-python = ">=3.8"
88

99
[tool.poetry]
1010
name = "spacy-partial-tagger"
11-
version = "0.15.1"
11+
version = "0.15.2"
1212
description = "Sequence Tagger for Partially Annotated Dataset in spaCy"
1313
authors = ["yasufumi <[email protected]>"]
1414
license = "MIT"
@@ -27,7 +27,7 @@ transformers = {extras = ["ja"], version = "^4.25.1"}
2727
torch = "^2.0.1"
2828
spacy = {extras = ["transformers"], version = "^3.3.1"}
2929
spacy-alignments = "^0.8.5"
30-
pytorch-partial-tagger = "^0.1.9"
30+
pytorch-partial-tagger = "^0.1.12"
3131

3232
[tool.poetry.group.dev.dependencies]
3333
mypy = "^1.3.0"

requirements.txt

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@ black==22.12.0 ; python_version >= "3.8" and python_version < "4.0"
22
blis==0.7.9 ; python_version >= "3.8" and python_version < "4.0"
33
catalogue==2.0.8 ; python_version >= "3.8" and python_version < "4.0"
44
certifi==2023.5.7 ; python_version >= "3.8" and python_version < "4.0"
5-
charset-normalizer==3.1.0 ; python_version >= "3.8" and python_version < "4.0"
6-
click==8.1.3 ; python_version >= "3.8" and python_version < "4.0"
5+
charset-normalizer==3.2.0 ; python_version >= "3.8" and python_version < "4.0"
6+
click==8.1.5 ; python_version >= "3.8" and python_version < "4.0"
77
colorama==0.4.6 ; python_version >= "3.8" and python_version < "4.0" and sys_platform == "win32" or python_version >= "3.8" and python_version < "4.0" and platform_system == "Windows"
8-
confection==0.0.4 ; python_version >= "3.8" and python_version < "4.0"
8+
confection==0.1.0 ; python_version >= "3.8" and python_version < "4.0"
99
coverage[toml]==7.2.7 ; python_version >= "3.8" and python_version < "4.0"
1010
cymem==2.0.7 ; python_version >= "3.8" and python_version < "4.0"
11-
exceptiongroup==1.1.1 ; python_version >= "3.8" and python_version < "3.11"
11+
exceptiongroup==1.1.2 ; python_version >= "3.8" and python_version < "3.11"
1212
filelock==3.12.2 ; python_version >= "3.8" and python_version < "4.0"
1313
flake8==4.0.1 ; python_version >= "3.8" and python_version < "4.0"
1414
fsspec==2023.6.0 ; python_version >= "3.8" and python_version < "4.0"
1515
fugashi==1.2.1 ; python_version >= "3.8" and python_version < "4.0"
16-
huggingface-hub==0.15.1 ; python_version >= "3.8" and python_version < "4.0"
16+
huggingface-hub==0.16.4 ; python_version >= "3.8" and python_version < "4.0"
1717
idna==3.4 ; python_version >= "3.8" and python_version < "4.0"
1818
iniconfig==2.0.0 ; python_version >= "3.8" and python_version < "4.0"
1919
ipadic==1.0.0 ; python_version >= "3.8" and python_version < "4.0"
@@ -25,36 +25,36 @@ mccabe==0.6.1 ; python_version >= "3.8" and python_version < "4.0"
2525
mpmath==1.3.0 ; python_version >= "3.8" and python_version < "4.0"
2626
murmurhash==1.0.9 ; python_version >= "3.8" and python_version < "4.0"
2727
mypy-extensions==1.0.0 ; python_version >= "3.8" and python_version < "4.0"
28-
mypy==1.3.0 ; python_version >= "3.8" and python_version < "4.0"
28+
mypy==1.4.1 ; python_version >= "3.8" and python_version < "4.0"
2929
networkx==3.1 ; python_version >= "3.8" and python_version < "4.0"
30-
numpy==1.24.3 ; python_version >= "3.8" and python_version < "4.0"
30+
numpy==1.24.4 ; python_version >= "3.8" and python_version < "4.0"
3131
packaging==23.1 ; python_version >= "3.8" and python_version < "4.0"
3232
pathspec==0.11.1 ; python_version >= "3.8" and python_version < "4.0"
33-
pathy==0.10.1 ; python_version >= "3.8" and python_version < "4.0"
33+
pathy==0.10.2 ; python_version >= "3.8" and python_version < "4.0"
3434
plac==1.3.5 ; python_version >= "3.8" and python_version < "4.0"
35-
platformdirs==3.6.0 ; python_version >= "3.8" and python_version < "4.0"
36-
pluggy==1.0.0 ; python_version >= "3.8" and python_version < "4.0"
35+
platformdirs==3.9.1 ; python_version >= "3.8" and python_version < "4.0"
36+
pluggy==1.2.0 ; python_version >= "3.8" and python_version < "4.0"
3737
preshed==3.0.8 ; python_version >= "3.8" and python_version < "4.0"
3838
pycodestyle==2.8.0 ; python_version >= "3.8" and python_version < "4.0"
39-
pydantic==1.10.9 ; python_version >= "3.8" and python_version < "4.0"
39+
pydantic==1.10.11 ; python_version >= "3.8" and python_version < "4.0"
4040
pyflakes==2.4.0 ; python_version >= "3.8" and python_version < "4.0"
4141
pytest-cov==3.0.0 ; python_version >= "3.8" and python_version < "4.0"
42-
pytest==7.3.2 ; python_version >= "3.8" and python_version < "4.0"
43-
pytorch-partial-tagger==0.1.9 ; python_version >= "3.8" and python_version < "4.0"
42+
pytest==7.4.0 ; python_version >= "3.8" and python_version < "4.0"
43+
pytorch-partial-tagger==0.1.12 ; python_version >= "3.8" and python_version < "4.0"
4444
pyyaml==6.0 ; python_version >= "3.8" and python_version < "4.0"
4545
regex==2023.6.3 ; python_version >= "3.8" and python_version < "4.0"
4646
requests==2.31.0 ; python_version >= "3.8" and python_version < "4.0"
4747
rhoknp==1.3.0 ; python_version >= "3.8" and python_version < "4.0"
4848
ruff==0.0.270 ; python_version >= "3.8" and python_version < "4.0"
4949
safetensors==0.3.1 ; python_version >= "3.8" and python_version < "4.0"
50-
setuptools==67.8.0 ; python_version >= "3.8" and python_version < "4.0"
50+
setuptools==68.0.0 ; python_version >= "3.8" and python_version < "4.0"
5151
smart-open==6.3.0 ; python_version >= "3.8" and python_version < "4.0"
5252
spacy-alignments==0.8.6 ; python_version >= "3.8" and python_version < "4.0"
5353
spacy-legacy==3.0.12 ; python_version >= "3.8" and python_version < "4.0"
5454
spacy-loggers==1.0.4 ; python_version >= "3.8" and python_version < "4.0"
5555
spacy-transformers==1.2.5 ; python_version >= "3.8" and python_version < "4.0"
56-
spacy==3.5.3 ; python_version >= "3.8" and python_version < "4.0"
57-
spacy[transformers]==3.5.3 ; python_version >= "3.8" and python_version < "4.0"
56+
spacy==3.6.0 ; python_version >= "3.8" and python_version < "4.0"
57+
spacy[transformers]==3.6.0 ; python_version >= "3.8" and python_version < "4.0"
5858
srsly==2.4.6 ; python_version >= "3.8" and python_version < "4.0"
5959
sudachidict-core==20230110 ; python_version >= "3.8" and python_version < "4.0"
6060
sudachipy==0.6.7 ; python_version >= "3.8" and python_version < "4.0"
@@ -66,8 +66,8 @@ torch==2.0.1 ; python_version >= "3.8" and python_version < "4.0"
6666
tqdm==4.65.0 ; python_version >= "3.8" and python_version < "4.0"
6767
transformers==4.30.2 ; python_version >= "3.8" and python_version < "4.0"
6868
transformers[ja]==4.30.2 ; python_version >= "3.8" and python_version < "4.0"
69-
typer==0.7.0 ; python_version >= "3.8" and python_version < "4.0"
70-
typing-extensions==4.6.3 ; python_version >= "3.8" and python_version < "4.0"
69+
typer==0.9.0 ; python_version >= "3.8" and python_version < "4.0"
70+
typing-extensions==4.7.1 ; python_version >= "3.8" and python_version < "4.0"
7171
unidic-lite==1.0.8 ; python_version >= "3.8" and python_version < "4.0"
7272
unidic==1.1.0 ; python_version >= "3.8" and python_version < "4.0"
7373
urllib3==2.0.3 ; python_version >= "3.8" and python_version < "4.0"

spacy_partial_tagger/pipeline.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33
import srsly
44
import torch
5-
from partial_tagger.data import CharBasedTags, LabelSet
5+
from partial_tagger.data import LabelSet
66
from partial_tagger.data.batch.tag import TagsBatch
7-
from partial_tagger.data.batch.text import create_token_based_tags
87
from partial_tagger.training import compute_partially_supervised_loss
98
from partial_tagger.utils import create_tag
109
from spacy import util
@@ -52,15 +51,15 @@ def set_annotations(
5251
docs: List[Doc],
5352
tag_indices: Floats2d,
5453
) -> None:
55-
tokenized_texts = [doc.user_data["tokenized_text"] for doc in docs]
5654

57-
tags_batch = create_token_based_tags(
58-
tokenized_texts, tag_indices, self.label_set, self.padding_index
59-
)
60-
61-
for doc, tags in zip(docs, tags_batch):
55+
for doc, indices in zip(docs, tag_indices.tolist()):
56+
indices = [index for index in indices if index != self.padding_index]
57+
alignment = doc.user_data["alignment"]
6258
ents = []
63-
for tag in tags:
59+
for tag in alignment.create_char_based_tags(
60+
tag_indices=indices,
61+
label_set=self.label_set,
62+
):
6463
span = doc.char_span(tag.start, tag.start + tag.length, tag.label)
6564
if span:
6665
ents.append(span)
@@ -114,24 +113,26 @@ def get_loss(
114113
) -> Tuple[float, Floats4d]:
115114
scores_pt = xp2torch(scores, requires_grad=True)
116115

117-
token_based_tags = []
116+
char_based_tags = []
117+
alignments = []
118118
lengths = []
119119
for example in examples:
120120
tags = tuple(
121121
create_tag(ent.start_char, len(ent.text), ent.label_)
122122
for ent in example.y.ents
123123
)
124-
tokenized_text = example.x.user_data["tokenized_text"]
125-
token_based_tags.append(
126-
CharBasedTags(tags, example.x.text).convert_to_token_based(
127-
tokenized_text
128-
)
129-
)
130-
lengths.append(tokenized_text.num_tokens)
124+
char_based_tags.append(tags)
131125

132-
tags_batch = TagsBatch(tuple(token_based_tags), self.label_set)
126+
alignment = example.x.user_data["alignment"]
127+
lengths.append(alignment.num_tokens)
128+
alignments.append(alignment)
129+
130+
tags_batch = TagsBatch(
131+
tags_batch=tuple(char_based_tags),
132+
alignments=alignments,
133+
)
133134
tags_batch.to(scores_pt.device)
134-
tag_bitmap = tags_batch.get_tag_bitmap()
135+
tag_bitmap = tags_batch.get_tag_bitmap(self.label_set)
135136

136137
max_length = max(lengths)
137138
mask = torch.tensor(

spacy_partial_tagger/tagger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def forward(
4747

4848
text_batch = tokenizer(tuple(doc.text for doc in X))
4949

50-
for doc, text in zip(X, text_batch.tokenized_texts):
51-
doc.user_data["tokenized_text"] = text
50+
for doc, alignment in zip(X, text_batch.alignments):
51+
doc.user_data["alignment"] = alignment
5252

5353
device = get_torch_default_device()
5454
text_batch.to(device)

spacy_partial_tagger/tokenizer.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Optional, Tuple
22

3-
import torch
4-
from partial_tagger.data import Span, TokenizedText
3+
from partial_tagger.data import Alignment, Span
54
from partial_tagger.data.batch.text import (
65
BaseTokenizer,
76
TextBatch,
@@ -26,6 +25,7 @@ def __init__(
2625

2726
self.__tokenizer_args = tokenizer_args or {
2827
"padding": True,
28+
"truncation": True,
2929
"return_tensors": "pt",
3030
}
3131
self.__tokenizer_args["return_offsets_mapping"] = True
@@ -34,9 +34,10 @@ def __call__(self, texts: Tuple[str]) -> TextBatch:
3434
batch_encoding = self.__tokenizer(texts, **self.__tokenizer_args)
3535

3636
pad_token_id = self.__tokenizer.pad_token_id
37-
tokenized_text_lengths = (batch_encoding.input_ids != pad_token_id).sum(dim=1)
37+
mask = batch_encoding.input_ids != pad_token_id
38+
tokenized_text_lengths = mask.sum(dim=1)
3839

39-
tokenized_texts = []
40+
alignments = []
4041
for _tokenized_text_length, input_ids, text in zip(
4142
tokenized_text_lengths, batch_encoding.input_ids, texts
4243
):
@@ -52,16 +53,11 @@ def __call__(self, texts: Tuple[str]) -> TextBatch:
5253
end = char_span.start + char_span.length
5354
token_indices[start:end] = [token_index] * char_span.length
5455

55-
tokenized_texts.append(
56-
TokenizedText(text, char_spans, tuple(token_indices))
57-
)
56+
alignments.append(Alignment(text, char_spans, tuple(token_indices)))
5857

59-
lengths = [text.num_tokens for text in tokenized_texts]
60-
max_length = max(lengths)
61-
mask = torch.tensor(
62-
[[True] * length + [False] * (max_length - length) for length in lengths]
58+
return TextBatch(
59+
tagger_inputs=batch_encoding, mask=mask, alignments=tuple(alignments)
6360
)
64-
return TextBatch(tuple(tokenized_texts), batch_encoding, mask)
6561

6662

6763
def get_tokenizer(

0 commit comments

Comments
 (0)