Skip to content

Commit 9259314

Browse files
[textanalytics] decorator to validate multiapi (#24281)
* add validation decorator for multiapi args since inputs changed from v3.x to language api * add tests
1 parent 7b2d19c commit 9259314

23 files changed

Lines changed: 582 additions & 274 deletions

sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_check.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,12 @@ def is_language_api(api_version):
1111
"""Language API is date-based
1212
"""
1313
return re.search(r'\d{4}-\d{2}-\d{2}', api_version)
14+
15+
16+
def string_index_type_compatibility(string_index_type):
17+
"""Language API changed this string_index_type option to plural.
18+
Convert singular to plural for language API
19+
"""
20+
if string_index_type == "TextElement_v8":
21+
return "TextElements_v8"
22+
return string_index_type

sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_models.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,7 @@
1313
from ._generated.v3_0 import models as _v3_0_models
1414
from ._generated.v3_1 import models as _v3_1_models
1515
from ._generated.v2022_03_01_preview import models as _v2022_03_01_preview_models
16-
from ._check import is_language_api
17-
18-
19-
def string_index_type_compatibility(string_index_type):
20-
"""Language API changed this string_index_type option to plural.
21-
Convert singular to plural for language API
22-
"""
23-
if string_index_type == "TextElement_v8":
24-
return "TextElements_v8"
25-
return string_index_type
16+
from ._check import is_language_api, string_index_type_compatibility
2617

2718

2819
def _get_indices(relation):

sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_request_handlers.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
TextDocumentInput,
1010
_AnalyzeActionsType,
1111
)
12-
from ._check import is_language_api
1312

1413

1514
def _validate_input(documents, hint, whole_input_hint):
@@ -99,25 +98,3 @@ def _determine_action_type(action): # pylint: disable=too-many-return-statement
9998
if action.__class__.__name__ == "CustomMultiLabelClassificationLROTask":
10099
return _AnalyzeActionsType.MULTI_CATEGORY_CLASSIFY
101100
return _AnalyzeActionsType.EXTRACT_KEY_PHRASES
102-
103-
104-
def _check_string_index_type_arg(
105-
string_index_type_arg, api_version, string_index_type_default="UnicodeCodePoint"
106-
):
107-
string_index_type = None
108-
109-
if api_version == "v3.0":
110-
if string_index_type_arg is not None:
111-
raise ValueError(
112-
"'string_index_type' is only available for API version V3_1 and up"
113-
)
114-
elif is_language_api(api_version) and string_index_type_arg == "TextElement_v8":
115-
return "TextElements_v8"
116-
else:
117-
if string_index_type_arg is None:
118-
string_index_type = string_index_type_default
119-
120-
else:
121-
string_index_type = string_index_type_arg
122-
123-
return string_index_type

sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_text_analytics_client.py

Lines changed: 44 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
from azure.core.tracing.decorator import distributed_trace
1717
from azure.core.exceptions import HttpResponseError
1818
from azure.core.credentials import AzureKeyCredential
19-
from ._base_client import TextAnalyticsClientBase, TextAnalyticsApiVersion
19+
from ._base_client import TextAnalyticsClientBase
2020
from ._lro import AnalyzeActionsLROPoller, AnalyzeHealthcareEntitiesLROPoller
2121
from ._request_handlers import (
2222
_validate_input,
2323
_determine_action_type,
24-
_check_string_index_type_arg,
2524
)
25+
from ._validate import validate_multiapi_args, check_for_unsupported_actions_types
2626
from ._version import DEFAULT_API_VERSION
2727
from ._response_handlers import (
2828
process_http_response_error,
@@ -67,7 +67,7 @@
6767
MultiCategoryClassifyResult,
6868
_AnalyzeActionsType,
6969
)
70-
from ._check import is_language_api
70+
from ._check import is_language_api, string_index_type_compatibility
7171

7272
if TYPE_CHECKING:
7373
from azure.core.credentials import TokenCredential
@@ -132,6 +132,10 @@ def __init__(
132132
)
133133

