Skip to content
This repository was archived by the owner on Sep 13, 2023. It is now read-only.

Commit 9518156

Browse files
DictType: added support for int keys (#298)
* DictType: added support for int keys * incorporated review comments * move fixtures Co-authored-by: Alexander Guschin <[email protected]>
1 parent 9c4c650 commit 9518156

File tree

2 files changed

+107
-8
lines changed

2 files changed

+107
-8
lines changed

mlem/core/data_type.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020

2121
import flatdict
22-
from pydantic import BaseModel, validator
22+
from pydantic import BaseModel, StrictInt, StrictStr, validator
2323
from pydantic.main import create_model
2424

2525
from mlem.core.artifacts import Artifacts, Storage
@@ -514,7 +514,7 @@ class DictType(DataType, DataSerializer):
514514
"""
515515

516516
type: ClassVar[str] = "dict"
517-
item_types: Dict[str, DataType]
517+
item_types: Dict[Union[StrictStr, StrictInt], DataType]
518518

519519
@classmethod
520520
def process(cls, obj, **kwargs):
@@ -563,7 +563,7 @@ def get_writer(
563563

564564
def get_model(self, prefix="") -> Type[BaseModel]:
565565
kwargs = {
566-
k: (v.get_serializer().get_model(prefix + k + "_"), ...)
566+
str(k): (v.get_serializer().get_model(prefix + str(k) + "_"), ...)
567567
for k, v in self.item_types.items()
568568
}
569569
return create_model(prefix + "DictType", **kwargs) # type: ignore
@@ -585,9 +585,9 @@ def write(
585585
dtype_reader, art = dtype.get_writer().write(
586586
dtype.copy().bind(data.data[key]),
587587
storage,
588-
posixpath.join(path, key),
588+
posixpath.join(path, str(key)),
589589
)
590-
res[key] = art
590+
res[str(key)] = art
591591
readers[key] = dtype_reader
592592
return DictReader(data_type=data, item_readers=readers), dict(
593593
flatdict.FlatterDict(res, delimiter="/")
@@ -597,13 +597,13 @@ def write(
597597
class DictReader(DataReader):
598598
type: ClassVar[str] = "dict"
599599
data_type: DictType
600-
item_readers: Dict[str, DataReader]
600+
item_readers: Dict[Union[StrictStr, StrictInt], DataReader]
601601

602602
def read(self, artifacts: Artifacts) -> DataType:
603603
artifacts = flatdict.FlatterDict(artifacts, delimiter="/")
604604
data_dict = {}
605605
for (key, dtype_reader) in self.item_readers.items():
606-
v_data_type = dtype_reader.read(artifacts[key]) # type: ignore
606+
v_data_type = dtype_reader.read(artifacts[str(key)]) # type: ignore
607607
data_dict[key] = v_data_type.data
608608
return self.data_type.copy().bind(data_dict)
609609

tests/core/test_data_type.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,59 @@
2525
from tests.conftest import data_write_read_check
2626

2727

28+
@pytest.fixture
29+
def dict_data_str_keys():
30+
data = {"1": 1.5, "2": "a", "3": {"1": False}}
31+
payload = {
32+
"item_types": {
33+
"1": {"ptype": "float", "type": "primitive"},
34+
"2": {"ptype": "str", "type": "primitive"},
35+
"3": {
36+
"item_types": {"1": {"ptype": "bool", "type": "primitive"}},
37+
"type": "dict",
38+
},
39+
},
40+
"type": "dict",
41+
}
42+
schema = {
43+
"title": "DictType",
44+
"type": "object",
45+
"properties": {
46+
"1": {"title": "1", "type": "number"},
47+
"2": {"title": "2", "type": "string"},
48+
"3": {"$ref": "#/definitions/3_DictType"},
49+
},
50+
"required": ["1", "2", "3"],
51+
"definitions": {
52+
"3_DictType": {
53+
"title": "3_DictType",
54+
"type": "object",
55+
"properties": {"1": {"title": "1", "type": "boolean"}},
56+
"required": ["1"],
57+
}
58+
},
59+
}
60+
61+
return data, payload, schema
62+
63+
64+
@pytest.fixture
65+
def dict_data_int_keys(dict_data_str_keys):
66+
data = {"1": 1.5, 2: "a", "3": {1: False}}
67+
payload = {
68+
"item_types": {
69+
"1": {"ptype": "float", "type": "primitive"},
70+
2: {"ptype": "str", "type": "primitive"},
71+
"3": {
72+
"item_types": {1: {"ptype": "bool", "type": "primitive"}},
73+
"type": "dict",
74+
},
75+
},
76+
"type": "dict",
77+
}
78+
return data, payload, dict_data_str_keys[2]
79+
80+
2881
class NotPrimitive:
2982
pass
3083

@@ -300,7 +353,6 @@ def dynamic_dict_str_val_type_data():
300353
"type": "object",
301354
"additionalProperties": {"type": "string"},
302355
}
303-
304356
test_data1 = {"a": "1", "b": "2"}
305357
test_data2 = {"a": "1"}
306358
test_data3 = {"a": "1", "b": "2", "c": "3", "d": "1"}
@@ -551,3 +603,50 @@ def custom_assert(x, y):
551603
else:
552604
assert list(artifacts.keys()) == ["data"]
553605
assert artifacts["data"].uri.endswith("data")
606+
607+
608+
@pytest.mark.parametrize(
609+
"dict_data",
610+
[lazy_fixture("dict_data_str_keys"), lazy_fixture("dict_data_int_keys")],
611+
)
612+
def test_dict_key_int_and_str_types(dict_data):
613+
d_value, payload, schema = dict_data
614+
data_type = DataType.create(d_value)
615+
616+
assert isinstance(data_type, DictType)
617+
618+
assert data_type.dict() == payload
619+
dt2 = parse_obj_as(DictType, payload)
620+
assert dt2 == data_type
621+
assert d_value == data_type.serialize(d_value)
622+
assert d_value == data_type.deserialize(d_value)
623+
assert data_type.get_model().__name__ == "DictType"
624+
assert data_type.get_model().schema() == schema
625+
626+
627+
@pytest.mark.parametrize(
628+
"d_value",
629+
[
630+
{"1": 1.5, "2": "a", "3": {"1": False}},
631+
{"1": 1.5, 2: "a", "3": {1: False}},
632+
],
633+
)
634+
def test_dict_source_int_and_str_types(d_value):
635+
data_type = DataType.create(d_value)
636+
637+
def custom_assert(x, y):
638+
assert x == y
639+
assert len(x) == len(y)
640+
assert isinstance(x, dict)
641+
assert isinstance(y, dict)
642+
643+
artifacts = data_write_read_check(
644+
data_type,
645+
reader_type=DictReader,
646+
custom_assert=custom_assert,
647+
)
648+
649+
assert list(artifacts.keys()) == ["1/data", "2/data", "3/1/data"]
650+
assert artifacts["1/data"].uri.endswith("data/1")
651+
assert artifacts["2/data"].uri.endswith("data/2")
652+
assert artifacts["3/1/data"].uri.endswith("data/3/1")

0 commit comments

Comments
 (0)