Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
8 changes: 0 additions & 8 deletions sdk/keyvault/azure-keyvault-keys/conftest.py

This file was deleted.

134 changes: 134 additions & 0 deletions sdk/keyvault/azure-keyvault-keys/tests/_async_test_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import json
import os

import pytest
from azure.core.pipeline import AsyncPipeline
from azure.core.pipeline.transport import AioHttpTransport, HttpRequest
from azure.keyvault.keys import KeyReleasePolicy
from azure.keyvault.keys._shared.client_base import DEFAULT_VERSION, ApiVersion
from devtools_testutils import AzureRecordedTestCase


async def get_attestation_token(attestation_uri):
request = HttpRequest("GET", "{}/generate-test-token".format(attestation_uri))
async with AsyncPipeline(transport=AioHttpTransport()) as pipeline:
response = await pipeline.run(request)
return json.loads(response.http_response.text())["token"]


def get_decorator(only_hsm=False, only_vault=False, api_versions=None, **kwargs):
"""returns a test decorator for test parameterization"""
params = [
pytest.param(p[0],p[1], id=p[0] + ("_mhsm" if p[1] else "_vault" ))
for p in get_test_parameters(only_hsm, only_vault, api_versions=api_versions)
]
return params


def get_release_policy(attestation_uri, **kwargs):
release_policy_json = {
"anyOf": [
{
"anyOf": [
{
"claim": "sdk-test",
"equals": True
}
],
"authority": attestation_uri.rstrip("/") + "/"
}
],
"version": "1.0.0"
}
policy_string = json.dumps(release_policy_json).encode()
return KeyReleasePolicy(policy_string, **kwargs)


def get_test_parameters(only_hsm=False, only_vault=False, api_versions=None):
"""generates a list of parameter pairs for test case parameterization, where [x, y] = [api_version, is_hsm]"""
combinations = []
versions = api_versions or ApiVersion
hsm_supported_versions = {ApiVersion.V7_2, ApiVersion.V7_3}

for api_version in versions:
if not only_vault and api_version in hsm_supported_versions:
combinations.append([api_version, True])
if not only_hsm:
combinations.append([api_version, False])
return combinations


def is_public_cloud():
return (".microsoftonline.com" in os.getenv('AZURE_AUTHORITY_HOST', ''))


class AsyncKeysClientPreparer(AzureRecordedTestCase):
def __init__(self, *args, **kwargs):
vault_playback_url = "https://vaultname.vault.azure.net"
hsm_playback_url = "https://managedhsmvaultname.vault.azure.net"
self.is_logging_enabled = kwargs.pop("logging_enable", True)

if self.is_live:
self.vault_url = os.environ["AZURE_KEYVAULT_URL"]
self.managed_hsm_url = os.environ.get("AZURE_MANAGEDHSM_URL")
else:
self.vault_url = vault_playback_url
self.managed_hsm_url = hsm_playback_url

self._set_mgmt_settings_real_values()

def __call__(self, fn):
async def _preparer(test_class, api_version, is_hsm, **kwargs):

self._skip_if_not_configured(api_version, is_hsm)
if not self.is_logging_enabled:
kwargs.update({"logging_enable": False})
endpoint_url = self.managed_hsm_url if is_hsm else self.vault_url
client = self.create_key_client(endpoint_url, api_version=api_version, **kwargs)
async with client:
await fn(test_class, client, is_hsm=is_hsm, managed_hsm_url = self.managed_hsm_url, vault_url = self.vault_url)

return _preparer



def create_key_client(self, vault_uri, **kwargs):

from azure.keyvault.keys.aio import KeyClient

credential = self.get_credential(KeyClient, is_async=True)

return self.create_client_from_credential(KeyClient, credential=credential, vault_url=vault_uri, **kwargs)

def create_crypto_client(self, key, **kwargs):

from azure.keyvault.keys.crypto.aio import CryptographyClient

credential = self.get_credential(CryptographyClient, is_async=True)
return self.create_client_from_credential(CryptographyClient, credential=credential, key=key, **kwargs)

Comment thread
kashifkhan marked this conversation as resolved.
Outdated
def _get_attestation_uri(self):
playback_uri = "https://fakeattestation.azurewebsites.net"
if self.is_live:
real_uri = os.environ.get("AZURE_KEYVAULT_ATTESTATION_URL")
if real_uri is None:
pytest.skip("No AZURE_KEYVAULT_ATTESTATION_URL environment variable")
self._scrub_url(real_uri, playback_uri)
return real_uri
return playback_uri
Comment thread
kashifkhan marked this conversation as resolved.
Outdated

def _set_mgmt_settings_real_values(self):
if self.is_live:
os.environ["AZURE_TENANT_ID"] = os.environ["KEYVAULT_TENANT_ID"]
os.environ["AZURE_CLIENT_ID"] = os.environ["KEYVAULT_CLIENT_ID"]
os.environ["AZURE_CLIENT_SECRET"] = os.environ["KEYVAULT_CLIENT_SECRET"]