134134
@distributed_trace
135+
@validate_multiapi_args(
136+
version_method_added="v3.0",
137+
args_mapping={"v3.1": ["disable_service_logs"]}
138+
)
135139
def detect_language(
136140
self,
137141
documents: Union[List[str], List[DetectLanguageInput], List[Dict[str, str]]],
@@ -228,6 +232,10 @@ def detect_language(
228232
return process_http_response_error(error)
229233

230234
@distributed_trace
235+
@validate_multiapi_args(
236+
version_method_added="v3.0",
237+
args_mapping={"v3.1": ["string_index_type", "disable_service_logs"]}
238+
)
231239
def recognize_entities(
232240
self,
233241
documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]],
@@ -294,11 +302,7 @@ def recognize_entities(
294302
docs = _validate_input(documents, "language", language)
295303
model_version = kwargs.pop("model_version", None)
296304
show_stats = kwargs.pop("show_stats", None)
297-
string_index_type = _check_string_index_type_arg(
298-
kwargs.pop("string_index_type", None),
299-
self._api_version,
300-
string_index_type_default=self._string_index_type_default,
301-
)
305+
string_index_type = kwargs.pop("string_index_type", self._string_index_type_default)
302306
disable_service_logs = kwargs.pop("disable_service_logs", None)
303307

304308
try:
@@ -310,7 +314,7 @@ def recognize_entities(
310314
parameters=models.EntitiesTaskParameters(
311315
logging_opt_out=disable_service_logs,
312316
model_version=model_version,
313-
string_index_type=string_index_type
317+
string_index_type=string_index_type_compatibility(string_index_type)
314318
)
315319
),
316320
show_stats=show_stats,
@@ -332,6 +336,9 @@ def recognize_entities(
332336
return process_http_response_error(error)
333337

334338
@distributed_trace
339+
@validate_multiapi_args(
340+
version_method_added="v3.1"
341+
)
335342
def recognize_pii_entities(
336343
self,
337344
documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]],
@@ -408,11 +415,7 @@ def recognize_pii_entities(
408415
show_stats = kwargs.pop("show_stats", None)
409416
domain_filter = kwargs.pop("domain_filter", None)
410417
categories_filter = kwargs.pop("categories_filter", None)
411-
string_index_type = _check_string_index_type_arg(
412-
kwargs.pop("string_index_type", None),
413-
self._api_version,
414-
string_index_type_default=self._string_index_type_default,
415-
)
418+
string_index_type = kwargs.pop("string_index_type", self._string_index_type_default)
416419
disable_service_logs = kwargs.pop("disable_service_logs", None)
417420

418421
try:
@@ -426,7 +429,7 @@ def recognize_pii_entities(
426429
model_version=model_version,
427430
domain=domain_filter,
428431
pii_categories=categories_filter,
429-
string_index_type=string_index_type
432+
string_index_type=string_index_type_compatibility(string_index_type)
430433
)
431434
),
432435
show_stats=show_stats,
@@ -446,19 +449,14 @@ def recognize_pii_entities(
446449
cls=kwargs.pop("cls", pii_entities_result),
447450
**kwargs
448451
)
449-
except ValueError as error:
450-
if (
451-
"API version v3.0 does not have operation 'entities_recognition_pii'"
452-
in str(error)
453-
):
454-
raise ValueError(
455-
"'recognize_pii_entities' endpoint is only available for API version V3_1 and up"
456-
) from error
457-
raise error
458452
except HttpResponseError as error:
459453
return process_http_response_error(error)
460454

461455
@distributed_trace
456+
@validate_multiapi_args(
457+
version_method_added="v3.0",
458+
args_mapping={"v3.1": ["string_index_type", "disable_service_logs"]}
459+
)
462460
def recognize_linked_entities(
463461
self,
464462
documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]],
@@ -527,11 +525,7 @@ def recognize_linked_entities(
527525
model_version = kwargs.pop("model_version", None)
528526
show_stats = kwargs.pop("show_stats", None)
529527
disable_service_logs = kwargs.pop("disable_service_logs", None)
530-
string_index_type = _check_string_index_type_arg(
531-
kwargs.pop("string_index_type", None),
532-
self._api_version,
533-
string_index_type_default=self._string_index_type_default,
534-
)
528+
string_index_type = kwargs.pop("string_index_type", self._string_index_type_default)
535529

536530
try:
537531
if is_language_api(self._api_version):
@@ -542,7 +536,7 @@ def recognize_linked_entities(
542536
parameters=models.EntityLinkingTaskParameters(
543537
logging_opt_out=disable_service_logs,
544538
model_version=model_version,
545-
string_index_type=string_index_type
539+
string_index_type=string_index_type_compatibility(string_index_type)
546540
)
547541
),
548542
show_stats=show_stats,
@@ -581,6 +575,10 @@ def _healthcare_result_callback(
581575
)
582576

583577
@distributed_trace
578+
@validate_multiapi_args(
579+
version_method_added="v3.1",
580+
args_mapping={"2022-03-01-preview": ["display_name"]}
581+
)
584582
def begin_analyze_healthcare_entities(
585583
self,
586584
documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]],
@@ -658,11 +656,7 @@ def begin_analyze_healthcare_entities(
658656
show_stats = kwargs.pop("show_stats", None)
659657
polling_interval = kwargs.pop("polling_interval", 5)
660658
continuation_token = kwargs.pop("continuation_token", None)
661-
string_index_type = _check_string_index_type_arg(
662-
kwargs.pop("string_index_type", None),
663-
self._api_version,
664-
string_index_type_default=self._string_index_type_default,
665-
)
659+
string_index_type = kwargs.pop("string_index_type", self._string_index_type_default)
666660
disable_service_logs = kwargs.pop("disable_service_logs", None)
667661
display_name = kwargs.pop("display_name", None)
668662

@@ -710,7 +704,7 @@ def get_result_from_cont_token(initial_response, pipeline_response):
710704
parameters=models.HealthcareTaskParameters(
711705
model_version=model_version,
712706
logging_opt_out=disable_service_logs,
713-
string_index_type=string_index_type,
707+
string_index_type=string_index_type_compatibility(string_index_type)
714708
)
715709
)
716710
]
@@ -755,19 +749,14 @@ def get_result_from_cont_token(initial_response, pipeline_response):
755749
continuation_token=continuation_token,
756750
**kwargs
757751
)
758-
759-
except ValueError as error:
760-
if "API version v3.0 does not have operation 'begin_health'" in str(error):
761-
raise ValueError(
762-
"'begin_analyze_healthcare_entities' method is only available for API version \
763-
V3_1 and up."
764-
) from error
765-
raise error
766-
767752
except HttpResponseError as error:
768753
return process_http_response_error(error)
769754

