You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

271 lines
10 KiB

from logging import getLogger
from typing import Any, Union
from ..exceptions import ConnectionError, InvalidResponse, ResponseError
from ..typing import EncodableT
from .base import (
AsyncPushNotificationsParser,
PushNotificationsParser,
_AsyncRESPBase,
_RESPBase,
)
from .socket import SERVER_CLOSED_CONNECTION_ERROR
class _RESP3Parser(_RESPBase, PushNotificationsParser):
"""RESP3 protocol implementation"""
def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.node_moving_push_handler_func = None
self.maintenance_push_handler_func = None
self.oss_cluster_maint_push_handler_func = None
self.invalidation_push_handler_func = None
def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.debug("Push response: " + str(response))
return response
def read_response(self, disable_decoding=False, push_request=False):
pos = self._buffer.get_pos() if self._buffer is not None else None
try:
result = self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
except BaseException:
if self._buffer is not None:
self._buffer.rewind(pos)
raise
else:
if self._buffer is not None:
try:
self._buffer.purge()
except AttributeError:
# Buffer may have been set to None by another thread after
# the check above; result is still valid so we don't raise
pass
return result
def _read_response(self, disable_decoding=False, push_request=False):
raw = self._buffer.readline()
if not raw:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
byte, response = raw[:1], raw[1:]
# server returned an error
if byte in (b"-", b"!"):
if byte == b"!":
response = self._buffer.read(int(response))
response = response.decode("utf-8", errors="replace")
error = self.parse_error(response)
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
# and/or the pipeline's execute() will raise this error if
# necessary, so just return the exception instance here.
return error
# single value
elif byte == b"+":
pass
# null value
elif byte == b"_":
return None
# int and big int values
elif byte in (b":", b"("):
return int(response)
# double value
elif byte == b",":
return float(response)
# bool value
elif byte == b"#":
return response == b"t"
# bulk response
elif byte == b"$":
response = self._buffer.read(int(response))
# verbatim string response
elif byte == b"=":
response = self._buffer.read(int(response))[4:]
# array response
elif byte == b"*":
response = [
self._read_response(disable_decoding=disable_decoding)
for _ in range(int(response))
]
# set response
elif byte == b"~":
# redis can return unhashable types (like dict) in a set,
# so we return sets as list, all the time, for predictability
response = [
self._read_response(disable_decoding=disable_decoding)
for _ in range(int(response))
]
# map response
elif byte == b"%":
# We cannot use a dict-comprehension to parse stream.
# Evaluation order of key:val expression in dict comprehension only
# became defined to be left-right in version 3.8
resp_dict = {}
for _ in range(int(response)):
key = self._read_response(disable_decoding=disable_decoding)
resp_dict[key] = self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
response = resp_dict
# push response
elif byte == b">":
response = [
self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
for _ in range(int(response))
]
response = self.handle_push_response(response)
# if this is a push request return the push response
if push_request:
return response
return self._read_response(
disable_decoding=disable_decoding,
push_request=push_request,
)
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")
if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response
class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser):
def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.invalidation_push_handler_func = None
async def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.debug("Push response: " + str(response))
return response
async def read_response(
self, disable_decoding: bool = False, push_request: bool = False
):
if self._chunks:
# augment parsing buffer with previously read data
self._buffer += b"".join(self._chunks)
self._chunks.clear()
self._pos = 0
response = await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
# Successfully parsing a response allows us to clear our parsing buffer
self._clear()
return response
async def _read_response(
self, disable_decoding: bool = False, push_request: bool = False
) -> Union[EncodableT, ResponseError, None]:
if not self._stream or not self.encoder:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
raw = await self._readline()
response: Any
byte, response = raw[:1], raw[1:]
# if byte not in (b"-", b"+", b":", b"$", b"*"):
# raise InvalidResponse(f"Protocol Error: {raw!r}")
# server returned an error
if byte in (b"-", b"!"):
if byte == b"!":
response = await self._read(int(response))
response = response.decode("utf-8", errors="replace")
error = self.parse_error(response)
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
self._clear() # Successful parse
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
# and/or the pipeline's execute() will raise this error if
# necessary, so just return the exception instance here.
return error
# single value
elif byte == b"+":
pass
# null value
elif byte == b"_":
return None
# int and big int values
elif byte in (b":", b"("):
return int(response)
# double value
elif byte == b",":
return float(response)
# bool value
elif byte == b"#":
return response == b"t"
# bulk response
elif byte == b"$":
response = await self._read(int(response))
# verbatim string response
elif byte == b"=":
response = (await self._read(int(response)))[4:]
# array response
elif byte == b"*":
response = [
(await self._read_response(disable_decoding=disable_decoding))
for _ in range(int(response))
]
# set response
elif byte == b"~":
# redis can return unhashable types (like dict) in a set,
# so we always convert to a list, to have predictable return types
response = [
(await self._read_response(disable_decoding=disable_decoding))
for _ in range(int(response))
]
# map response
elif byte == b"%":
# We cannot use a dict-comprehension to parse stream.
# Evaluation order of key:val expression in dict comprehension only
# became defined to be left-right in version 3.8
resp_dict = {}
for _ in range(int(response)):
key = await self._read_response(disable_decoding=disable_decoding)
resp_dict[key] = await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
response = resp_dict
# push response
elif byte == b">":
response = [
(
await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
)
for _ in range(int(response))
]
response = await self.handle_push_response(response)
if not push_request:
return await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return response
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")
if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response