Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
8 changes: 0 additions & 8 deletions sdk/keyvault/azure-keyvault-keys/conftest.py

This file was deleted.

117 changes: 117 additions & 0 deletions sdk/keyvault/azure-keyvault-keys/tests/_async_test_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import json
import os

import pytest
from azure.core.pipeline import AsyncPipeline
from azure.core.pipeline.transport import AioHttpTransport, HttpRequest
from azure.keyvault.keys import KeyReleasePolicy
from azure.keyvault.keys._shared.client_base import DEFAULT_VERSION, ApiVersion
from devtools_testutils import AzureRecordedTestCase


async def get_attestation_token(attestation_uri):
request = HttpRequest("GET", "{}/generate-test-token".format(attestation_uri))
async with AsyncPipeline(transport=AioHttpTransport()) as pipeline:
response = await pipeline.run(request)
return json.loads(response.http_response.text())["token"]


def get_decorator(only_hsm=False, only_vault=False, api_versions=None, **kwargs):
"""returns a test decorator for test parameterization"""
params = [
pytest.param(p[0],p[1], id=p[0] + ("_mhsm" if p[1] else "_vault" ))
for p in get_test_parameters(only_hsm, only_vault, api_versions=api_versions)
]
return params


def get_release_policy(attestation_uri, **kwargs):
release_policy_json = {
"anyOf": [
{
"anyOf": [
{
"claim": "sdk-test",
"equals": True
}
],
"authority": attestation_uri.rstrip("/") + "/"
}
],
"version": "1.0.0"
}
policy_string = json.dumps(release_policy_json).encode()
return KeyReleasePolicy(policy_string, **kwargs)


def get_test_parameters(only_hsm=False, only_vault=False, api_versions=None):
"""generates a list of parameter pairs for test case parameterization, where [x, y] = [api_version, is_hsm]"""
combinations = []
versions = api_versions or ApiVersion
hsm_supported_versions = {ApiVersion.V7_2, ApiVersion.V7_3}

for api_version in versions:
if not only_vault and api_version in hsm_supported_versions:
combinations.append([api_version, True])
if not only_hsm:
combinations.append([api_version, False])
return combinations


def is_public_cloud():
return (".microsoftonline.com" in os.getenv('AZURE_AUTHORITY_HOST', ''))


class AsyncKeysClientPreparer(AzureRecordedTestCase):
def __init__(self, *args, **kwargs):
vault_playback_url = "https://vaultname.vault.azure.net"
hsm_playback_url = "https://managedhsmvaultname.vault.azure.net"
self.is_logging_enabled = kwargs.pop("logging_enable", True)

if self.is_live:
self.vault_url = os.environ["AZURE_KEYVAULT_URL"]
self.managed_hsm_url = os.environ.get("AZURE_MANAGEDHSM_URL")
else:
self.vault_url = vault_playback_url
self.managed_hsm_url = hsm_playback_url

self._set_mgmt_settings_real_values()

def __call__(self, fn):
async def _preparer(test_class, api_version, is_hsm, **kwargs):

self._skip_if_not_configured(api_version, is_hsm)
if not self.is_logging_enabled:
kwargs.update({"logging_enable": False})
endpoint_url = self.managed_hsm_url if is_hsm else self.vault_url
client = self.create_key_client(endpoint_url, api_version=api_version, **kwargs)
async with client:
await fn(test_class, client, is_hsm=is_hsm, managed_hsm_url = self.managed_hsm_url, vault_url = self.vault_url)

return _preparer



def create_key_client(self, vault_uri, **kwargs):

from azure.keyvault.keys.aio import KeyClient

credential = self.get_credential(KeyClient, is_async=True)

return self.create_client_from_credential(KeyClient, credential=credential, vault_url=vault_uri, **kwargs)

def _set_mgmt_settings_real_values(self):
if self.is_live:
os.environ["AZURE_TENANT_ID"] = os.environ["KEYVAULT_TENANT_ID"]
os.environ["AZURE_CLIENT_ID"] = os.environ["KEYVAULT_CLIENT_ID"]
os.environ["AZURE_CLIENT_SECRET"] = os.environ["KEYVAULT_CLIENT_SECRET"]

def _skip_if_not_configured(self, api_version, is_hsm):
if self.is_live and api_version != DEFAULT_VERSION:
pytest.skip("This test only uses the default API version for live tests")
if self.is_live and is_hsm and self.managed_hsm_url is None:
pytest.skip("No HSM endpoint for live testing")
26 changes: 26 additions & 0 deletions sdk/keyvault/azure-keyvault-keys/tests/_keys_test_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os

import pytest
from devtools_testutils import AzureRecordedTestCase


class KeysTestCase(AzureRecordedTestCase):
def _get_attestation_uri(self):
playback_uri = "https://fakeattestation.azurewebsites.net"
if self.is_live:
real_uri = os.environ.get("AZURE_KEYVAULT_ATTESTATION_URL")
real_uri = real_uri.rstrip('/')
if real_uri is None:
pytest.skip("No AZURE_KEYVAULT_ATTESTATION_URL environment variable")
return real_uri
return playback_uri

