Skip to content

Commit 02b1a66

Browse files
committed
Redis(Cache) now (re)supports function as a key_prefix
1 parent 620e215 commit 02b1a66

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

Diff for: src/cachelib/redis.py

+20-16
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ def __init__(
5757
self._read_client = self._write_client = host
5858
self.key_prefix = key_prefix or ""
5959

60+
def _get_prefix(self):
61+
return (
62+
self.key_prefix if isinstance(self.key_prefix, str) else self.key_prefix()
63+
)
64+
6065
def _normalize_timeout(self, timeout: _t.Optional[int]) -> int:
6166
"""Normalize timeout by setting it to default of 300 if
6267
not defined (None) or -1 if explicitly set to zero.
@@ -69,11 +74,11 @@ def _normalize_timeout(self, timeout: _t.Optional[int]) -> int:
6974
return timeout
7075

7176
def get(self, key: str) -> _t.Any:
72-
return self.serializer.loads(self._read_client.get(self.key_prefix + key))
77+
return self.serializer.loads(self._read_client.get(self._get_prefix() + key))
7378

7479
def get_many(self, *keys: str) -> _t.List[_t.Any]:
7580
if self.key_prefix:
76-
prefixed_keys = [self.key_prefix + key for key in keys]
81+
prefixed_keys = [self._get_prefix() + key for key in keys]
7782
else:
7883
prefixed_keys = list(keys)
7984
return [self.serializer.loads(x) for x in self._read_client.mget(prefixed_keys)]
@@ -82,20 +87,20 @@ def set(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.A
8287
timeout = self._normalize_timeout(timeout)
8388
dump = self.serializer.dumps(value)
8489
if timeout == -1:
85-
result = self._write_client.set(name=self.key_prefix + key, value=dump)
90+
result = self._write_client.set(name=self._get_prefix() + key, value=dump)
8691
else:
8792
result = self._write_client.setex(
88-
name=self.key_prefix + key, value=dump, time=timeout
93+
name=self._get_prefix() + key, value=dump, time=timeout
8994
)
9095
return result
9196

9297
def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any:
9398
timeout = self._normalize_timeout(timeout)
9499
dump = self.serializer.dumps(value)
95-
created = self._write_client.setnx(name=self.key_prefix + key, value=dump)
100+
created = self._write_client.setnx(name=self._get_prefix() + key, value=dump)
96101
# handle case where timeout is explicitly set to zero
97102
if created and timeout != -1:
98-
self._write_client.expire(name=self.key_prefix + key, time=timeout)
103+
self._write_client.expire(name=self._get_prefix() + key, time=timeout)
99104
return created
100105

101106
def set_many(
@@ -109,41 +114,40 @@ def set_many(
109114
for key, value in mapping.items():
110115
dump = self.serializer.dumps(value)
111116
if timeout == -1:
112-
pipe.set(name=self.key_prefix + key, value=dump)
117+
pipe.set(name=self._get_prefix() + key, value=dump)
113118
else:
114-
pipe.setex(name=self.key_prefix + key, value=dump, time=timeout)
119+
pipe.setex(name=self._get_prefix() + key, value=dump, time=timeout)
115120
results = pipe.execute()
116-
res = zip(mapping.keys(), results) # noqa: B905
117-
return [k for k, was_set in res if was_set]
121+
return [k for k, was_set in zip(mapping.keys(), results) if was_set]
118122

119123
def delete(self, key: str) -> bool:
120-
return bool(self._write_client.delete(self.key_prefix + key))
124+
return bool(self._write_client.delete(self._get_prefix() + key))
121125

122126
def delete_many(self, *keys: str) -> _t.List[_t.Any]:
123127
if not keys:
124128
return []
125129
if self.key_prefix:
126-
prefixed_keys = [self.key_prefix + key for key in keys]
130+
prefixed_keys = [self._get_prefix() + key for key in keys]
127131
else:
128132
prefixed_keys = [k for k in keys]
129133
self._write_client.delete(*prefixed_keys)
130134
return [k for k in prefixed_keys if not self.has(k)]
131135

132136
def has(self, key: str) -> bool:
133-
return bool(self._read_client.exists(self.key_prefix + key))
137+
return bool(self._read_client.exists(self._get_prefix() + key))
134138

135139
def clear(self) -> bool:
136140
status = 0
137141
if self.key_prefix:
138-
keys = self._read_client.keys(self.key_prefix + "*")
142+
keys = self._read_client.keys(self._get_prefix() + "*")
139143
if keys:
140144
status = self._write_client.delete(*keys)
141145
else:
142146
status = self._write_client.flushdb()
143147
return bool(status)
144148

145149
def inc(self, key: str, delta: int = 1) -> _t.Any:
146-
return self._write_client.incr(name=self.key_prefix + key, amount=delta)
150+
return self._write_client.incr(name=self._get_prefix() + key, amount=delta)
147151

148152
def dec(self, key: str, delta: int = 1) -> _t.Any:
149-
return self._write_client.incr(name=self.key_prefix + key, amount=-delta)
153+
return self._write_client.incr(name=self._get_prefix() + key, amount=-delta)

0 commit comments

Comments
 (0)