770755
@distributed_trace
756+
@validate_multiapi_args(
757+
version_method_added="v3.0",
758+
args_mapping={"v3.1": ["disable_service_logs"]}
759+
)
771760
def extract_key_phrases(
772761
self,
773762
documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]],
@@ -862,6 +851,10 @@ def extract_key_phrases(
862851
return process_http_response_error(error)
863852

864853
@distributed_trace
854+
@validate_multiapi_args(
855+
version_method_added="v3.0",
856+
args_mapping={"v3.1": ["show_opinion_mining", "disable_service_logs", "string_index_type"]}
857+
)
865858
def analyze_sentiment(
866859
self,
867860
documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]],
@@ -936,19 +929,7 @@ def analyze_sentiment(
936929
show_stats = kwargs.pop("show_stats", None)
937930
show_opinion_mining = kwargs.pop("show_opinion_mining", None)
938931
disable_service_logs = kwargs.pop("disable_service_logs", None)
939-
string_index_type = _check_string_index_type_arg(
940-
kwargs.pop("string_index_type", None),
941-
self._api_version,
942-
string_index_type_default=self._string_index_type_default,
943-
)
944-
if show_opinion_mining is not None:
945-
if (
946-
self._api_version == TextAnalyticsApiVersion.V3_0
947-
and show_opinion_mining
948-
):
949-
raise ValueError(
950-
"'show_opinion_mining' is only available for API version v3.1 and up"
951-
)
932+
string_index_type = kwargs.pop("string_index_type", self._string_index_type_default)
952933

953934
try:
954935
if is_language_api(self._api_version):
@@ -959,7 +940,7 @@ def analyze_sentiment(
959940
parameters=models.SentimentAnalysisTaskParameters(
960941
logging_opt_out=disable_service_logs,
961942
model_version=model_version,
962-
string_index_type=string_index_type,
943+
string_index_type=string_index_type_compatibility(string_index_type),
963944
opinion_mining=show_opinion_mining,
964945
)
965946
),
@@ -1001,6 +982,10 @@ def _analyze_result_callback(
1001982
)
1002983

1003984
@distributed_trace
985+
@validate_multiapi_args(
986+
version_method_added="v3.1",
987+
custom_wrapper=check_for_unsupported_actions_types
988+
)
1004989
def begin_analyze_actions(
1005990
self,
1006991
documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]],
@@ -1253,13 +1238,5 @@ def get_result_from_cont_token(initial_response, pipeline_response):
12531238
continuation_token=continuation_token,
12541239
**kwargs
12551240
)
1256-
1257-
except ValueError as error:
1258-
if "API version v3.0 does not have operation 'begin_analyze'" in str(error):
1259-
raise ValueError(
1260-
"'begin_analyze_actions' endpoint is only available for API version V3_1 and up"
1261-
) from error
1262-
raise error
1263-
12641241
except HttpResponseError as error:
12651242
return process_http_response_error(error)

0 commit comments

Comments
 (0)