def create_crypto_client(self, key, **kwargs):
if kwargs.pop("is_async", False):
from azure.keyvault.keys.crypto.aio import CryptographyClient
credential = self.get_credential(CryptographyClient,is_async=True)
else:
from azure.keyvault.keys.crypto import CryptographyClient
credential = self.get_credential(CryptographyClient)

return self.create_client_from_credential(CryptographyClient, credential=credential, key=key, **kwargs)
19 changes: 7 additions & 12 deletions sdk/keyvault/azure-keyvault-keys/tests/_shared/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,13 @@
# ------------------------------------
import time

from azure_devtools.scenario_tests.patches import patch_time_sleep_api
from devtools_testutils import AzureTestCase
from azure.keyvault.keys._shared import HttpChallengeCache
from devtools_testutils import AzureRecordedTestCase


class KeyVaultTestCase(AzureTestCase):
def __init__(self, *args, **kwargs):
if "match_body" not in kwargs:
kwargs["match_body"] = True

super(KeyVaultTestCase, self).__init__(*args, **kwargs)
self.replay_patches.append(patch_time_sleep_api)

def setUp(self):
self.list_test_size = 7
super(KeyVaultTestCase, self).setUp()

class KeyVaultTestCase(AzureRecordedTestCase):
def get_resource_name(self, name):
"""helper to create resources with a consistent, test-indicative prefix"""
return super(KeyVaultTestCase, self).get_resource_name("livekvtest{}".format(name))
Expand Down Expand Up @@ -48,3 +39,7 @@ def _poll_until_exception(self, fn, expected_exception, max_retries=20, retry_de
return

self.fail("expected exception {expected_exception} was not raised")

def teardown_method(self, method):
HttpChallengeCache.clear()
assert len(HttpChallengeCache._cache) == 0
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,11 @@
# ------------------------------------
import asyncio

from azure_devtools.scenario_tests.patches import mock_in_unit_test
from devtools_testutils import AzureTestCase
from devtools_testutils import AzureRecordedTestCase
from azure.keyvault.keys._shared import HttpChallengeCache


def skip_sleep(unit_test):
async def immediate_return(_):
return

return mock_in_unit_test(unit_test, "asyncio.sleep", immediate_return)


class KeyVaultTestCase(AzureTestCase):
def __init__(self, *args, match_body=True, **kwargs):
super().__init__(*args, match_body=match_body, **kwargs)
self.replay_patches.append(skip_sleep)

def setUp(self):
self.list_test_size = 7
super(KeyVaultTestCase, self).setUp()

class KeyVaultTestCase(AzureRecordedTestCase):
def get_resource_name(self, name):
"""helper to create resources with a consistent, test-indicative prefix"""
return super(KeyVaultTestCase, self).get_resource_name("livekvtest{}".format(name))
Expand Down Expand Up @@ -51,3 +36,7 @@ async def _poll_until_exception(self, fn, expected_exception, max_retries=20, re
except expected_exception:
return
self.fail("expected exception {expected_exception} was not raised")

def teardown_method(self, method):
HttpChallengeCache.clear()
assert len(HttpChallengeCache._cache) == 0
110 changes: 29 additions & 81 deletions sdk/keyvault/azure-keyvault-keys/tests/_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import functools
import json
import os

import pytest
from azure.core.pipeline import Pipeline
from azure.core.pipeline.transport import HttpRequest, RequestsTransport
from azure.keyvault.keys import KeyReleasePolicy
from azure.keyvault.keys._shared import HttpChallengeCache
from azure.keyvault.keys._shared.client_base import ApiVersion, DEFAULT_VERSION
from devtools_testutils import AzureTestCase
from parameterized import parameterized, param
import pytest
from six.moves.urllib_parse import urlparse


def client_setup(testcase_func):
"""decorator that creates a client to be passed in to a test method"""

@functools.wraps(testcase_func)
def wrapper(test_class_instance, api_version, is_hsm=False, **kwargs):
test_class_instance._skip_if_not_configured(api_version, is_hsm)
endpoint_url = test_class_instance.managed_hsm_url if is_hsm else test_class_instance.vault_url
client = test_class_instance.create_key_client(endpoint_url, api_version=api_version, **kwargs)

if kwargs.get("is_async"):
import asyncio

coroutine = testcase_func(test_class_instance, client, is_hsm=is_hsm)
loop = asyncio.get_event_loop()
loop.run_until_complete(coroutine)
else:
testcase_func(test_class_instance, client, is_hsm=is_hsm)

return wrapper
from azure.keyvault.keys._shared.client_base import DEFAULT_VERSION, ApiVersion
from devtools_testutils import AzureRecordedTestCase


def get_attestation_token(attestation_uri):
Expand All @@ -48,10 +23,10 @@ def get_attestation_token(attestation_uri):
def get_decorator(only_hsm=False, only_vault=False, api_versions=None, **kwargs):
"""returns a test decorator for test parameterization"""
params = [
param(api_version=p[0], is_hsm=p[1], **kwargs)
pytest.param(p[0],p[1], id=p[0] + ("_mhsm" if p[1] else "_vault" ))
for p in get_test_parameters(only_hsm, only_vault, api_versions=api_versions)
]
return functools.partial(parameterized.expand, params, name_func=suffixed_test_name)
return params


def get_release_policy(attestation_uri, **kwargs):
Expand Down Expand Up @@ -87,78 +62,51 @@ def get_test_parameters(only_hsm=False, only_vault=False, api_versions=None):
return combinations


def suffixed_test_name(testcase_func, param_num, param):
api_version = param.kwargs.get("api_version")
suffix = "mhsm" if param.kwargs.get("is_hsm") else "vault"
return "{}_{}_{}".format(
testcase_func.__name__, parameterized.to_safe_name(api_version), parameterized.to_safe_name(suffix)
)


def is_public_cloud():
return (".microsoftonline.com" in os.getenv('AZURE_AUTHORITY_HOST', ''))


class KeysTestCase(AzureTestCase):
def setUp(self, *args, **kwargs):
class KeysClientPreparer(AzureRecordedTestCase):
def __init__(self, *args, **kwargs):
vault_playback_url = "https://vaultname.vault.azure.net"
hsm_playback_url = "https://managedhsmname.managedhsm.azure.net"
hsm_playback_url = "https://managedhsmvaultname.vault.azure.net"
self.is_logging_enabled = kwargs.pop("logging_enable", True)

if self.is_live:
self.vault_url = os.environ["AZURE_KEYVAULT_URL"]
self._scrub_url(real_url=self.vault_url, playback_url=vault_playback_url)

self.managed_hsm_url = os.environ.get("AZURE_MANAGEDHSM_URL")
self.vault_url = self.vault_url.rstrip("/")
self.managed_hsm_url = os.environ.get("AZURE_MANAGEDHSM_URL", None)
if self.managed_hsm_url:
self._scrub_url(real_url=self.managed_hsm_url, playback_url=hsm_playback_url)
self.managed_hsm_url = self.managed_hsm_url.rstrip("/")
else:
self.vault_url = vault_playback_url
self.managed_hsm_url = hsm_playback_url

self._set_mgmt_settings_real_values()
super(KeysTestCase, self).setUp(*args, **kwargs)

def tearDown(self):
HttpChallengeCache.clear()
assert len(HttpChallengeCache._cache) == 0
super(KeysTestCase, self).tearDown()
def __call__(self, fn):
def _preparer(test_class, api_version, is_hsm, **kwargs):

def create_key_client(self, vault_uri, **kwargs):
if kwargs.pop("is_async", False):
from azure.keyvault.keys.aio import KeyClient

credential = self.get_credential(KeyClient, is_async=True)
else:
from azure.keyvault.keys import KeyClient
#self._skip_if_not_configured(api_version, is_hsm)
if not self.is_logging_enabled:
kwargs.update({"logging_enable": False})
endpoint_url = self.managed_hsm_url if is_hsm else self.vault_url
client = self.create_key_client(endpoint_url, api_version=api_version, **kwargs)

credential = self.get_credential(KeyClient)
return self.create_client_from_credential(KeyClient, credential=credential, vault_url=vault_uri, **kwargs)
with client:
fn(test_class, client, is_hsm=is_hsm, managed_hsm_url = self.managed_hsm_url, vault_url = self.vault_url)
return _preparer


def create_crypto_client(self, key, **kwargs):
if kwargs.pop("is_async", False):
from azure.keyvault.keys.crypto.aio import CryptographyClient

credential = self.get_credential(CryptographyClient, is_async=True)
else:
from azure.keyvault.keys.crypto import CryptographyClient
def create_key_client(self, vault_uri, **kwargs):

from azure.keyvault.keys import KeyClient

credential = self.get_credential(CryptographyClient)
return self.create_client_from_credential(CryptographyClient, credential=credential, key=key, **kwargs)
credential = self.get_credential(KeyClient)

return self.create_client_from_credential(KeyClient, credential=credential, vault_url=vault_uri, **kwargs)

def _get_attestation_uri(self):
playback_uri = "https://fakeattestation.azurewebsites.net"
if self.is_live:
real_uri = os.environ.get("AZURE_KEYVAULT_ATTESTATION_URL")
if real_uri is None:
pytest.skip("No AZURE_KEYVAULT_ATTESTATION_URL environment variable")
self._scrub_url(real_uri, playback_uri)
return real_uri
return playback_uri

def _scrub_url(self, real_url, playback_url):
real = urlparse(real_url)
playback = urlparse(playback_url)
self.scrubber.register_name_pair(real.netloc, playback.netloc)

def _set_mgmt_settings_real_values(self):
if self.is_live:
Expand Down
Loading