diff --git a/pyproject.toml b/pyproject.toml index c306cbe..9c0711a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,8 @@ classifiers = [ dependencies = [ "numpy", "orjson", - "tqdm" + "tqdm", + "msgspec" ] [build-system] diff --git a/uv.lock b/uv.lock index 97c156f..35c7d9a 100644 --- a/uv.lock +++ b/uv.lock @@ -299,7 +299,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -1748,7 +1748,7 @@ name = "tqdm" version = "4.67.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/e8/4f/0153c21dc5779a49a0598c445b1978126b1344bab9ee71e53e44877e14e0/tqdm-4.67.0.tar.gz", hash = "sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a", size = 169739 } wheels = [ @@ -1841,7 +1841,7 @@ wheels = [ [[package]] name = "vicinity" -version = "0.4.0" +version = "0.4.1" source = { editable = "." } dependencies = [ { name = "numpy" }, diff --git a/vicinity/datatypes.py b/vicinity/datatypes.py index dcad8f9..63b131b 100644 --- a/vicinity/datatypes.py +++ b/vicinity/datatypes.py @@ -22,3 +22,8 @@ class Backend(str, Enum): FAISS = "faiss" USEARCH = "usearch" VOYAGER = "voyager" + + +class ItemType(str, Enum): + JSON = "json" + MSGPACK = "msgpack" diff --git a/vicinity/items.py b/vicinity/items.py new file mode 100644 index 0000000..3b684c1 --- /dev/null +++ b/vicinity/items.py @@ -0,0 +1,28 @@ +from collections.abc import Sequence + +from msgspec.json import decode as json_decode +from msgspec.json import encode as json_encode +from msgspec.msgpack import decode as msgpack_decode +from msgspec.msgpack import encode as msgpack_encode + +from vicinity.datatypes import ItemType + + +def encode_msgspec(data: Sequence[str], itemtype: ItemType) -> bytes: + """Encode a list of strings to a format supported by msgpack.""" + if itemtype == ItemType.JSON: + return json_encode(data) + elif itemtype == ItemType.MSGPACK: + return msgpack_encode(data) + else: + raise ValueError(f"Unknown item type: {itemtype}") + + +def decode_msgpack(data: bytes, itemtype: ItemType) -> Sequence[str]: + """Decode bytes to a list of strings based on the item type.""" + if itemtype == ItemType.JSON: + return json_decode(data) + elif itemtype == ItemType.MSGPACK: + return msgpack_decode(data) + else: + raise ValueError(f"Unknown item type: {itemtype}") diff --git a/vicinity/vicinity.py b/vicinity/vicinity.py index e9b7ff0..e2ead3d 100644 --- a/vicinity/vicinity.py +++ b/vicinity/vicinity.py @@ -15,7 +15,8 @@ from vicinity import Metric from vicinity.backends import AbstractBackend, BasicBackend, BasicVectorStore, get_backend_class -from vicinity.datatypes import Backend, PathLike, QueryResult +from vicinity.datatypes import Backend, ItemType, PathLike, QueryResult +from vicinity.items import decode_msgpack, encode_msgspec logger = logging.getLogger(__name__) @@ -165,9 +166,7 @@ def query_threshold( return out def save( - self, - folder: PathLike, - overwrite: bool = False, + self, folder: PathLike, overwrite: bool = False, serialize_type: ItemType | str = ItemType.MSGPACK ) -> None: """ Save a Vicinity instance in a fast format. @@ -177,21 +176,36 @@ def save( :param folder: The path to which to save the JSON file. The vectors are saved separately. The JSON contains a path to the numpy file. :param overwrite: Whether to overwrite the JSON and numpy files if they already exist. + :param serialize_type: The serialization type to use for the items. Defaults to MSGPACK. :raises ValueError: If the path is not a directory. :raises JSONEncodeError: If the items are not serializable. """ path = Path(folder) path.mkdir(parents=True, exist_ok=overwrite) + serialize_type = ItemType(serialize_type) + if not path.is_dir(): raise ValueError(f"Path {path} should be a directory.") - items_dict = {"items": self.items, "metadata": self.metadata, "backend_type": self.backend.backend_type.value} + items_dict = { + "metadata": self.metadata, + "backend_type": self.backend.backend_type.value, + "items_type": serialize_type.value, + } try: with open(path / "data.json", "wb") as file_handle: file_handle.write(orjson.dumps(items_dict)) except JSONEncodeError as e: - raise JSONEncodeError(f"Items could not be encoded to JSON because they are not serializable: {e}") + raise JSONEncodeError(f"Metadata could not be encoded to JSON because they are not serializable: {e}") + try: + item_bytes = encode_msgspec(self.items, serialize_type) + except ValueError as e: + raise ValueError( + f"Items could not be encoded to {serialize_type.value} because they are not serializable: {e}" + ) + with open(path / f"items.{serialize_type.value}", "wb") as file_handle: + file_handle.write(item_bytes) self.backend.save(path) if self.vector_store is not None: @@ -215,7 +229,15 @@ def load(cls, filename: PathLike) -> Vicinity: with open(folder_path / "data.json", "rb") as file_handle: data: dict[str, Any] = orjson.loads(file_handle.read()) - items: Sequence[Any] = data["items"] + + if "items" in data: + logger.warning("The 'items' key is deprecated. Please save this vicinity instance and load it again.") + items: Sequence[Any] = data["items"] + else: + item_type = ItemType(data["items_type"]) + items = [] + with open(folder_path / f"items.{item_type.value}", "rb") as file_handle: + items = decode_msgpack(file_handle.read(), item_type) metadata: dict[str, Any] = data["metadata"] backend_type = Backend(data["backend_type"])