import logging import sys from abc import ABC from asyncio import IncompleteReadError, StreamReader, TimeoutError from typing import Awaitable, Callable, List, Optional, Protocol, Union from redis.maint_notifications import ( MaintenanceNotification, NodeFailedOverNotification, NodeFailingOverNotification, NodeMigratedNotification, NodeMigratingNotification, NodeMovingNotification, OSSNodeMigratedNotification, OSSNodeMigratingNotification, ) from redis.utils import safe_str if sys.version_info.major >= 3 and sys.version_info.minor >= 11: from asyncio import timeout as async_timeout else: from async_timeout import timeout as async_timeout from ..exceptions import ( AskError, AuthenticationError, AuthenticationWrongNumberOfArgsError, BusyLoadingError, ClusterCrossSlotError, ClusterDownError, ConnectionError, ExecAbortError, ExternalAuthProviderError, MasterDownError, ModuleError, MovedError, NoPermissionError, NoScriptError, OutOfMemoryError, ReadOnlyError, ResponseError, TryAgainError, ) from ..typing import EncodableT from .encoders import Encoder from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." MODULE_EXPORTS_DATA_TYPES_ERROR = ( "Error unloading module: the module " "exports one or more module-side data " "types, can't unload" ) # user send an AUTH cmd to a server without authorization configured NO_AUTH_SET_ERROR = { # Redis >= 6.0 "AUTH called without any password " "configured for the default user. Are you sure " "your configuration is correct?": AuthenticationError, # Redis < 6.0 "Client sent AUTH, but no password is set": AuthenticationError, } EXTERNAL_AUTH_PROVIDER_ERROR = { "problem with LDAP service": ExternalAuthProviderError, } logger = logging.getLogger(__name__) class BaseParser(ABC): EXCEPTION_CLASSES = { "ERR": { "max number of clients reached": ConnectionError, "invalid password": AuthenticationError, # some Redis server versions report invalid command syntax # in lowercase "wrong number of arguments " "for 'auth' command": AuthenticationWrongNumberOfArgsError, # some Redis server versions report invalid command syntax # in uppercase "wrong number of arguments " "for 'AUTH' command": AuthenticationWrongNumberOfArgsError, MODULE_LOAD_ERROR: ModuleError, MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, NO_SUCH_MODULE_ERROR: ModuleError, MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, **NO_AUTH_SET_ERROR, **EXTERNAL_AUTH_PROVIDER_ERROR, }, "OOM": OutOfMemoryError, "WRONGPASS": AuthenticationError, "EXECABORT": ExecAbortError, "LOADING": BusyLoadingError, "NOSCRIPT": NoScriptError, "READONLY": ReadOnlyError, "NOAUTH": AuthenticationError, "NOPERM": NoPermissionError, "ASK": AskError, "TRYAGAIN": TryAgainError, "MOVED": MovedError, "CLUSTERDOWN": ClusterDownError, "CROSSSLOT": ClusterCrossSlotError, "MASTERDOWN": MasterDownError, } @classmethod def parse_error(cls, response): "Parse an error response" error_code = response.split(" ")[0] if error_code in cls.EXCEPTION_CLASSES: response = response[len(error_code) + 1 :] exception_class = cls.EXCEPTION_CLASSES[error_code] if isinstance(exception_class, dict): exception_class = exception_class.get(response, ResponseError) return exception_class(response, status_code=error_code) return ResponseError(response) def on_disconnect(self): raise NotImplementedError() def on_connect(self, connection): raise NotImplementedError() class _RESPBase(BaseParser): """Base class for sync-based resp parsing""" def __init__(self, socket_read_size): self.socket_read_size = socket_read_size self.encoder = None self._sock = None self._buffer = None def __del__(self): try: self.on_disconnect() except Exception: pass def on_connect(self, connection): "Called when the socket connects" self._sock = connection._sock self._buffer = SocketBuffer( self._sock, self.socket_read_size, connection.socket_timeout ) self.encoder = connection.encoder def on_disconnect(self): "Called when the socket disconnects" self._sock = None if self._buffer is not None: self._buffer.close() self._buffer = None self.encoder = None def can_read(self, timeout): return self._buffer and self._buffer.can_read(timeout) class AsyncBaseParser(BaseParser): """Base parsing class for the python-backed async parser""" __slots__ = "_stream", "_read_size" def __init__(self, socket_read_size: int): self._stream: Optional[StreamReader] = None self._read_size = socket_read_size async def can_read_destructive(self) -> bool: raise NotImplementedError() async def read_response( self, disable_decoding: bool = False ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: raise NotImplementedError() class MaintenanceNotificationsParser: """Protocol defining maintenance push notification parsing functionality""" @staticmethod def parse_oss_maintenance_start_msg(response): # Expected message format is: # SMIGRATING id = response[1] slots = safe_str(response[2]) return OSSNodeMigratingNotification(id, slots) @staticmethod def parse_oss_maintenance_completed_msg(response): # Expected message format is: # SMIGRATED [[ ], ...] id = response[1] nodes_to_slots_mapping_data = response[2] # Build the nodes_to_slots_mapping dict structure: # { # "src_host:port": [ # {"dest_host:port": "slot_range"}, # ... # ], # ... # } nodes_to_slots_mapping = {} for src_node, dest_node, slots in nodes_to_slots_mapping_data: src_node_str = safe_str(src_node) dest_node_str = safe_str(dest_node) slots_str = safe_str(slots) if src_node_str not in nodes_to_slots_mapping: nodes_to_slots_mapping[src_node_str] = [] nodes_to_slots_mapping[src_node_str].append({dest_node_str: slots_str}) return OSSNodeMigratedNotification(id, nodes_to_slots_mapping) @staticmethod def parse_maintenance_start_msg(response, notification_type): # Expected message format is: