Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 11 additions & 21 deletions sdk/keyvault/azure-keyvault-keys/tests/_async_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,16 @@ async def get_attestation_token(attestation_uri):
def get_decorator(only_hsm=False, only_vault=False, api_versions=None, **kwargs):
"""returns a test decorator for test parameterization"""
params = [
pytest.param(p[0],p[1], id=p[0] + ("_mhsm" if p[1] else "_vault" ))
pytest.param(p[0], p[1], id=p[0] + ("_mhsm" if p[1] else "_vault"))
for p in get_test_parameters(only_hsm, only_vault, api_versions=api_versions)
]
return params


def get_release_policy(attestation_uri, **kwargs):
release_policy_json = {
"anyOf": [
{
"anyOf": [
{
"claim": "sdk-test",
"equals": True
}
],
"authority": attestation_uri.rstrip("/") + "/"
}
],
"version": "1.0.0"
"anyOf": [{"anyOf": [{"claim": "sdk-test", "equals": True}], "authority": attestation_uri.rstrip("/") + "/"}],
"version": "1.0.0",
}
policy_string = json.dumps(release_policy_json).encode()
return KeyReleasePolicy(policy_string, **kwargs)
Expand All @@ -63,9 +53,9 @@ def get_test_parameters(only_hsm=False, only_vault=False, api_versions=None):


def is_public_cloud():
return (".microsoftonline.com" in os.getenv('AZURE_AUTHORITY_HOST', ''))
return ".microsoftonline.com" in os.getenv("AZURE_AUTHORITY_HOST", "")



class AsyncKeysClientPreparer(AzureRecordedTestCase):
def __init__(self, *args, **kwargs):
vault_playback_url = "https://vaultname.vault.azure.net"
Expand All @@ -83,25 +73,25 @@ def __init__(self, *args, **kwargs):

def __call__(self, fn):
async def _preparer(test_class, api_version, is_hsm, **kwargs):

self._skip_if_not_configured(api_version, is_hsm)
if not self.is_logging_enabled:
kwargs.update({"logging_enable": False})
endpoint_url = self.managed_hsm_url if is_hsm else self.vault_url
client = self.create_key_client(endpoint_url, api_version=api_version, **kwargs)
async with client:
await fn(test_class, client, is_hsm=is_hsm, managed_hsm_url = self.managed_hsm_url, vault_url = self.vault_url)
await fn(
test_class, client, is_hsm=is_hsm, managed_hsm_url=self.managed_hsm_url, vault_url=self.vault_url
)

return _preparer



def create_key_client(self, vault_uri, **kwargs):

from azure.keyvault.keys.aio import KeyClient

credential = self.get_credential(KeyClient, is_async=True)

return self.create_client_from_credential(KeyClient, credential=credential, vault_url=vault_uri, **kwargs)

def _set_mgmt_settings_real_values(self):
Expand Down
4 changes: 4 additions & 0 deletions sdk/keyvault/azure-keyvault-keys/tests/_keys_test_case.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os

import pytest
Expand Down
32 changes: 10 additions & 22 deletions sdk/keyvault/azure-keyvault-keys/tests/_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,16 @@ def get_attestation_token(attestation_uri):
def get_decorator(only_hsm=False, only_vault=False, api_versions=None, **kwargs):
"""returns a test decorator for test parameterization"""
params = [
pytest.param(p[0],p[1], id=p[0] + ("_mhsm" if p[1] else "_vault" ))
pytest.param(p[0], p[1], id=p[0] + ("_mhsm" if p[1] else "_vault"))
for p in get_test_parameters(only_hsm, only_vault, api_versions=api_versions)
]
return params


def get_release_policy(attestation_uri, **kwargs):
release_policy_json = {
"anyOf": [
{
"anyOf": [
{
"claim": "sdk-test",
"equals": True
}
],
"authority": attestation_uri.rstrip("/") + "/"
}
],
"version": "1.0.0"
"anyOf": [{"anyOf": [{"claim": "sdk-test", "equals": True}], "authority": attestation_uri.rstrip("/") + "/"}],
"version": "1.0.0",
}
policy_string = json.dumps(release_policy_json).encode()
return KeyReleasePolicy(policy_string, **kwargs)
Expand All @@ -63,9 +53,9 @@ def get_test_parameters(only_hsm=False, only_vault=False, api_versions=None):


def is_public_cloud():
return (".microsoftonline.com" in os.getenv('AZURE_AUTHORITY_HOST', ''))
return ".microsoftonline.com" in os.getenv("AZURE_AUTHORITY_HOST", "")



class KeysClientPreparer(AzureRecordedTestCase):
def __init__(self, *args, **kwargs):
vault_playback_url = "https://vaultname.vault.azure.net"
Expand All @@ -87,26 +77,24 @@ def __init__(self, *args, **kwargs):
def __call__(self, fn):
def _preparer(test_class, api_version, is_hsm, **kwargs):

#self._skip_if_not_configured(api_version, is_hsm)
self._skip_if_not_configured(api_version, is_hsm)
if not self.is_logging_enabled:
kwargs.update({"logging_enable": False})
endpoint_url = self.managed_hsm_url if is_hsm else self.vault_url
client = self.create_key_client(endpoint_url, api_version=api_version, **kwargs)

with client:
fn(test_class, client, is_hsm=is_hsm, managed_hsm_url = self.managed_hsm_url, vault_url = self.vault_url)
return _preparer

fn(test_class, client, is_hsm=is_hsm, managed_hsm_url=self.managed_hsm_url, vault_url=self.vault_url)

return _preparer

def create_key_client(self, vault_uri, **kwargs):

from azure.keyvault.keys import KeyClient

credential = self.get_credential(KeyClient)

return self.create_client_from_credential(KeyClient, credential=credential, vault_url=vault_uri, **kwargs)

return self.create_client_from_credential(KeyClient, credential=credential, vault_url=vault_uri, **kwargs)

def _set_mgmt_settings_real_values(self):
if self.is_live:
Expand Down