Skip to content

Commit d0bb48d

Browse files
committed
swap to use CRT implementation of STS Web Identity Provider
1 parent d1035e2 commit d0bb48d

19 files changed

+402
-450
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ if (LEGACY_BUILD)
7171
option(ENABLE_PROTOCOL_TESTS "Enable protocol tests" OFF)
7272
option(DISABLE_DNS_REQUIRED_TESTS "Disable unit tests that require DNS lookup to succeed, useful when using a http client that does not perform DNS lookup" OFF)
7373
option(AWS_APPSTORE_SAFE "Remove reference to private Apple APIs for AES GCM in Common Crypto. If set to OFF you application will get rejected from the apple app store." OFF)
74+
option(AWS_ENABLE_CORE_INTEGRATION_TEST "Enables the core integration tests to be built which contains dependencies on other clients for setup and tear down" OFF)
7475

7576

7677
set(AWS_USER_AGENT_CUSTOMIZATION "" CACHE STRING "User agent extension")

cmake/sdksCommon.cmake

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,11 @@ list(APPEND HIGH_LEVEL_SDK_LIST "text-to-speech")
8787
set(SDK_TEST_PROJECT_LIST "")
8888
list(APPEND SDK_TEST_PROJECT_LIST "cloudfront:tests/aws-cpp-sdk-cloudfront-integration-tests")
8989
list(APPEND SDK_TEST_PROJECT_LIST "cognito-identity:tests/aws-cpp-sdk-cognitoidentity-integration-tests")
90-
list(APPEND SDK_TEST_PROJECT_LIST "core:tests/aws-cpp-sdk-core-tests")
90+
if (AWS_ENABLE_CORE_INTEGRATION_TEST)
91+
list(APPEND SDK_TEST_PROJECT_LIST "core:tests/aws-cpp-sdk-core-tests,tests/aws-cpp-sdk-core-integration-tests")
92+
else ()
93+
list(APPEND SDK_TEST_PROJECT_LIST "core:tests/aws-cpp-sdk-core-tests")
94+
endif ()
9195
list(APPEND SDK_TEST_PROJECT_LIST "dynamodb:tests/aws-cpp-sdk-dynamodb-integration-tests")
9296
list(APPEND SDK_TEST_PROJECT_LIST "dynamodb:tests/aws-cpp-sdk-dynamodb-unit-tests")
9397
list(APPEND SDK_TEST_PROJECT_LIST "ec2:tests/aws-cpp-sdk-ec2-integration-tests")
@@ -146,6 +150,9 @@ list(APPEND TEST_DEPENDENCY_LIST "sqs:access-management,cognito-identity,iam,cor
146150
list(APPEND TEST_DEPENDENCY_LIST "text-to-speech:polly,core")
147151
list(APPEND TEST_DEPENDENCY_LIST "transfer:s3,core")
148152
list(APPEND TEST_DEPENDENCY_LIST "logs:access-management,cognito-identity,iam,core")
153+
if (AWS_ENABLE_CORE_INTEGRATION_TEST)
154+
list(APPEND TEST_DEPENDENCY_LIST "core:sts,iam,cognito-identity")
155+
endif ()
149156

150157
set(GENERATED_SERVICE_LIST ${SERVICE_CLIENT_LIST})
151158
foreach(SERVICE_NAME IN LISTS SERVICE_CLIENT_LIST)

src/aws-cpp-sdk-core/include/aws/core/Globals.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#pragma once
77

88
#include <aws/core/Core_EXPORTS.h>
9+
#include <memory>
910

1011
namespace Aws
1112
{

src/aws-cpp-sdk-core/include/aws/core/auth/STSCredentialsProvider.h

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77
#pragma once
88

99
#include <aws/core/Core_EXPORTS.h>
10-
#include <aws/core/utils/DateTime.h>
11-
#include <aws/core/utils/memory/stl/AWSString.h>
12-
#include <aws/core/internal/AWSHttpResourceClient.h>
1310
#include <aws/core/auth/AWSCredentialsProvider.h>
14-
#include <memory>
11+
12+
namespace Aws {
13+
namespace Crt {
14+
namespace Auth {
15+
class ICredentialsProvider;
16+
}
17+
}
18+
}
1519

1620
namespace Aws
1721
{
@@ -27,6 +31,7 @@ namespace Aws
2731
public:
2832
STSAssumeRoleWebIdentityCredentialsProvider();
2933
STSAssumeRoleWebIdentityCredentialsProvider(Aws::Client::ClientConfiguration::CredentialProviderConfiguration config);
34+
virtual ~STSAssumeRoleWebIdentityCredentialsProvider();
3035

3136
/**
3237
* Retrieves the credentials if found, otherwise returns empty credential set.
@@ -37,17 +42,14 @@ namespace Aws
3742
void Reload() override;
3843

3944
private:
40-
void RefreshIfExpired();
41-
Aws::String CalculateQueryString() const;
42-
43-
Aws::UniquePtr<Aws::Internal::STSCredentialsClient> m_client;
44-
Aws::Auth::AWSCredentials m_credentials;
45-
Aws::String m_roleArn;
46-
Aws::String m_tokenFile;
47-
Aws::String m_sessionName;
48-
Aws::String m_token;
49-
bool m_initialized;
50-
bool ExpiresSoon() const;
45+
enum class STATE {
46+
INITIALIZED,
47+
SHUT_DOWN,
48+
} m_state{STATE::SHUT_DOWN};
49+
std::mutex m_refreshMutex;
50+
std::condition_variable m_refreshSignal;
51+
std::shared_ptr<Aws::Crt::Auth::ICredentialsProvider> m_credentialsProvider;
52+
std::chrono::milliseconds m_providerFuturesTimeoutMs;
5153
};
5254
} // namespace Auth
5355
} // namespace Aws

src/aws-cpp-sdk-core/include/aws/core/client/ClientConfiguration.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,33 @@ namespace Aws
519519
bool disableImdsV1;
520520
bool disableImds;
521521
} imdsConfig;
522-
}credentialProviderConfig;
522+
523+
/**
524+
* Configuration for the STSCredentials provider
525+
*/
526+
struct STSCredentialsCredentialProviderConfiguration {
527+
STSCredentialsCredentialProviderConfiguration() = default;
528+
STSCredentialsCredentialProviderConfiguration(const Aws::String& role, const Aws::String& session, const String& tokenFile)
529+
: roleArn(role), sessionName(session), tokenFilePath(tokenFile) {};
530+
/**
531+
* Arn of the role to assume by fetching credentials for
532+
*/
533+
Aws::String roleArn;
534+
/**
535+
* Assumed role session identifier to be associated with the sourced credentials
536+
*/
537+
Aws::String sessionName;
538+
/**
539+
* The OAuth 2.0 access token or OpenID Connect ID token
540+
*/
541+
Aws::String tokenFilePath;
542+
543+
/**
544+
* Time out for the credentials future call.
545+
*/
546+
std::chrono::milliseconds retrieveCredentialsFutureTimeout = std::chrono::seconds(10);
547+
} stsCredentialsProviderConfig;
548+
} credentialProviderConfig;
523549
};
524550

525551
/**

src/aws-cpp-sdk-core/source/auth/STSCredentialsProvider.cpp

Lines changed: 85 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -2,166 +2,108 @@
22
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
33
* SPDX-License-Identifier: Apache-2.0.
44
*/
5-
6-
5+
#include <aws/core/Globals.h>
76
#include <aws/core/auth/STSCredentialsProvider.h>
8-
#include <aws/core/config/AWSProfileConfigLoader.h>
7+
#include <aws/core/client/ClientConfiguration.h>
98
#include <aws/core/platform/Environment.h>
10-
#include <aws/core/platform/FileSystem.h>
11-
#include <aws/core/utils/logging/LogMacros.h>
12-
#include <aws/core/utils/StringUtils.h>
13-
#include <aws/core/utils/FileSystemUtils.h>
14-
#include <aws/core/client/SpecifiedRetryableErrorsRetryStrategy.h>
15-
#include <aws/core/utils/StringUtils.h>
16-
#include <aws/core/utils/UUID.h>
17-
#include <cstdlib>
18-
#include <fstream>
19-
#include <string.h>
20-
#include <climits>
21-
9+
#include <aws/crt/auth/Credentials.h>
2210

23-
using namespace Aws::Utils;
24-
using namespace Aws::Utils::Logging;
2511
using namespace Aws::Auth;
26-
using namespace Aws::Internal;
27-
using namespace Aws::FileSystem;
28-
using namespace Aws::Client;
29-
using Aws::Utils::Threading::ReaderLockGuard;
30-
using Aws::Utils::Threading::WriterLockGuard;
12+
using namespace Aws::Utils;
3113

32-
static const char STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG[] = "STSAssumeRoleWithWebIdentityCredentialsProvider";
33-
static const int STS_CREDENTIAL_PROVIDER_EXPIRATION_GRACE_PERIOD = 5 * 60 * 1000; // 5 Minutes.
14+
namespace {
15+
const char* STS_LOG_TAG = "STSAssumeRoleWebIdentityCredentialsProvider";
16+
}
3417

35-
STSAssumeRoleWebIdentityCredentialsProvider::STSAssumeRoleWebIdentityCredentialsProvider(Aws::Client::ClientConfiguration::CredentialProviderConfiguration credentialsConfig):
36-
m_initialized(false)
18+
STSAssumeRoleWebIdentityCredentialsProvider::STSAssumeRoleWebIdentityCredentialsProvider(
19+
Aws::Client::ClientConfiguration::CredentialProviderConfiguration credentialsConfig)
20+
: m_credentialsProvider(nullptr), m_providerFuturesTimeoutMs(credentialsConfig.stsCredentialsProviderConfig.retrieveCredentialsFutureTimeout)
3721
{
38-
m_roleArn = Aws::Environment::GetEnv("AWS_ROLE_ARN");
39-
m_tokenFile = Aws::Environment::GetEnv("AWS_WEB_IDENTITY_TOKEN_FILE");
40-
m_sessionName = Aws::Environment::GetEnv("AWS_ROLE_SESSION_NAME");
41-
42-
// check profile_config if either m_roleArn or m_tokenFile is not loaded from environment variable
43-
// region source is not enforced, but we need it to construct sts endpoint, if we can't find from environment, we should check if it's set in config file.
44-
if (m_roleArn.empty() || m_tokenFile.empty())
45-
{
46-
auto profile = Aws::Config::GetCachedConfigProfile(credentialsConfig.profile);
47-
// If either of these two were not found from environment, use whatever found for all three in config file
48-
if (m_roleArn.empty() || m_tokenFile.empty())
49-
{
50-
m_roleArn = profile.GetRoleArn();
51-
m_tokenFile = profile.GetValue("web_identity_token_file");
52-
m_sessionName = profile.GetValue("role_session_name");
53-
}
54-
}
55-
56-
if (m_tokenFile.empty())
57-
{
58-
AWS_LOGSTREAM_WARN(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Token file must be specified to use STS AssumeRole web identity creds provider.");
59-
return; // No need to do further constructing
22+
Aws::Crt::Auth::CredentialsProviderSTSWebIdentityConfig stsConfig{};
23+
stsConfig.Bootstrap = GetDefaultClientBootstrap();
24+
Aws::Crt::Io::TlsContextOptions tlsCtxOptions = Aws::Crt::Io::TlsContextOptions::InitDefaultClient();
25+
const Aws::Crt::Io::TlsContext tlsContext(tlsCtxOptions, Aws::Crt::Io::TlsMode::CLIENT);
26+
stsConfig.TlsCtx = tlsContext;
27+
stsConfig.Region = credentialsConfig.region.c_str();
28+
stsConfig.TokenFilePath = credentialsConfig.stsCredentialsProviderConfig.tokenFilePath.c_str();
29+
stsConfig.RoleArn = credentialsConfig.stsCredentialsProviderConfig.roleArn.c_str();
30+
stsConfig.SessionName = [&credentialsConfig]() -> Aws::String {
31+
if (!credentialsConfig.stsCredentialsProviderConfig.sessionName.empty()) {
32+
return credentialsConfig.stsCredentialsProviderConfig.sessionName;
6033
}
61-
else
62-
{
63-
AWS_LOGSTREAM_DEBUG(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Resolved token_file from profile_config or environment variable to be " << m_tokenFile);
64-
}
65-
66-
if (m_roleArn.empty())
67-
{
68-
AWS_LOGSTREAM_WARN(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "RoleArn must be specified to use STS AssumeRole web identity creds provider.");
69-
return; // No need to do further constructing
70-
}
71-
else
72-
{
73-
AWS_LOGSTREAM_DEBUG(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Resolved role_arn from profile_config or environment variable to be " << m_roleArn);
74-
}
75-
76-
if (m_sessionName.empty())
77-
{
78-
m_sessionName = Aws::Utils::UUID::PseudoRandomUUID();
79-
}
80-
else
81-
{
82-
AWS_LOGSTREAM_DEBUG(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Resolved session_name from profile_config or environment variable to be " << m_sessionName);
83-
}
84-
85-
Aws::Client::ClientConfiguration config;
86-
config.scheme = Aws::Http::Scheme::HTTPS;
87-
config.region = credentialsConfig.region;
88-
Aws::Vector<Aws::String> retryableErrors;
89-
retryableErrors.push_back("IDPCommunicationError");
90-
retryableErrors.push_back("InvalidIdentityToken");
91-
92-
config.retryStrategy = Aws::MakeShared<SpecifiedRetryableErrorsRetryStrategy>(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, retryableErrors, 3/*maxRetries*/);
93-
94-
m_client = Aws::MakeUnique<Aws::Internal::STSCredentialsClient>(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, config);
95-
m_initialized = true;
96-
AWS_LOGSTREAM_INFO(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Creating STS AssumeRole with web identity creds provider.");
34+
return UUID::RandomUUID();
35+
}().c_str();
36+
m_credentialsProvider = Aws::Crt::Auth::CredentialsProvider::CreateCredentialsProviderSTSWebIdentity(stsConfig);
37+
if (m_credentialsProvider && m_credentialsProvider->IsValid()) {
38+
m_state = STATE::INITIALIZED;
39+
} else {
40+
AWS_LOGSTREAM_WARN(STS_LOG_TAG, "Failed to create STS credentials provider");
41+
}
9742
}
9843

99-
Aws::String LegacyGetRegion() {
100-
auto region = Aws::Environment::GetEnv("AWS_DEFAULT_REGION");
101-
if (region.empty()) {
44+
Aws::String GetLegacySettingFromEnvOrProfile(const Aws::String& envVar,
45+
std::function<Aws::String (Aws::Config::Profile)> profileFetchFunction)
46+
{
47+
auto value = Aws::Environment::GetEnv(envVar.c_str());
48+
if (value.empty()) {
10249
auto profile = Aws::Config::GetCachedConfigProfile(Aws::Auth::GetConfigProfileName());
103-
region = profile.GetRegion();
50+
value = profileFetchFunction(profile);
10451
}
105-
return region;
52+
return value;
10653
}
10754

10855
STSAssumeRoleWebIdentityCredentialsProvider::STSAssumeRoleWebIdentityCredentialsProvider()
10956
: STSAssumeRoleWebIdentityCredentialsProvider(
110-
Aws::Client::ClientConfiguration::CredentialProviderConfiguration{Aws::Auth::GetConfigProfileName(), LegacyGetRegion(), {}}) {}
111-
112-
AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::GetAWSCredentials()
113-
{
114-
// A valid client means required information like role arn and token file were constructed correctly.
115-
// We can use this provider to load creds, otherwise, we can just return empty creds.
116-
if (!m_initialized)
117-
{
118-
return Aws::Auth::AWSCredentials();
119-
}
120-
RefreshIfExpired();
121-
ReaderLockGuard guard(m_reloadLock);
122-
return m_credentials;
123-
}
124-
125-
void STSAssumeRoleWebIdentityCredentialsProvider::Reload()
126-
{
127-
AWS_LOGSTREAM_INFO(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Credentials have expired, attempting to renew from STS.");
128-
129-
Aws::IFStream tokenFile(m_tokenFile.c_str());
130-
if(tokenFile)
131-
{
132-
Aws::String token((std::istreambuf_iterator<char>(tokenFile)), std::istreambuf_iterator<char>());
133-
m_token = token;
134-
}
135-
else
136-
{
137-
AWS_LOGSTREAM_ERROR(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Can't open token file: " << m_tokenFile);
138-
return;
139-
}
140-
STSCredentialsClient::STSAssumeRoleWithWebIdentityRequest request {m_sessionName, m_roleArn, m_token};
141-
142-
auto result = m_client->GetAssumeRoleWithWebIdentityCredentials(request);
143-
AWS_LOGSTREAM_TRACE(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Successfully retrieved credentials with AWS_ACCESS_KEY: " << result.creds.GetAWSAccessKeyId());
144-
m_credentials = result.creds;
145-
}
57+
Aws::Client::ClientConfiguration::CredentialProviderConfiguration{
58+
Aws::Auth::GetConfigProfileName(),
59+
GetLegacySettingFromEnvOrProfile("AWS_DEFAULT_REGION",
60+
[](const Aws::Config::Profile& profile) -> Aws::String { return profile.GetRegion(); }),
61+
{},
62+
{
63+
GetLegacySettingFromEnvOrProfile("AWS_ROLE_ARN",
64+
[](const Aws::Config::Profile& profile) -> Aws::String { return profile.GetRoleArn(); }),
65+
GetLegacySettingFromEnvOrProfile("AWS_ROLE_SESSION_NAME",
66+
[](const Aws::Config::Profile& profile) -> Aws::String { return profile.GetValue("role_session_name"); }),
67+
GetLegacySettingFromEnvOrProfile("AWS_WEB_IDENTITY_TOKEN_FILE",
68+
[](const Aws::Config::Profile& profile) -> Aws::String { return profile.GetValue("web_identity_token_file"); })
69+
}})
70+
{}
71+
72+
STSAssumeRoleWebIdentityCredentialsProvider::~STSAssumeRoleWebIdentityCredentialsProvider() = default;
73+
74+
AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::GetAWSCredentials() {
75+
if (m_state != STATE::INITIALIZED) {
76+
AWS_LOGSTREAM_DEBUG(STS_LOG_TAG, "STSCredentialsProvider is not initialized, returning empty credentials");
77+
return AWSCredentials{};
78+
}
79+
AWSCredentials credentials{};
80+
auto refreshDone = false;
81+
m_credentialsProvider->GetCredentials(
82+
[this, &credentials, &refreshDone](std::shared_ptr<Aws::Crt::Auth::Credentials> crtCredentials, int errorCode) -> void {
83+
{
84+
const std::unique_lock<std::mutex> lock{m_refreshMutex};
85+
if (errorCode != AWS_ERROR_SUCCESS) {
86+
AWS_LOGSTREAM_ERROR(STS_LOG_TAG, "Failed to get credentials from STS: " << errorCode);
87+
} else {
88+
const auto accountIdCursor = crtCredentials->GetAccessKeyId();
89+
credentials.SetAWSAccessKeyId({reinterpret_cast<char*>(accountIdCursor.ptr), accountIdCursor.len});
90+
const auto secretKeuCursor = crtCredentials->GetSecretAccessKey();
91+
credentials.SetAWSSecretKey({reinterpret_cast<char*>(secretKeuCursor.ptr), secretKeuCursor.len});
92+
const auto expiration = crtCredentials->GetExpirationTimepointInSeconds();
93+
credentials.SetExpiration(DateTime{static_cast<double>(expiration)});
94+
const auto sessionTokenCursor = crtCredentials->GetSessionToken();
95+
credentials.SetSessionToken({reinterpret_cast<char*>(sessionTokenCursor.ptr), sessionTokenCursor.len});
96+
}
97+
refreshDone = true;
98+
}
99+
m_refreshSignal.notify_one();
100+
});
146101

147-
bool STSAssumeRoleWebIdentityCredentialsProvider::ExpiresSoon() const
148-
{
149-
return ((m_credentials.GetExpiration() - Aws::Utils::DateTime::Now()).count() < STS_CREDENTIAL_PROVIDER_EXPIRATION_GRACE_PERIOD);
102+
std::unique_lock<std::mutex> lock{m_refreshMutex};
103+
m_refreshSignal.wait_for(lock, m_providerFuturesTimeoutMs, [&refreshDone]() -> bool { return refreshDone; });
104+
return credentials;
150105
}
151106

152-
void STSAssumeRoleWebIdentityCredentialsProvider::RefreshIfExpired()
153-
{
154-
ReaderLockGuard guard(m_reloadLock);
155-
if (!m_credentials.IsEmpty() && !ExpiresSoon())
156-
{
157-
return;
158-
}
159-
160-
guard.UpgradeToWriterLock();
161-
if (!m_credentials.IsExpiredOrEmpty() && !ExpiresSoon()) // double-checked lock to avoid refreshing twice
162-
{
163-
return;
164-
}
165-
166-
Reload();
107+
void STSAssumeRoleWebIdentityCredentialsProvider::Reload() {
108+
AWS_LOGSTREAM_DEBUG(STS_LOG_TAG, "Calling reload on STSCredentialsProvider is a no-op and no longer in the call path");
167109
}

0 commit comments

Comments
 (0)