Skip to content

Commit d89cb24

Browse files
authored
Support Model Stage in the SDK (#30254)
* Support Model Stage in SDK/CLI * update * fix lint * should fix test_ipp_model * fix list, test, and add recording * fix test * fix lint * fix it * run black * Revert "fix it" This reverts commit d3c337c. * Revert "run black" This reverts commit 8cb4113. * update * cleanup * remove extra spaces * fixes * update changelog
1 parent 29e78d9 commit d89cb24

8 files changed

Lines changed: 496 additions & 2 deletions

File tree

sdk/ml/azure-ai-ml/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
### Features Added
66
- Added support to enable set workspace connection secret expiry time.
7+
- Added support for `stage` on model version
78

89
### Bugs Fixed
910

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/assets/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class ModelSchema(PathAwareSchema):
3535
description = fields.Str()
3636
properties = fields.Dict()
3737
tags = fields.Dict()
38+
stage = fields.Str()
3839
utc_time_created = fields.DateTime(format="iso", dump_only=True)
3940
flavors = fields.Dict()
4041
creation_context = NestedField(CreationContextSchema, dump_only=True)

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_assets/_artifacts/model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from .artifact import ArtifactStorageInfo
3030

3131

32-
class Model(Artifact):
32+
class Model(Artifact): # pylint: disable=too-many-instance-attributes
3333
"""Model for training and scoring.
3434
3535
:param name: Name of the resource.
@@ -54,6 +54,8 @@ class Model(Artifact):
5454
:type tags: dict[str, str]
5555
:param properties: The asset property dictionary.
5656
:type properties: dict[str, str]
57+
:param stage: The stage of the resource.
58+
:type stage: str
5759
:param kwargs: A dictionary of additional configuration parameters.
5860
:type kwargs: dict
5961
"""
@@ -70,6 +72,7 @@ def __init__(
7072
description: Optional[str] = None,
7173
tags: Optional[Dict] = None,
7274
properties: Optional[Dict] = None,
75+
stage: Optional[str] = None,
7376
**kwargs,
7477
):
7578
self.job_name = kwargs.pop("job_name", None)
@@ -87,6 +90,7 @@ def __init__(
8790
self.flavors = dict(flavors) if flavors else None
8891
self._arm_type = ArmConstants.MODEL_VERSION_TYPE
8992
self.type = type or AssetTypes.CUSTOM_MODEL
93+
self.stage = stage
9094
if self._is_anonymous and self.path:
9195
_ignore_file = get_ignore_file(self.path)
9296
_upload_hash = get_object_hash(self.path, _ignore_file)
@@ -115,6 +119,7 @@ def _to_dict(self) -> Dict:
115119
def _from_rest_object(cls, model_rest_object: ModelVersion) -> "Model":
116120
rest_model_version: ModelVersionProperties = model_rest_object.properties
117121
arm_id = AMLVersionedArmId(arm_id=model_rest_object.id)
122+
model_stage = rest_model_version.stage if hasattr(rest_model_version, "stage") else None
118123
if hasattr(rest_model_version, "flavors"):
119124
flavors = {key: flavor.data for key, flavor in rest_model_version.flavors.items()}
120125
model = Model(
@@ -126,6 +131,7 @@ def _from_rest_object(cls, model_rest_object: ModelVersion) -> "Model":
126131
tags=rest_model_version.tags,
127132
flavors=flavors,
128133
properties=rest_model_version.properties,
134+
stage=model_stage,
129135
# pylint: disable=protected-access
130136
creation_context=SystemData._from_rest_object(model_rest_object.system_data),
131137
type=rest_model_version.model_type,
@@ -162,6 +168,7 @@ def _to_rest_object(self) -> ModelVersion:
162168
else None, # flatten OrderedDict to dict
163169
model_type=self.type,
164170
model_uri=self.path,
171+
stage=self.stage,
165172
is_anonymous=self._is_anonymous,
166173
)
167174
model_version_resource = ModelVersion(properties=model_version)

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_model_operations.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,6 @@ def _get(self, name: str, version: Optional[str] = None) -> ModelVersion: # nam
248248
name=name,
249249
version=version,
250250
workspace_name=self._workspace_name,
251-
api_version="2023-02-01-preview",
252251
**self._scope_kwargs,
253252
)
254253
)
@@ -418,6 +417,7 @@ def restore(
418417
def list(
419418
self,
420419
name: Optional[str] = None,
420+
stage: Optional[str] = None,
421421
*,
422422
list_view_type: ListViewType = ListViewType.ACTIVE_ONLY,
423423
) -> Iterable[Model]:
@@ -444,6 +444,7 @@ def list(
444444
workspace_name=self._workspace_name,
445445
cls=lambda objs: [Model._from_rest_object(obj) for obj in objs],
446446
list_view_type=list_view_type,
447+
stage=stage,
447448
**self._scope_kwargs,
448449
)
449450
)

sdk/ml/azure-ai-ml/tests/model/e2etests/test_model.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,35 @@ def test_crud_file(self, client: MLClient, randstr: Callable[[], str], tmp_path:
6969
# with pytest.raises(Exception):
7070
# client.models.get(name=model.name, version="3")
7171

72+
def test_crud_model_with_stage(self, client: MLClient, randstr: Callable[[], str], tmp_path: Path) -> None:
73+
path = Path("./tests/test_configs/model/model_with_stage.yml")
74+
model_name = randstr("model_prod_name")
75+
76+
model = load_model(path)
77+
model.name = model_name
78+
model = client.models.create_or_update(model)
79+
assert model.name == model_name
80+
assert model.version == "3"
81+
assert model.description == "this is my test model with stage"
82+
assert model.type == "mlflow_model"
83+
assert model.stage == "Production"
84+
assert re.match(LONG_URI_REGEX_FORMAT, model.path)
85+
86+
with pytest.raises(Exception):
87+
with patch("azure.ai.ml._artifacts._artifact_utilities.get_object_hash", return_value="DIFFERENT_HASH"):
88+
model = load_model(source=artifact_path)
89+
model = client.models.create_or_update(model)
90+
91+
model = client.models.get(model.name, "3")
92+
assert model.name == model_name
93+
assert model.version == "3"
94+
assert model.description == "this is my test model with stage"
95+
assert model.stage == "Production"
96+
97+
model_list = client.models.list(name=model.name, stage="Production")
98+
model_stage_list = [m.stage for m in model_list if m is not None]
99+
assert model.stage in model_stage_list
100+
72101
def test_list_no_name(self, client: MLClient) -> None:
73102
models = client.models.list()
74103
assert isinstance(models, Iterator)

sdk/ml/azure-ai-ml/tests/model/unittests/test_model_schema.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ def test_deserialize_no_version(self) -> None:
3232
assert model._auto_increment_version
3333
assert model.type == AssetTypes.CUSTOM_MODEL # assert the default model type
3434

35+
def test_deserialize_with_stage(self) -> None:
36+
path = Path("./tests/test_configs/model/model_with_stage.yml")
37+
model = load_model(path)
38+
assert model.stage == "Production"
39+
3540
def test_ipp_model(self) -> None:
3641
rest_ipp_model = {
3742
"id": "azureml://registries/fake_registry/models/fake_ipp_model/versions/611575",

0 commit comments

Comments
 (0)