def _skip_if_not_configured(self, api_version, is_hsm):
if self.is_live and api_version != DEFAULT_VERSION:
pytest.skip("This test only uses the default API version for live tests")
if self.is_live and is_hsm and self.managed_hsm_url is None:
pytest.skip("No HSM endpoint for live testing")
46 changes: 34 additions & 12 deletions sdk/keyvault/azure-keyvault-keys/tests/_shared/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,46 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os
import pytest
Comment thread
kashifkhan marked this conversation as resolved.
Outdated
import time

from azure_devtools.scenario_tests.patches import patch_time_sleep_api
from devtools_testutils import AzureTestCase
from azure.keyvault.keys._shared import HttpChallengeCache
from devtools_testutils import AzureRecordedTestCase


class KeyVaultTestCase(AzureTestCase):
def __init__(self, *args, **kwargs):
if "match_body" not in kwargs:
kwargs["match_body"] = True

super(KeyVaultTestCase, self).__init__(*args, **kwargs)
self.replay_patches.append(patch_time_sleep_api)

def setUp(self):
self.list_test_size = 7
super(KeyVaultTestCase, self).setUp()

class KeyVaultTestCase(AzureRecordedTestCase):
def get_resource_name(self, name):
"""helper to create resources with a consistent, test-indicative prefix"""
return super(KeyVaultTestCase, self).get_resource_name("livekvtest{}".format(name))

def _get_attestation_uri(self):
Comment thread
kashifkhan marked this conversation as resolved.
Outdated
playback_uri = "https://fakeattestation.azurewebsites.net"
if self.is_live:
real_uri = os.environ.get("AZURE_KEYVAULT_ATTESTATION_URL")
real_uri = real_uri.rstrip('/')
if real_uri is None:
pytest.skip("No AZURE_KEYVAULT_ATTESTATION_URL environment variable")
return real_uri
return playback_uri

def create_crypto_client(self, key, **kwargs):

from azure.keyvault.keys.crypto import CryptographyClient

credential = self.get_credential(CryptographyClient)
return self.create_client_from_credential(CryptographyClient, credential=credential, key=key, **kwargs)

def create_key_client(self, vault_uri, **kwargs):

from azure.keyvault.keys import KeyClient

credential = self.get_credential(KeyClient)

return self.create_client_from_credential(KeyClient, credential=credential, vault_url=vault_uri, **kwargs)

def _poll_until_no_exception(self, fn, expected_exception, max_retries=20, retry_delay=10):
"""polling helper for live tests because some operations take an unpredictable amount of time to complete"""

Expand All @@ -48,3 +66,7 @@ def _poll_until_exception(self, fn, expected_exception, max_retries=20, retry_de
return

self.fail("expected exception {expected_exception} was not raised")

def teardown_method(self, method):
HttpChallengeCache.clear()
assert len(HttpChallengeCache._cache) == 0
45 changes: 27 additions & 18 deletions sdk/keyvault/azure-keyvault-keys/tests/_shared/test_case_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,36 @@
# Licensed under the MIT License.
# ------------------------------------
import asyncio
import pytest
import os
Comment thread
kashifkhan marked this conversation as resolved.
Outdated

from azure_devtools.scenario_tests.patches import mock_in_unit_test
from devtools_testutils import AzureTestCase
from devtools_testutils import AzureRecordedTestCase
from azure.keyvault.keys._shared import HttpChallengeCache


def skip_sleep(unit_test):
async def immediate_return(_):
return

return mock_in_unit_test(unit_test, "asyncio.sleep", immediate_return)


class KeyVaultTestCase(AzureTestCase):
def __init__(self, *args, match_body=True, **kwargs):
super().__init__(*args, match_body=match_body, **kwargs)
self.replay_patches.append(skip_sleep)

def setUp(self):
self.list_test_size = 7
super(KeyVaultTestCase, self).setUp()

class KeyVaultTestCase(AzureRecordedTestCase):
def get_resource_name(self, name):
"""helper to create resources with a consistent, test-indicative prefix"""
return super(KeyVaultTestCase, self).get_resource_name("livekvtest{}".format(name))

def _get_attestation_uri(self):
playback_uri = "https://fakeattestation.azurewebsites.net"
if self.is_live:
real_uri = os.environ.get("AZURE_KEYVAULT_ATTESTATION_URL")
real_uri = real_uri.rstrip('/')
if real_uri is None:
pytest.skip("No AZURE_KEYVAULT_ATTESTATION_URL environment variable")
#self._scrub_url(real_uri, playback_uri)
Comment thread
kashifkhan marked this conversation as resolved.
Outdated
return real_uri
return playback_uri

def create_crypto_client(self, key, **kwargs):

from azure.keyvault.keys.crypto.aio import CryptographyClient

credential = self.get_credential(CryptographyClient, is_async=True)
return self.create_client_from_credential(CryptographyClient, credential=credential, key=key, **kwargs)

async def _poll_until_no_exception(self, fn, expected_exception, max_retries=20, retry_delay=10):
"""polling helper for live tests because some operations take an unpredictable amount of time to complete"""

Expand All @@ -51,3 +56,7 @@ async def _poll_until_exception(self, fn, expected_exception, max_retries=20, re
except expected_exception:
return
self.fail("expected exception {expected_exception} was not raised")

def teardown_method(self, method):
HttpChallengeCache.clear()
assert len(HttpChallengeCache._cache) == 0
Loading