diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 9064f8b7ee..ab7a63251b 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -11,7 +11,7 @@ HybridQuery, ) from redis.commands.search.hybrid_result import HybridCursorResult, HybridResult -from redis.utils import deprecated_function +from redis.utils import deprecated_function, experimental_method from ..helpers import get_protocol_version from ._util import to_string @@ -560,6 +560,7 @@ def search( SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0 ) + @experimental_method() def hybrid_search( self, query: HybridQuery, @@ -1053,6 +1054,7 @@ async def search( SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0 ) + @experimental_method() async def hybrid_search( self, query: HybridQuery, diff --git a/redis/commands/search/hybrid_query.py b/redis/commands/search/hybrid_query.py index 3b072082b2..fe51f4d161 100644 --- a/redis/commands/search/hybrid_query.py +++ b/redis/commands/search/hybrid_query.py @@ -25,7 +25,11 @@ def __init__( Args: query_string: The query string. - scorer: The scorer to use. Allowed values are "TFIDF" or "BM25". + scorer: Scoring algorithm for text search query. + Allowed values are "TFIDF", "TFIDF.DOCNORM", "DISMAX", "DOCSCORE", + "BM25", "BM25STD", "BM25STD.TANH", "HAMMING", etc. + For more information about supported scoring algorithms, see + https://redis.io/docs/latest/develop/ai/search-and-query/advanced-concepts/scoring/ yield_score_as: The name of the field to yield the score as. """ self._query_string = query_string @@ -39,9 +43,10 @@ def query_string(self) -> str: def scorer(self, scorer: str) -> "HybridSearchQuery": """ Scoring algorithm for text search query. - Allowed values are "TFIDF", "DISMAX", "DOCSCORE", "BM25", etc. + Allowed values are "TFIDF", "TFIDF.DOCNORM", "DISMAX", "DOCSCORE", "BM25", + "BM25STD", "BM25STD.TANH", "HAMMING", etc. - For more information about supported scroring algorithms, + For more information about supported scoring algorithms, see https://redis.io/docs/latest/develop/ai/search-and-query/advanced-concepts/scoring/ """ self._scorer = scorer diff --git a/redis/utils.py b/redis/utils.py index 8a8bce6de1..799294f0f5 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -1,4 +1,5 @@ import datetime +import inspect import logging import textwrap import warnings @@ -125,12 +126,22 @@ def deprecated_function(reason="", version="", name=None): """ def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - warn_deprecated(name or func.__name__, reason, version, stacklevel=3) - return func(*args, **kwargs) + if inspect.iscoroutinefunction(func): + # Create async wrapper for async functions + @wraps(func) + async def async_wrapper(*args, **kwargs): + warn_deprecated(name or func.__name__, reason, version, stacklevel=3) + return await func(*args, **kwargs) + + return async_wrapper + else: + # Create regular wrapper for sync functions + @wraps(func) + def wrapper(*args, **kwargs): + warn_deprecated(name or func.__name__, reason, version, stacklevel=3) + return func(*args, **kwargs) - return wrapper + return wrapper return decorator @@ -158,9 +169,26 @@ def warn_deprecated_arg_usage( C = TypeVar("C", bound=Callable) +def _get_filterable_args( + func: Callable, args: tuple, kwargs: dict, allowed_args: Optional[List[str]] = None +) -> dict: + """ + Extract arguments from function call that should be checked for deprecation/experimental warnings. + Excludes 'self' and any explicitly allowed args. + """ + arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] + filterable_args = dict(zip(arg_names, args)) + filterable_args.update(kwargs) + filterable_args.pop("self", None) + if allowed_args: + for allowed_arg in allowed_args: + filterable_args.pop(allowed_arg, None) + return filterable_args + + def deprecated_args( - args_to_warn: list = ["*"], - allowed_args: list = [], + args_to_warn: Optional[List[str]] = None, + allowed_args: Optional[List[str]] = None, reason: str = "", version: str = "", ) -> Callable[[C], C]: @@ -168,37 +196,46 @@ def deprecated_args( Decorator to mark specified args of a function as deprecated. If '*' is in args_to_warn, all arguments will be marked as deprecated. """ + if args_to_warn is None: + args_to_warn = ["*"] + if allowed_args is None: + allowed_args = [] + + def _check_deprecated_args(func, filterable_args): + """Check and warn about deprecated arguments.""" + for arg in args_to_warn: + if arg == "*" and len(filterable_args) > 0: + warn_deprecated_arg_usage( + list(filterable_args.keys()), + func.__name__, + reason, + version, + stacklevel=5, + ) + elif arg in filterable_args: + warn_deprecated_arg_usage( + arg, func.__name__, reason, version, stacklevel=5 + ) def decorator(func: C) -> C: - @wraps(func) - def wrapper(*args, **kwargs): - # Get function argument names - arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] - - provided_args = dict(zip(arg_names, args)) - provided_args.update(kwargs) - - provided_args.pop("self", None) - for allowed_arg in allowed_args: - provided_args.pop(allowed_arg, None) - - for arg in args_to_warn: - if arg == "*" and len(provided_args) > 0: - warn_deprecated_arg_usage( - list(provided_args.keys()), - func.__name__, - reason, - version, - stacklevel=3, - ) - elif arg in provided_args: - warn_deprecated_arg_usage( - arg, func.__name__, reason, version, stacklevel=3 - ) - - return func(*args, **kwargs) - - return wrapper + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args, **kwargs): + filterable_args = _get_filterable_args(func, args, kwargs, allowed_args) + _check_deprecated_args(func, filterable_args) + return await func(*args, **kwargs) + + return async_wrapper + else: + + @wraps(func) + def wrapper(*args, **kwargs): + filterable_args = _get_filterable_args(func, args, kwargs, allowed_args) + _check_deprecated_args(func, filterable_args) + return func(*args, **kwargs) + + return wrapper return decorator @@ -368,12 +405,22 @@ def experimental_method() -> Callable[[C], C]: """ def decorator(func: C) -> C: - @wraps(func) - def wrapper(*args, **kwargs): - warn_experimental(func.__name__, stacklevel=2) - return func(*args, **kwargs) + if inspect.iscoroutinefunction(func): + # Create async wrapper for async functions + @wraps(func) + async def async_wrapper(*args, **kwargs): + warn_experimental(func.__name__, stacklevel=2) + return await func(*args, **kwargs) + + return async_wrapper + else: + # Create regular wrapper for sync functions + @wraps(func) + def wrapper(*args, **kwargs): + warn_experimental(func.__name__, stacklevel=2) + return func(*args, **kwargs) - return wrapper + return wrapper return decorator @@ -393,32 +440,45 @@ def warn_experimental_arg_usage( def experimental_args( - args_to_warn: list = ["*"], + args_to_warn: Optional[List[str]] = None, ) -> Callable[[C], C]: """ Decorator to mark specified args of a function as experimental. + If '*' is in args_to_warn, all arguments will be marked as experimental. """ + if args_to_warn is None: + args_to_warn = ["*"] + + def _check_experimental_args(func, filterable_args): + """Check and warn about experimental arguments.""" + for arg in args_to_warn: + if arg == "*" and len(filterable_args) > 0: + warn_experimental_arg_usage( + list(filterable_args.keys()), func.__name__, stacklevel=4 + ) + elif arg in filterable_args: + warn_experimental_arg_usage(arg, func.__name__, stacklevel=4) def decorator(func: C) -> C: - @wraps(func) - def wrapper(*args, **kwargs): - # Get function argument names - arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] + if inspect.iscoroutinefunction(func): - provided_args = dict(zip(arg_names, args)) - provided_args.update(kwargs) + @wraps(func) + async def async_wrapper(*args, **kwargs): + filterable_args = _get_filterable_args(func, args, kwargs) + if len(filterable_args) > 0: + _check_experimental_args(func, filterable_args) + return await func(*args, **kwargs) - provided_args.pop("self", None) + return async_wrapper + else: - if len(provided_args) == 0: + @wraps(func) + def wrapper(*args, **kwargs): + filterable_args = _get_filterable_args(func, args, kwargs) + if len(filterable_args) > 0: + _check_experimental_args(func, filterable_args) return func(*args, **kwargs) - for arg in args_to_warn: - if arg in provided_args: - warn_experimental_arg_usage(arg, func.__name__, stacklevel=3) - - return func(*args, **kwargs) - - return wrapper + return wrapper return decorator diff --git a/tests/test_asyncio/test_utils.py b/tests/test_asyncio/test_utils.py index 05cad1bfaf..47f383f3ef 100644 --- a/tests/test_asyncio/test_utils.py +++ b/tests/test_asyncio/test_utils.py @@ -1,8 +1,141 @@ from datetime import datetime +import warnings +import pytest import redis +from redis.utils import ( + deprecated_function, + deprecated_args, + experimental_method, + experimental_args, +) async def redis_server_time(client: redis.Redis): seconds, milliseconds = await client.time() timestamp = float(f"{seconds}.{milliseconds}") return datetime.fromtimestamp(timestamp) + + +# Async tests for deprecated_function decorator +class TestDeprecatedFunctionAsync: + @pytest.mark.asyncio + async def test_async_function_warns(self): + @deprecated_function(reason="use new_async_func", version="2.0.0") + async def old_async_func(): + return "async_result" + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = await old_async_func() + assert result == "async_result" + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "old_async_func" in str(w[0].message) + + @pytest.mark.asyncio + async def test_async_preserves_function_metadata(self): + @deprecated_function() + async def async_documented_func(): + """This is the async docstring.""" + pass + + assert async_documented_func.__name__ == "async_documented_func" + assert async_documented_func.__doc__ == "This is the async docstring." + + +# Async tests for deprecated_args decorator +class TestDeprecatedArgsAsync: + @pytest.mark.asyncio + async def test_async_function_warns_on_deprecated_arg(self): + @deprecated_args(args_to_warn=["old_param"], reason="use new_param") + async def async_func_with_args(new_param=None, old_param=None): + return new_param or old_param + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = await async_func_with_args(old_param="async_value") + assert result == "async_value" + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "old_param" in str(w[0].message) + + @pytest.mark.asyncio + async def test_async_function_no_warning_on_allowed_arg(self): + @deprecated_args(args_to_warn=["*"], allowed_args=["allowed_param"]) + async def async_func_with_allowed(allowed_param=None): + return allowed_param + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = await async_func_with_allowed(allowed_param="async_value") + assert result == "async_value" + assert len(w) == 0 + + @pytest.mark.asyncio + async def test_async_wildcard_warns_all_args(self): + @deprecated_args(args_to_warn=["*"]) + async def async_func_all_deprecated(param1=None, param2=None): + return (param1, param2) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = await async_func_all_deprecated(param1="a", param2="b") + assert result == ("a", "b") + assert len(w) == 1 + assert "param1" in str(w[0].message) or "param2" in str(w[0].message) + + +# Async tests for experimental_method decorator +class TestExperimentalMethodAsync: + @pytest.mark.asyncio + async def test_async_function_warns(self): + @experimental_method() + async def async_experimental_func(): + return "async_experimental_result" + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = await async_experimental_func() + assert result == "async_experimental_result" + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + assert "async_experimental_func" in str(w[0].message) + + @pytest.mark.asyncio + async def test_async_preserves_function_metadata(self): + @experimental_method() + async def async_experimental_documented(): + """Experimental async docstring.""" + pass + + assert async_experimental_documented.__name__ == "async_experimental_documented" + assert async_experimental_documented.__doc__ == "Experimental async docstring." + + +# Async tests for experimental_args decorator +class TestExperimentalArgsAsync: + @pytest.mark.asyncio + async def test_async_function_warns_on_experimental_arg(self): + @experimental_args(args_to_warn=["beta_param"]) + async def async_func_with_experimental(stable_param=None, beta_param=None): + return stable_param or beta_param + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = await async_func_with_experimental(beta_param="async_beta") + assert result == "async_beta" + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + assert "beta_param" in str(w[0].message) + + @pytest.mark.asyncio + async def test_async_no_warning_when_no_args_provided(self): + @experimental_args(args_to_warn=["beta_param"]) + async def async_func_no_args(): + return "no_args" + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = await async_func_no_args() + assert result == "no_args" + assert len(w) == 0 diff --git a/tests/test_search.py b/tests/test_search.py index a5f42f62ed..b5576f8826 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -4250,6 +4250,45 @@ def test_hybrid_search_query_with_scorer(self, client): ) assert res["warnings"] == [] + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_supported_scorer(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=10) + + # set search query + search_query = HybridSearchQuery("shoes") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data="$vec", + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + supported_scorers = [ + "TFIDF", + "TFIDF.DOCNORM", + "BM25", + "BM25STD", + "BM25STD.TANH", + "DISMAX", + "DOCSCORE", + "HAMMING", + ] + for scorer in supported_scorers: + search_query.scorer(scorer) + + res = client.ft().hybrid_search( + query=hybrid_query, + params_substitution={ + "vec": np.array([1, 2, 2, 3], dtype=np.float32).tobytes() + }, + timeout=10, + ) + assert res is not None + @pytest.mark.redismod @skip_if_server_version_lt("8.3.224") def test_hybrid_search_query_with_vsim_method_defined_query_init(self, client): diff --git a/tests/test_utils.py b/tests/test_utils.py index 75de8dbb9f..a9ca17ce91 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,13 @@ from datetime import datetime +import warnings import pytest -from redis.utils import compare_versions +from redis.utils import ( + compare_versions, + deprecated_function, + deprecated_args, + experimental_method, + experimental_args, +) @pytest.mark.parametrize( @@ -32,3 +39,112 @@ def redis_server_time(client): seconds, milliseconds = client.time() timestamp = float(f"{seconds}.{milliseconds}") return datetime.fromtimestamp(timestamp) + + +# Tests for deprecated_function decorator +class TestDeprecatedFunction: + def test_sync_function_warns(self): + @deprecated_function(reason="use new_func", version="1.0.0") + def old_func(): + return "result" + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = old_func() + assert result == "result" + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "old_func" in str(w[0].message) + assert "use new_func" in str(w[0].message) + assert "1.0.0" in str(w[0].message) + + def test_preserves_function_metadata(self): + @deprecated_function() + def documented_func(): + """This is the docstring.""" + pass + + assert documented_func.__name__ == "documented_func" + assert documented_func.__doc__ == "This is the docstring." + + +# Tests for deprecated_args decorator +class TestDeprecatedArgs: + def test_sync_function_warns_on_deprecated_arg(self): + @deprecated_args(args_to_warn=["old_param"], reason="use new_param") + def func_with_args(new_param=None, old_param=None): + return new_param or old_param + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = func_with_args(old_param="value") + assert result == "value" + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "old_param" in str(w[0].message) + + def test_sync_function_no_warning_on_allowed_arg(self): + @deprecated_args(args_to_warn=["*"], allowed_args=["allowed_param"]) + def func_with_allowed(allowed_param=None): + return allowed_param + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = func_with_allowed(allowed_param="value") + assert result == "value" + assert len(w) == 0 + + def test_wildcard_warns_all_args(self): + @deprecated_args(args_to_warn=["*"]) + def func_all_deprecated(param1=None, param2=None): + return (param1, param2) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = func_all_deprecated(param1="a", param2="b") + assert result == ("a", "b") + assert len(w) == 1 + assert "param1" in str(w[0].message) or "param2" in str(w[0].message) + + +# Tests for experimental_method decorator +class TestExperimentalMethod: + def test_sync_function_warns(self): + @experimental_method() + def experimental_func(): + return "experimental_result" + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = experimental_func() + assert result == "experimental_result" + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + assert "experimental_func" in str(w[0].message) + + +# Tests for experimental_args decorator +class TestExperimentalArgs: + def test_sync_function_warns_on_experimental_arg(self): + @experimental_args(args_to_warn=["beta_param"]) + def func_with_experimental(stable_param=None, beta_param=None): + return stable_param or beta_param + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = func_with_experimental(beta_param="beta_value") + assert result == "beta_value" + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + assert "beta_param" in str(w[0].message) + + def test_no_warning_when_no_args_provided(self): + @experimental_args(args_to_warn=["beta_param"]) + def func_no_args(): + return "no_args" + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = func_no_args() + assert result == "no_args" + assert len(w) == 0