1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
|
from __future__ import annotations
from datetime import timedelta
from typing import TYPE_CHECKING, cast
from redis.asyncio import Redis
from redis.asyncio.connection import ConnectionPool
from litestar.exceptions import ImproperlyConfiguredException
from litestar.types import Empty, EmptyType
from litestar.utils.empty import value_or_default
from .base import NamespacedStore
__all__ = ("RedisStore",)
if TYPE_CHECKING:
from types import TracebackType
class RedisStore(NamespacedStore):
"""Redis based, thread and process safe asynchronous key/value store."""
__slots__ = ("_redis",)
def __init__(
self, redis: Redis, namespace: str | None | EmptyType = Empty, handle_client_shutdown: bool = False
) -> None:
"""Initialize :class:`RedisStore`
Args:
redis: An :class:`redis.asyncio.Redis` instance
namespace: A key prefix to simulate a namespace in redis. If not given,
defaults to ``LITESTAR``. Namespacing can be explicitly disabled by passing
``None``. This will make :meth:`.delete_all` unavailable.
handle_client_shutdown: If ``True``, handle the shutdown of the `redis` instance automatically during the store's lifespan. Should be set to `True` unless the shutdown is handled externally
"""
self._redis = redis
self.namespace: str | None = value_or_default(namespace, "LITESTAR")
self.handle_client_shutdown = handle_client_shutdown
# script to get and renew a key in one atomic step
self._get_and_renew_script = self._redis.register_script(
b"""
local key = KEYS[1]
local renew = tonumber(ARGV[1])
local data = redis.call('GET', key)
local ttl = redis.call('TTL', key)
if ttl > 0 then
redis.call('EXPIRE', key, renew)
end
return data
"""
)
# script to delete all keys in the namespace
self._delete_all_script = self._redis.register_script(
b"""
local cursor = 0
repeat
local result = redis.call('SCAN', cursor, 'MATCH', ARGV[1])
for _,key in ipairs(result[2]) do
redis.call('UNLINK', key)
end
cursor = tonumber(result[1])
until cursor == 0
"""
)
async def _shutdown(self) -> None:
if self.handle_client_shutdown:
await self._redis.aclose(close_connection_pool=True) # type: ignore[attr-defined]
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self._shutdown()
@classmethod
def with_client(
cls,
url: str = "redis://localhost:6379",
*,
db: int | None = None,
port: int | None = None,
username: str | None = None,
password: str | None = None,
namespace: str | None | EmptyType = Empty,
) -> RedisStore:
"""Initialize a :class:`RedisStore` instance with a new class:`redis.asyncio.Redis` instance.
Args:
url: Redis URL to connect to
db: Redis database to use
port: Redis port to use
username: Redis username to use
password: Redis password to use
namespace: Virtual key namespace to use
"""
pool = ConnectionPool.from_url(
url=url,
db=db,
decode_responses=False,
port=port,
username=username,
password=password,
)
return cls(
redis=Redis(connection_pool=pool),
namespace=namespace,
handle_client_shutdown=True,
)
def with_namespace(self, namespace: str) -> RedisStore:
"""Return a new :class:`RedisStore` with a nested virtual key namespace.
The current instances namespace will serve as a prefix for the namespace, so it
can be considered the parent namespace.
"""
return type(self)(
redis=self._redis,
namespace=f"{self.namespace}_{namespace}" if self.namespace else namespace,
handle_client_shutdown=self.handle_client_shutdown,
)
def _make_key(self, key: str) -> str:
prefix = f"{self.namespace}:" if self.namespace else ""
return prefix + key
async def set(self, key: str, value: str | bytes, expires_in: int | timedelta | None = None) -> None:
"""Set a value.
Args:
key: Key to associate the value with
value: Value to store
expires_in: Time in seconds before the key is considered expired
Returns:
``None``
"""
if isinstance(value, str):
value = value.encode("utf-8")
await self._redis.set(self._make_key(key), value, ex=expires_in)
async def get(self, key: str, renew_for: int | timedelta | None = None) -> bytes | None:
"""Get a value.
Args:
key: Key associated with the value
renew_for: If given and the value had an initial expiry time set, renew the
expiry time for ``renew_for`` seconds. If the value has not been set
with an expiry time this is a no-op. Atomicity of this step is guaranteed
by using a lua script to execute fetch and renewal. If ``renew_for`` is
not given, the script will be bypassed so no overhead will occur
Returns:
The value associated with ``key`` if it exists and is not expired, else
``None``
"""
key = self._make_key(key)
if renew_for:
if isinstance(renew_for, timedelta):
renew_for = renew_for.seconds
data = await self._get_and_renew_script(keys=[key], args=[renew_for])
return cast("bytes | None", data)
return await self._redis.get(key)
async def delete(self, key: str) -> None:
"""Delete a value.
If no such key exists, this is a no-op.
Args:
key: Key of the value to delete
"""
await self._redis.delete(self._make_key(key))
async def delete_all(self) -> None:
"""Delete all stored values in the virtual key namespace.
Raises:
ImproperlyConfiguredException: If no namespace was configured
"""
if not self.namespace:
raise ImproperlyConfiguredException("Cannot perform delete operation: No namespace configured")
await self._delete_all_script(keys=[], args=[f"{self.namespace}*:*"])
async def exists(self, key: str) -> bool:
"""Check if a given ``key`` exists."""
return await self._redis.exists(self._make_key(key)) == 1
async def expires_in(self, key: str) -> int | None:
"""Get the time in seconds ``key`` expires in. If no such ``key`` exists or no
expiry time was set, return ``None``.
"""
ttl = await self._redis.ttl(self._make_key(key))
return None if ttl == -2 else ttl
|