Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ classifiers = [
dependencies = [
"numpy",
"orjson",
"tqdm"
"tqdm",
"msgspec"
]

[build-system]
Expand Down
6 changes: 3 additions & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions vicinity/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,8 @@ class Backend(str, Enum):
FAISS = "faiss"
USEARCH = "usearch"
VOYAGER = "voyager"


class ItemType(str, Enum):
JSON = "json"
MSGPACK = "msgpack"
28 changes: 28 additions & 0 deletions vicinity/items.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from collections.abc import Sequence

from msgspec.json import decode as json_decode
from msgspec.json import encode as json_encode
from msgspec.msgpack import decode as msgpack_decode
from msgspec.msgpack import encode as msgpack_encode

from vicinity.datatypes import ItemType


def encode_msgspec(data: Sequence[str], itemtype: ItemType) -> bytes:
"""Encode a list of strings to a format supported by msgpack."""
if itemtype == ItemType.JSON:
return json_encode(data)
elif itemtype == ItemType.MSGPACK:
return msgpack_encode(data)
else:
raise ValueError(f"Unknown item type: {itemtype}")


def decode_msgpack(data: bytes, itemtype: ItemType) -> Sequence[str]:
"""Decode bytes to a list of strings based on the item type."""
if itemtype == ItemType.JSON:
return json_decode(data)
elif itemtype == ItemType.MSGPACK:
return msgpack_decode(data)
else:
raise ValueError(f"Unknown item type: {itemtype}")
36 changes: 29 additions & 7 deletions vicinity/vicinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from vicinity import Metric
from vicinity.backends import AbstractBackend, BasicBackend, BasicVectorStore, get_backend_class
from vicinity.datatypes import Backend, PathLike, QueryResult
from vicinity.datatypes import Backend, ItemType, PathLike, QueryResult
from vicinity.items import decode_msgpack, encode_msgspec

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -165,9 +166,7 @@ def query_threshold(
return out

def save(
self,
folder: PathLike,
overwrite: bool = False,
self, folder: PathLike, overwrite: bool = False, serialize_type: ItemType | str = ItemType.MSGPACK
) -> None:
"""
Save a Vicinity instance in a fast format.
Expand All @@ -177,21 +176,36 @@ def save(

:param folder: The path to which to save the JSON file. The vectors are saved separately. The JSON contains a path to the numpy file.
:param overwrite: Whether to overwrite the JSON and numpy files if they already exist.
:param serialize_type: The serialization type to use for the items. Defaults to MSGPACK.
:raises ValueError: If the path is not a directory.
:raises JSONEncodeError: If the items are not serializable.
"""
path = Path(folder)
path.mkdir(parents=True, exist_ok=overwrite)

serialize_type = ItemType(serialize_type)

if not path.is_dir():
raise ValueError(f"Path {path} should be a directory.")

items_dict = {"items": self.items, "metadata": self.metadata, "backend_type": self.backend.backend_type.value}
items_dict = {
"metadata": self.metadata,
"backend_type": self.backend.backend_type.value,
"items_type": serialize_type.value,
}
try:
with open(path / "data.json", "wb") as file_handle:
file_handle.write(orjson.dumps(items_dict))
except JSONEncodeError as e:
raise JSONEncodeError(f"Items could not be encoded to JSON because they are not serializable: {e}")
raise JSONEncodeError(f"Metadata could not be encoded to JSON because they are not serializable: {e}")
try:
item_bytes = encode_msgspec(self.items, serialize_type)
except ValueError as e:
raise ValueError(
f"Items could not be encoded to {serialize_type.value} because they are not serializable: {e}"
)
with open(path / f"items.{serialize_type.value}", "wb") as file_handle:
file_handle.write(item_bytes)

self.backend.save(path)
if self.vector_store is not None:
Expand All @@ -215,7 +229,15 @@ def load(cls, filename: PathLike) -> Vicinity:

with open(folder_path / "data.json", "rb") as file_handle:
data: dict[str, Any] = orjson.loads(file_handle.read())
items: Sequence[Any] = data["items"]

if "items" in data:
logger.warning("The 'items' key is deprecated. Please save this vicinity instance and load it again.")
items: Sequence[Any] = data["items"]
else:
item_type = ItemType(data["items_type"])
items = []
with open(folder_path / f"items.{item_type.value}", "rb") as file_handle:
items = decode_msgpack(file_handle.read(), item_type)

metadata: dict[str, Any] = data["metadata"]
backend_type = Backend(data["backend_type"])
Expand Down