Source code for litegram.fsm.storage.redis

from __future__ import annotations

import json
from collections.abc import AsyncGenerator, Callable, Mapping
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, cast

from redis.asyncio.client import Redis
from redis.asyncio.connection import ConnectionPool
from redis.asyncio.lock import Lock

from litegram.exceptions import DataNotDictLikeError
from litegram.fsm.state import State
from litegram.fsm.storage.base import (
    BaseEventIsolation,
    BaseStorage,
    DefaultKeyBuilder,
    KeyBuilder,
    StateType,
    StorageKey,
)

if TYPE_CHECKING:
    from redis.typing import ExpiryT

DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60}
_JsonLoads = Callable[..., Any]
_JsonDumps = Callable[..., str]


[docs] class RedisStorage(BaseStorage): """ Redis storage required :code:`redis` package installed (:code:`pip install redis`) """
[docs] def __init__( self, redis: Redis, key_builder: KeyBuilder | None = None, state_ttl: ExpiryT | None = None, data_ttl: ExpiryT | None = None, json_loads: _JsonLoads = json.loads, json_dumps: _JsonDumps = json.dumps, ) -> None: """ :param redis: Instance of Redis connection :param key_builder: builder that helps to convert contextual key to string :param state_ttl: TTL for state records :param data_ttl: TTL for data records """ if key_builder is None: key_builder = DefaultKeyBuilder() self.redis = redis self.key_builder = key_builder self.state_ttl = state_ttl self.data_ttl = data_ttl self.json_loads = json_loads self.json_dumps = json_dumps
[docs] @classmethod def from_url( cls, url: str, connection_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> RedisStorage: """ Create an instance of :class:`RedisStorage` with specifying the connection string :param url: for example :code:`redis://user:password@host:port/db` :param connection_kwargs: see :code:`redis` docs :param kwargs: arguments to be passed to :class:`RedisStorage` :return: an instance of :class:`RedisStorage` """ if connection_kwargs is None: connection_kwargs = {} pool = ConnectionPool.from_url(url, **connection_kwargs) redis = Redis(connection_pool=pool) return cls(redis=redis, **kwargs)
def create_isolation(self, **kwargs: Any) -> RedisEventIsolation: return RedisEventIsolation(redis=self.redis, key_builder=self.key_builder, **kwargs) async def close(self) -> None: await self.redis.aclose(close_connection_pool=True) async def set_state( self, key: StorageKey, state: StateType = None, ) -> None: redis_key = self.key_builder.build(key, "state") if state is None: await self.redis.delete(redis_key) else: await self.redis.set( redis_key, cast("str", state.state if isinstance(state, State) else state), ex=self.state_ttl, ) async def get_state( self, key: StorageKey, ) -> str | None: redis_key = self.key_builder.build(key, "state") value = await self.redis.get(redis_key) if isinstance(value, bytes): return value.decode("utf-8") return cast("str | None", value) async def set_data( self, key: StorageKey, data: Mapping[str, Any], ) -> None: if not isinstance(data, dict): msg = f"Data must be a dict or dict-like object, got {type(data).__name__}" raise DataNotDictLikeError(msg) redis_key = self.key_builder.build(key, "data") if not data: await self.redis.delete(redis_key) return await self.redis.set( redis_key, self.json_dumps(data), ex=self.data_ttl, ) async def get_data( self, key: StorageKey, ) -> dict[str, Any]: redis_key = self.key_builder.build(key, "data") value = await self.redis.get(redis_key) if value is None: return {} if isinstance(value, bytes): value = value.decode("utf-8") return cast("dict[str, Any]", self.json_loads(value))
class RedisEventIsolation(BaseEventIsolation): def __init__( self, redis: Redis, key_builder: KeyBuilder | None = None, lock_kwargs: dict[str, Any] | None = None, ) -> None: if key_builder is None: key_builder = DefaultKeyBuilder() if lock_kwargs is None: lock_kwargs = DEFAULT_REDIS_LOCK_KWARGS self.redis = redis self.key_builder = key_builder self.lock_kwargs = lock_kwargs @classmethod def from_url( cls, url: str, connection_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> RedisEventIsolation: if connection_kwargs is None: connection_kwargs = {} pool = ConnectionPool.from_url(url, **connection_kwargs) redis = Redis(connection_pool=pool) return cls(redis=redis, **kwargs) @asynccontextmanager async def lock( self, key: StorageKey, ) -> AsyncGenerator[None]: redis_key = self.key_builder.build(key, "lock") async with self.redis.lock(name=redis_key, **self.lock_kwargs, lock_class=Lock): yield None async def close(self) -> None: pass