You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

389 lines
13 KiB

from enum import Enum
from typing import Any, Dict, List, Optional, Union
from redis.utils import experimental
try:
from typing import Self # Py 3.11+
except ImportError:
from typing_extensions import Self
from redis.commands.search.aggregation import Limit, Reducer
from redis.commands.search.query import Filter, SortbyField
@experimental
class HybridSearchQuery:
def __init__(
self,
query_string: str,
scorer: Optional[str] = None,
yield_score_as: Optional[str] = None,
) -> None:
"""
Create a new hybrid search query object.
Args:
query_string: The query string.
scorer: Scoring algorithm for text search query.
Allowed values are "TFIDF", "TFIDF.DOCNORM", "DISMAX", "DOCSCORE",
"BM25", "BM25STD", "BM25STD.TANH", "HAMMING", etc.
For more information about supported scoring algorithms, see
https://redis.io/docs/latest/develop/ai/search-and-query/advanced-concepts/scoring/
yield_score_as: The name of the field to yield the score as.
"""
self._query_string = query_string
self._scorer = scorer
self._yield_score_as = yield_score_as
def query_string(self) -> str:
"""Return the query string of this query object."""
return self._query_string
def scorer(self, scorer: str) -> "HybridSearchQuery":
"""
Scoring algorithm for text search query.
Allowed values are "TFIDF", "TFIDF.DOCNORM", "DISMAX", "DOCSCORE", "BM25",
"BM25STD", "BM25STD.TANH", "HAMMING", etc.
For more information about supported scoring algorithms,
see https://redis.io/docs/latest/develop/ai/search-and-query/advanced-concepts/scoring/
"""
self._scorer = scorer
return self
def yield_score_as(self, alias: str) -> "HybridSearchQuery":
"""
Yield the score as a field.
"""
self._yield_score_as = alias
return self
def get_args(self) -> List[str]:
args = ["SEARCH", self._query_string]
if self._scorer:
args.extend(("SCORER", self._scorer))
if self._yield_score_as:
args.extend(("YIELD_SCORE_AS", self._yield_score_as))
return args
class VectorSearchMethods(Enum):
KNN = "KNN"
RANGE = "RANGE"
@experimental
class HybridVsimQuery:
def __init__(
self,
vector_field_name: str,
vector_data: Union[bytes, str],
vsim_search_method: Optional[VectorSearchMethods] = None,
vsim_search_method_params: Optional[Dict[str, Any]] = None,
filter: Optional["Filter"] = None,
yield_score_as: Optional[str] = None,
) -> None:
"""
Create a new hybrid vsim query object.
Args:
vector_field_name: Vector field name.
vector_data: Vector data for the search.
vsim_search_method: Search method that will be used for the vsim search.
vsim_search_method_params: Search method parameters. Use the param names
for keys and the values for the values.
Example for KNN: {"K": 10, "EF_RUNTIME": 100}
where K is mandatory and defines the number of results
and EF_RUNTIME is optional and definesthe exploration factor.
Example for RANGE: {"RADIUS": 10, "EPSILON": 0.1}
where RADIUS is mandatory and defines the radius of the search
and EPSILON is optional and defines the accuracy of the search.
yield_score_as: The name of the field to yield the score as.
filter: If defined, a filter will be applied on the vsim query results.
"""
self._vector_field = vector_field_name
self._vector_data = vector_data
if vsim_search_method and vsim_search_method_params:
self.vsim_method_params(vsim_search_method, **vsim_search_method_params)
else:
self._vsim_method_params = None
self._filter = filter
self._yield_score_as = yield_score_as
def vector_field(self) -> str:
"""Return the vector field name of this query object."""
return self._vector_field
def vector_data(self) -> Union[bytes, str]:
"""Return the vector data of this query object."""
return self._vector_data
def vsim_method_params(
self,
method: VectorSearchMethods,
**kwargs,
) -> "HybridVsimQuery":
"""
Add search method parameters to the query.
Args:
method: Vector search method name. Supported values are "KNN" or "RANGE".
kwargs: Search method parameters. Use the param names for keys and the
values for the values. Example: {"K": 10, "EF_RUNTIME": 100}.
"""
vsim_method_params: List[Union[str, int]] = [method.value]
if kwargs:
vsim_method_params.append(len(kwargs.items()) * 2)
for key, value in kwargs.items():
vsim_method_params.extend((key, value))
self._vsim_method_params = vsim_method_params
return self
def filter(self, flt: "HybridFilter") -> "HybridVsimQuery":
"""
Add a filter to the query.
Args:
flt: A HybridFilter object, used on a corresponding field.
"""
self._filter = flt
return self
def yield_score_as(self, alias: str) -> "HybridVsimQuery":
"""
Return the score as a field with name `alias`.
"""
self._yield_score_as = alias
return self
def get_args(self) -> List[str]:
args = ["VSIM", self._vector_field, self._vector_data]
if self._vsim_method_params:
args.extend(self._vsim_method_params)
if self._filter:
args.extend(self._filter.args)
if self._yield_score_as:
args.extend(("YIELD_SCORE_AS", self._yield_score_as))
return args
class HybridQuery:
def __init__(
self,
search_query: HybridSearchQuery,
vector_similarity_query: HybridVsimQuery,
) -> None:
"""
Create a new hybrid query object.
Args:
search_query: HybridSearchQuery object containing the text query.
vector_similarity_query: HybridVsimQuery object containing the vector similarity query.
"""
self._search_query = search_query
self._vector_similarity_query = vector_similarity_query
def get_args(self) -> List[str]:
args = []
args.extend(self._search_query.get_args())
args.extend(self._vector_similarity_query.get_args())
return args
class CombinationMethods(Enum):
RRF = "RRF"
LINEAR = "LINEAR"
@experimental
class CombineResultsMethod:
def __init__(self, method: CombinationMethods, **kwargs) -> None:
"""
Create a new combine results method object.
Args:
method: The combine method to use - RRF or LINEAR.
kwargs: Additional combine parameters.
For RRF, the following parameters are supported(at least one should be provided):
WINDOW: Limits fusion scopeLimits fusion scope.
CONSTANT: Controls decay of rank influence.
YIELD_SCORE_AS: The name of the field to yield the calculated score as.
For LINEAR, supported parameters (at least one should be provided):
ALPHA: The weight of the first query.
BETA: The weight of the second query.
YIELD_SCORE_AS: The name of the field to yield the calculated score as.
The additional parameters are not validated and are passed as is to the server.
The supported format is to provide the parameter names and values like the following:
CombineResultsMethod(CombinationMethods.RRF, WINDOW=3, CONSTANT=0.5)
CombineResultsMethod(CombinationMethods.LINEAR, ALPHA=0.5, BETA=0.5)
"""
self._method = method
self._kwargs = kwargs
def get_args(self) -> List[Union[str, int]]:
args: List[Union[str, int]] = ["COMBINE", self._method.value]
if self._kwargs:
args.append(len(self._kwargs.items()) * 2)
for key, value in self._kwargs.items():
args.extend((key, value))
return args
@experimental
class HybridPostProcessingConfig:
def __init__(self) -> None:
"""
Create a new hybrid post processing configuration object.
"""
self._load_statements = []
self._apply_statements = []
self._groupby_statements = []
self._sortby_fields = []
self._filter = None
self._limit = None
def load(self, *fields: str) -> Self:
"""
Add load statement parameters to the query.
"""
if fields:
fields_str = " ".join(fields)
fields_list = fields_str.split(" ")
self._load_statements.extend(("LOAD", len(fields_list), *fields_list))
return self
def group_by(self, fields: List[str], *reducers: Reducer) -> Self:
"""
Specify by which fields to group the aggregation.
Args:
fields: Fields to group by. This can either be a single string or a list
of strings. In both cases, the field should be specified as `@field`.
reducers: One or more reducers. Reducers may be found in the
`aggregation` module.
"""
fields = [fields] if isinstance(fields, str) else fields
ret = ["GROUPBY", str(len(fields)), *fields]
for reducer in reducers:
ret.extend(("REDUCE", reducer.NAME, str(len(reducer.args))))
ret.extend(reducer.args)
if reducer._alias is not None:
ret.extend(("AS", reducer._alias))
self._groupby_statements.extend(ret)
return self
def apply(self, **kwexpr) -> Self:
"""
Specify one or more projection expressions to add to each result.
Args:
kwexpr: One or more key-value pairs for a projection. The key is
the alias for the projection, and the value is the projection
expression itself, for example `apply(square_root="sqrt(@foo)")`.
"""
apply_args = []
for alias, expr in kwexpr.items():
ret = ["APPLY", expr]
if alias is not None:
ret.extend(("AS", alias))
apply_args.extend(ret)
self._apply_statements.extend(apply_args)
return self
def sort_by(self, *sortby: "SortbyField") -> Self:
"""
Add sortby parameters to the query.
"""
self._sortby_fields = [*sortby]
return self
def filter(self, filter: "HybridFilter") -> Self:
"""
Add a numeric or string filter to the query.
Currently, only one of each filter is supported by the engine.
Args:
filter: A NumericFilter or GeoFilter object, used on a corresponding field.
"""
self._filter = filter
return self
def limit(self, offset: int, num: int) -> Self:
"""
Add limit parameters to the query.
"""
self._limit = Limit(offset, num)
return self
def build_args(self) -> List[str]:
args = []
if self._load_statements:
args.extend(self._load_statements)
if self._groupby_statements:
args.extend(self._groupby_statements)
if self._apply_statements:
args.extend(self._apply_statements)
if self._sortby_fields:
sortby_args = []
for f in self._sortby_fields:
sortby_args.extend(f.args)
args.extend(("SORTBY", len(sortby_args), *sortby_args))
if self._filter:
args.extend(self._filter.args)
if self._limit:
args.extend(self._limit.build_args())
return args
@experimental
class HybridFilter(Filter):
def __init__(
self,
conditions: str,
) -> None:
"""
Create a new hybrid filter object.
Args:
conditions: Filter conditions.
"""
args = [conditions]
Filter.__init__(self, "FILTER", *args)
@experimental
class HybridCursorQuery:
def __init__(self, count: int = 0, max_idle: int = 0) -> None:
"""
Create a new hybrid cursor query object.
Args:
count: Number of results to return per cursor iteration.
max_idle: Maximum idle time for the cursor.
"""
self.count = count
self.max_idle = max_idle
def build_args(self):
args = ["WITHCURSOR"]
if self.count:
args += ["COUNT", str(self.count)]
if self.max_idle:
args += ["MAXIDLE", str(self.max_idle)]
return args