Source code for litegram.fsm.storage.memory
from __future__ import annotations
from asyncio import Lock
from collections import defaultdict
from contextlib import asynccontextmanager
from copy import copy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, overload
from litegram.exceptions import DataNotDictLikeError
from litegram.fsm.state import State
from litegram.fsm.storage.base import (
BaseEventIsolation,
BaseStorage,
StateType,
StorageKey,
)
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Hashable, Mapping
@dataclass
class MemoryStorageRecord:
data: dict[str, Any] = field(default_factory=dict)
state: str | None = None
[docs]
class MemoryStorage(BaseStorage):
"""
Default FSM storage, stores all data in :class:`dict` and loss everything on shutdown
.. warning::
Is not recommended using in production in due to you will lose all data
when your bot restarts
"""
[docs]
def __init__(self) -> None:
self.storage: defaultdict[StorageKey, MemoryStorageRecord] = defaultdict(
MemoryStorageRecord,
)
async def close(self) -> None:
pass
async def set_state(self, key: StorageKey, state: StateType = None) -> None:
self.storage[key].state = state.state if isinstance(state, State) else state
async def get_state(self, key: StorageKey) -> str | None:
return self.storage[key].state
async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None:
if not isinstance(data, dict):
msg = f"Data must be a dict or dict-like object, got {type(data).__name__}"
raise DataNotDictLikeError(msg)
self.storage[key].data = data.copy()
async def get_data(self, key: StorageKey) -> dict[str, Any]:
return self.storage[key].data.copy()
@overload
async def get_value(self, storage_key: StorageKey, dict_key: str) -> Any | None: ...
@overload
async def get_value(self, storage_key: StorageKey, dict_key: str, default: Any) -> Any: ...
async def get_value(
self,
storage_key: StorageKey,
dict_key: str,
default: Any | None = None,
) -> Any | None:
data = self.storage[storage_key].data
return copy(data.get(dict_key, default))
class DisabledEventIsolation(BaseEventIsolation):
@asynccontextmanager
async def lock(self, key: StorageKey) -> AsyncGenerator[None]:
yield
async def close(self) -> None:
pass
class SimpleEventIsolation(BaseEventIsolation):
def __init__(self) -> None:
# TODO: Unused locks cleaner is needed
self._locks: defaultdict[Hashable, Lock] = defaultdict(Lock)
@asynccontextmanager
async def lock(self, key: StorageKey) -> AsyncGenerator[None]:
lock = self._locks[key]
async with lock:
yield
async def close(self) -> None:
self._locks.clear()