|
| 1 | +// Copyright (c) Microsoft Corporation. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include "azure/identity/client_assertion_credential.hpp" |
| 5 | + |
| 6 | +#include "private/identity_log.hpp" |
| 7 | +#include "private/package_version.hpp" |
| 8 | +#include "private/tenant_id_resolver.hpp" |
| 9 | +#include "private/token_credential_impl.hpp" |
| 10 | + |
| 11 | +#include <azure/core/internal/json/json.hpp> |
| 12 | + |
| 13 | +using Azure::Identity::ClientAssertionCredential; |
| 14 | +using Azure::Identity::ClientAssertionCredentialOptions; |
| 15 | + |
| 16 | +using Azure::Core::Context; |
| 17 | +using Azure::Core::Url; |
| 18 | +using Azure::Core::_internal::StringExtensions; |
| 19 | +using Azure::Core::Credentials::AccessToken; |
| 20 | +using Azure::Core::Credentials::AuthenticationException; |
| 21 | +using Azure::Core::Credentials::TokenRequestContext; |
| 22 | +using Azure::Core::Http::HttpMethod; |
| 23 | +using Azure::Identity::_detail::IdentityLog; |
| 24 | +using Azure::Identity::_detail::TenantIdResolver; |
| 25 | +using Azure::Identity::_detail::TokenCredentialImpl; |
| 26 | + |
| 27 | +namespace { |
| 28 | +bool IsValidTenantId(std::string const& tenantId) |
| 29 | +{ |
| 30 | + const std::string allowedChars = ".-"; |
| 31 | + if (tenantId.empty()) |
| 32 | + { |
| 33 | + return false; |
| 34 | + } |
| 35 | + for (auto const c : tenantId) |
| 36 | + { |
| 37 | + if (allowedChars.find(c) != std::string::npos) |
| 38 | + { |
| 39 | + continue; |
| 40 | + } |
| 41 | + if (!StringExtensions::IsAlphaNumeric(c)) |
| 42 | + { |
| 43 | + return false; |
| 44 | + } |
| 45 | + } |
| 46 | + return true; |
| 47 | +} |
| 48 | +} // namespace |
| 49 | + |
| 50 | +ClientAssertionCredential::ClientAssertionCredential( |
| 51 | + std::string tenantId, |
| 52 | + std::string clientId, |
| 53 | + std::function<std::string(Context const&)> assertionCallback, |
| 54 | + ClientAssertionCredentialOptions const& options) |
| 55 | + : TokenCredential("ClientAssertionCredential"), |
| 56 | + m_assertionCallback(std::move(assertionCallback)), |
| 57 | + m_clientCredentialCore(tenantId, options.AuthorityHost, options.AdditionallyAllowedTenants) |
| 58 | +{ |
| 59 | + bool isTenantIdValid = IsValidTenantId(tenantId); |
| 60 | + if (!isTenantIdValid) |
| 61 | + { |
| 62 | + IdentityLog::Write( |
| 63 | + IdentityLog::Level::Warning, |
| 64 | + GetCredentialName() |
| 65 | + + ": Invalid tenant ID provided. The tenant ID must be a non-empty string containing " |
| 66 | + "only alphanumeric characters, periods, or hyphens. You can locate your tenant ID by " |
| 67 | + "following the instructions listed here: " |
| 68 | + "https://learn.microsoft.com/partner-center/find-ids-and-domain-names"); |
| 69 | + } |
| 70 | + if (clientId.empty()) |
| 71 | + { |
| 72 | + IdentityLog::Write( |
| 73 | + IdentityLog::Level::Warning, GetCredentialName() + ": No client ID specified."); |
| 74 | + } |
| 75 | + if (!m_assertionCallback) |
| 76 | + { |
| 77 | + IdentityLog::Write( |
| 78 | + IdentityLog::Level::Warning, |
| 79 | + GetCredentialName() |
| 80 | + + ": The assertionCallback must be a valid function that returns assertions."); |
| 81 | + } |
| 82 | + |
| 83 | + if (isTenantIdValid && !clientId.empty() && m_assertionCallback) |
| 84 | + { |
| 85 | + m_tokenCredentialImpl = std::make_unique<TokenCredentialImpl>(options); |
| 86 | + m_requestBody |
| 87 | + = std::string( |
| 88 | + "grant_type=client_credentials" |
| 89 | + "&client_assertion_type=" |
| 90 | + "urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer" // cspell:disable-line |
| 91 | + "&client_id=") |
| 92 | + + Url::Encode(clientId); |
| 93 | + |
| 94 | + IdentityLog::Write( |
| 95 | + IdentityLog::Level::Informational, GetCredentialName() + " was created successfully."); |
| 96 | + } |
| 97 | + else |
| 98 | + { |
| 99 | + // Rather than throwing an exception in the ctor, following the pattern in existing credentials |
| 100 | + // to log the errors, and defer throwing an exception to the first call of GetToken(). This is |
| 101 | + // primarily needed for credentials that are part of the DefaultAzureCredential, which this |
| 102 | + // credential is not intended for. |
| 103 | + IdentityLog::Write( |
| 104 | + IdentityLog::Level::Warning, GetCredentialName() + " was not initialized correctly."); |
| 105 | + } |
| 106 | +} |
| 107 | + |
| 108 | +ClientAssertionCredential::~ClientAssertionCredential() = default; |
| 109 | + |
| 110 | +AccessToken ClientAssertionCredential::GetToken( |
| 111 | + TokenRequestContext const& tokenRequestContext, |
| 112 | + Context const& context) const |
| 113 | +{ |
| 114 | + if (!m_tokenCredentialImpl) |
| 115 | + { |
| 116 | + auto const AuthUnavailable = GetCredentialName() + " authentication unavailable. "; |
| 117 | + |
| 118 | + IdentityLog::Write( |
| 119 | + IdentityLog::Level::Warning, |
| 120 | + AuthUnavailable + "See earlier " + GetCredentialName() + " log messages for details."); |
| 121 | + |
| 122 | + throw AuthenticationException(AuthUnavailable); |
| 123 | + } |
| 124 | + |
| 125 | + auto const tenantId = TenantIdResolver::Resolve( |
| 126 | + m_clientCredentialCore.GetTenantId(), |
| 127 | + tokenRequestContext, |
| 128 | + m_clientCredentialCore.GetAdditionallyAllowedTenants()); |
| 129 | + |
| 130 | + auto const scopesStr |
| 131 | + = m_clientCredentialCore.GetScopesString(tenantId, tokenRequestContext.Scopes); |
| 132 | + |
| 133 | + // TokenCache::GetToken() and m_tokenCredentialImpl->GetToken() can only use the lambda |
| 134 | + // argument when they are being executed. They are not supposed to keep a reference to lambda |
| 135 | + // argument to call it later. Therefore, any capture made here will outlive the possible time |
| 136 | + // frame when the lambda might get called. |
| 137 | + return m_tokenCache.GetToken(scopesStr, tenantId, tokenRequestContext.MinimumExpiration, [&]() { |
| 138 | + return m_tokenCredentialImpl->GetToken(context, false, [&]() { |
| 139 | + auto body = m_requestBody; |
| 140 | + if (!scopesStr.empty()) |
| 141 | + { |
| 142 | + body += "&scope=" + scopesStr; |
| 143 | + } |
| 144 | + |
| 145 | + // Get the request url before calling m_assertionCallback to validate the authority host |
| 146 | + // scheme (GetRequestUrl() will throw if validation fails). This is to avoid calling the |
| 147 | + // assertion callback if the authority host scheme is invalid. |
| 148 | + auto const requestUrl = m_clientCredentialCore.GetRequestUrl(tenantId); |
| 149 | + |
| 150 | + const std::string assertion = m_assertionCallback(context); |
| 151 | + |
| 152 | + body += "&client_assertion=" + Azure::Core::Url::Encode(assertion); |
| 153 | + |
| 154 | + auto request |
| 155 | + = std::make_unique<TokenCredentialImpl::TokenRequest>(HttpMethod::Post, requestUrl, body); |
| 156 | + |
| 157 | + request->HttpRequest.SetHeader("Host", requestUrl.GetHost()); |
| 158 | + |
| 159 | + return request; |
| 160 | + }); |
| 161 | + }); |
| 162 | +} |
0 commit comments