Skip to content

Commit 285d175

Browse files
authored
Updated task type and artifact download for evaluation (#32734)
* Evaluate updated task types * New task type and downloading artifacts to output_path * Evaluate updated task types * New task type and downloading artifacts to output_path * Spell check fix
1 parent e4edbeb commit 285d175

6 files changed

Lines changed: 73 additions & 17 deletions

File tree

sdk/ai/azure-ai-generative/azure/ai/generative/evaluate/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66

77
try:
88
from ._evaluate import evaluate
9+
from ._evaluation_result import EvaluationResult
910
except ModuleNotFoundError as ex:
1011
print("Please make sure evaluate extras is installed. Please run the following command to install "
1112
"azure-ai-generative[evaluate]")
1213
raise ex
1314

1415
__all__ = [
15-
"evaluate"
16+
"evaluate",
17+
"EvaluationResult"
1618
]

sdk/ai/azure-ai-generative/azure/ai/generative/evaluate/_base_handler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# ---------------------------------------------------------
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
4-
from ._metric_handler import MetricHandler
5-
from ._utils import _has_column
64
import abc
75
import pandas as pd
86

sdk/ai/azure-ai-generative/azure/ai/generative/evaluate/_constants.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,19 @@
44

55
from azureml.metrics import constants
66

7-
CHAT = "chat"
7+
QA = "qa"
8+
QA_RAG = "qa-rag"
9+
CHAT_RAG = "chat-rag"
10+
11+
SUPPORTED_TASK_TYPE = [QA, QA_RAG, CHAT_RAG]
12+
13+
SUPPORTED_TO_METRICS_TASK_TYPE_MAPPING = {
14+
"qa": constants.QUESTION_ANSWERING,
15+
"qa-rag": constants.RAG_EVALUATION,
16+
"chat-rag": constants.RAG_EVALUATION,
17+
}
818

919
TYPE_TO_KWARGS_MAPPING = {
1020
constants.QUESTION_ANSWERING: ["questions", "contexts", "y_pred", "y_test"],
11-
CHAT: ["y_pred"]
21+
constants.RAG_EVALUATION: ["y_pred"]
1222
}

sdk/ai/azure-ai-generative/azure/ai/generative/evaluate/_evaluate.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111

1212
import mlflow
1313
import pandas as pd
14-
from azureml.metrics import constants
15-
from azure.ai.generative.evaluate._constants import CHAT
1614

1715
from mlflow.entities import Metric
1816
from mlflow.exceptions import MlflowException
@@ -21,6 +19,9 @@
2119
from azure.ai.generative.evaluate._metric_handler import MetricHandler
2220
from azure.ai.generative.evaluate._utils import _is_flow, load_jsonl, _get_artifact_dir_path
2321
from azure.ai.generative.evaluate._mlflow_log_collector import RedirectUserOutputStreams
22+
from azure.ai.generative.evaluate._constants import SUPPORTED_TO_METRICS_TASK_TYPE_MAPPING, SUPPORTED_TASK_TYPE, QA_RAG, \
23+
CHAT_RAG
24+
from azure.ai.generative.evaluate._evaluation_result import EvaluationResult
2425

2526
from ._utils import _write_properties_to_run_history
2627

@@ -85,6 +86,7 @@ def evaluate(
8586
metrics_list=None,
8687
model_config=None,
8788
data_mapping=None,
89+
output_path=None,
8890
**kwargs
8991
):
9092
results_list = []
@@ -117,6 +119,7 @@ def evaluate(
117119
data_mapping=data_mapping,
118120
params_dict=params_permutations_dict,
119121
metrics=metrics_list,
122+
output_path=output_path,
120123
**kwargs
121124
)
122125
results_list.append(evaluation_results)
@@ -130,6 +133,7 @@ def evaluate(
130133
model_config=model_config,
131134
data_mapping=data_mapping,
132135
metrics=metrics_list,
136+
output_path=output_path,
133137
**kwargs
134138
)
135139

@@ -146,6 +150,7 @@ def _evaluate(
146150
metrics=None,
147151
data_mapping=None,
148152
model_config=None,
153+
output_path=None,
149154
**kwargs
150155
):
151156
try:
@@ -166,7 +171,7 @@ def _evaluate(
166171
if target is None and prediction_data is None:
167172
raise Exception("target and prediction data cannot be null")
168173

169-
if task_type not in [constants.Tasks.QUESTION_ANSWERING, CHAT]:
174+
if task_type not in SUPPORTED_TASK_TYPE:
170175
raise Exception(f"task type {task_type} is not supported")
171176

172177
metrics_config = {}
@@ -195,7 +200,7 @@ def _evaluate(
195200
)
196201

197202
metrics_handler = MetricHandler(
198-
task_type=task_type,
203+
task_type=SUPPORTED_TO_METRICS_TASK_TYPE_MAPPING[task_type],
199204
metrics=metrics,
200205
prediction_data=asset_handler.prediction_data,
201206
truth_data=asset_handler.ground_truth,
@@ -209,7 +214,7 @@ def _evaluate(
209214

210215
def _get_instance_table():
211216
metrics.get("artifacts").pop("bertscore", None)
212-
if task_type == "chat":
217+
if task_type in [QA_RAG, CHAT_RAG]:
213218
instance_level_metrics_table = _get_chat_instance_table(metrics.get("artifacts"))
214219
else:
215220
instance_level_metrics_table = pd.DataFrame(metrics.get("artifacts"))
@@ -270,7 +275,19 @@ def _get_instance_table():
270275
mlflow.log_param("task_type", task_type)
271276
log_param_and_tag("_azureml.evaluate_metric_mapping", json.dumps(metrics_handler._metrics_mapping_to_log))
272277

273-
return metrics
278+
evaluation_result = EvaluationResult(
279+
metrics_summary=metrics.get("metrics"),
280+
artifacts={
281+
"eval_results.jsonl": f"runs:/{run.info.run_id}/eval_results.jsonl"
282+
},
283+
tracking_uri=kwargs.get("tracking_uri"),
284+
evaluation_id=run.info.run_id
285+
)
286+
if output_path:
287+
evaluation_result.download_evaluation_artifacts(path=output_path)
288+
289+
return evaluation_result
290+
274291

275292

276293
def log_input(data, data_is_file):
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
2+
3+
class EvaluationResult(object):
4+
5+
def __init__(self, metrics_summary: dict, artifacts: dict, **kwargs):
6+
self._metrics_summary = metrics_summary
7+
self._artifacts = artifacts
8+
self._tracking_uri = kwargs.get("tracking_uri")
9+
self._evaluation_id = kwargs.get("evaluation_id")
10+
11+
@property
12+
def metrics_summary(self) -> dict[str: float]:
13+
return self._metrics_summary
14+
15+
@property
16+
def artifacts(self) -> dict[str, str]:
17+
return self._artifacts
18+
19+
@property
20+
def tracking_uri(self) -> str:
21+
return self._tracking_uri
22+
23+
def download_evaluation_artifacts(self, path: str) -> str:
24+
from mlflow.artifacts import download_artifacts
25+
for artifact, artifact_uri in self.artifacts.items():
26+
download_artifacts(
27+
artifact_uri=artifact_uri,
28+
tracking_uri=self.tracking_uri,
29+
dst_path=path
30+
)

sdk/ai/azure-ai-generative/azure/ai/generative/evaluate/_metric_handler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# ---------------------------------------------------------
44
import copy
55

6-
from azure.ai.generative.evaluate._constants import TYPE_TO_KWARGS_MAPPING
6+
from azure.ai.generative.evaluate._constants import TYPE_TO_KWARGS_MAPPING, CHAT_RAG, QA_RAG
77

88

99
class MetricHandler(object):
@@ -64,8 +64,8 @@ def _get_data_for_metrics(self):
6464
data_column: data_source[metrics_mapping[data_column]].values.tolist()
6565
}
6666
)
67-
poped_value = metrics_mapping.pop(data_column, None)
68-
metrics_mapping_to_log[data_column] = poped_value
67+
popped_value = metrics_mapping.pop(data_column, None)
68+
metrics_mapping_to_log[data_column] = popped_value
6969

7070
metrics_data.update(metrics_mapping)
7171

@@ -75,14 +75,13 @@ def _get_data_for_metrics(self):
7575

7676
def calculate_metrics(self):
7777
from azureml.metrics import compute_metrics, constants
78-
from ._constants import CHAT
7978

8079
metrics_calculation_data = self._get_data_for_metrics()
8180

82-
metrics = self.metrics if self.task_type != CHAT else []
81+
metrics = self.metrics if self.task_type == constants.RAG_EVALUATION and self.metrics is not None else []
8382

8483
return compute_metrics(
8584
metrics=metrics,
86-
task_type=constants.Tasks.RAG_EVALUATION if self.task_type == CHAT else self.task_type,
85+
task_type=self.task_type,
8786
**metrics_calculation_data,
8887
)

0 commit comments

Comments
 (0)