Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use sts to construct stream arn #463

Merged
merged 3 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ set(STATIC_LIBS
boost_chrono)

find_package(Threads)
find_package(AWSSDK REQUIRED COMPONENTS kinesis monitoring)
find_package(AWSSDK REQUIRED COMPONENTS kinesis monitoring sts)

add_library(LibCrypto STATIC IMPORTED)
set_property(TARGET LibCrypto PROPERTY IMPORTED_LOCATION ${THIRD_PARTY_LIB_DIR}/libcrypto.a)
Expand Down
62 changes: 62 additions & 0 deletions aws/kinesis/core/configuration.h
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,26 @@ class Configuration : private boost::noncopyable {
return proxy_password_;
}

// Use a custom Sts endpoint.
//
// Note this does not accept protocols or paths, only host names or ip
// addresses. There is no way to disable TLS. The KPL always connects with
// TLS.
//
// Expected pattern: ^([A-Za-z0-9-\\.]+)?$
const std::string& sts_endpoint() const noexcept {
return sts_endpoint_;
}

// Server port to connect to for STS.
//
// Default: 443
// Minimum: 1
// Maximum (inclusive): 65535
size_t sts_port() const noexcept {
return sts_port_;
}

/// Indicates whether the SDK clients should use a thread pool or not
/// \return true if the client should use a thread pool, false otherwise
bool use_thread_pool() const noexcept {
Expand Down Expand Up @@ -1009,6 +1029,43 @@ class Configuration : private boost::noncopyable {
return *this;
}

// Use a custom STS endpoint.
//
// Note this does not accept protocols or paths, only host names or ip
// addresses. There is no way to disable TLS. The KPL always connects with
// TLS.
//
// Expected pattern: ^([A-Za-z0-9-\\.]+)?$
Configuration& sts_endpoint(std::string val) {
static std::regex pattern(
"^([A-Za-z0-9-\\.]+)?$",
std::regex::ECMAScript | std::regex::optimize);
if (!std::regex_match(val, pattern)) {
std::string err;
err += "sts_endpoint must match the pattern ^([A-Za-z0-9-\\.]+)?$, got ";
err += val;
throw std::runtime_error(err);
}
sts_endpoint_ = val;
return *this;
}

// Server port to connect to for STS.
//
// Default: 443
// Minimum: 1
// Maximum (inclusive): 65535
Configuration& sts_port(size_t val) {
if (val < 1ull || val > 65535ull) {
std::string err;
err += "sts_port must be between 1 and 65535, got ";
err += std::to_string(val);
throw std::runtime_error(err);
}
sts_port_ = val;
return *this;
}

/// Enables or disable the use of a thread pool for the SDK Client.
/// Default: false
/// \param val whether or not to use a thread pool
Expand Down Expand Up @@ -1078,6 +1135,9 @@ class Configuration : private boost::noncopyable {
proxy_port(c.proxy_port());
proxy_user_name(c.proxy_user_name());
proxy_password(c.proxy_password());
sts_endpoint(c.sts_endpoint());
sts_port(c.sts_port());

if (c.thread_config() == ::aws::kinesis::protobuf::Configuration_ThreadConfig::Configuration_ThreadConfig_POOLED) {
use_thread_pool(true);
thread_pool_size(c.thread_pool_size());
Expand Down Expand Up @@ -1123,6 +1183,8 @@ class Configuration : private boost::noncopyable {
size_t proxy_port_ = 443;
std::string proxy_user_name_ = "";
std::string proxy_password_ = "";
std::string sts_endpoint_ = "";
size_t sts_port_ = 443;

bool use_thread_pool_ = true;
uint32_t thread_pool_size_ = 64;
Expand Down
27 changes: 25 additions & 2 deletions aws/kinesis/core/kinesis_producer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,16 @@ namespace {
struct EndpointConfiguration {
std::string kinesis_endpoint_;
std::string cloudwatch_endpoint_;
std::string sts_endpoint_;

EndpointConfiguration(std::string kinesis_endpoint, std::string cloudwatch_endpoint) :
kinesis_endpoint_(kinesis_endpoint), cloudwatch_endpoint_(cloudwatch_endpoint) {}
EndpointConfiguration(std::string kinesis_endpoint, std::string cloudwatch_endpoint) {
EndpointConfiguration(kinesis_endpoint, cloudwatch_endpoint, {});
}

EndpointConfiguration(std::string kinesis_endpoint, std::string cloudwatch_endpoint, std::string sts_endpoint) :
kinesis_endpoint_(kinesis_endpoint),
cloudwatch_endpoint_(cloudwatch_endpoint),
sts_endpoint_(sts_endpoint) {}
};

const constexpr char* kVersion = "0.14.13N";
Expand Down Expand Up @@ -200,6 +207,21 @@ void KinesisProducer::create_cw_client(const std::string& ca_path) {
cfg);
}

void KinesisProducer::create_sts_client(const std::string& ca_path) {
auto cfg = make_sdk_client_cfg(*config_, region_, ca_path, 2);
if (config_->sts_endpoint().size() > 0) {
cfg.endpointOverride = config_->sts_endpoint() + ":" +
std::to_string(config_->sts_port());
LOG(info) << "Using STS endpoint " + cfg.endpointOverride;
} else {
set_override_if_present(region_, cfg, "STS", [](auto ep) -> std::string { return ep.sts_endpoint_; });
}

sts_client_ = std::make_shared<Aws::STS::STSClient>(
kinesis_creds_provider_, // STS doesn't require any permissions, so Kinesis cred works here
cfg);
}

Pipeline* KinesisProducer::create_pipeline(const std::string& stream) {
LOG(info) << "Created pipeline for stream \"" << stream << "\"";
return new Pipeline(
Expand All @@ -209,6 +231,7 @@ Pipeline* KinesisProducer::create_pipeline(const std::string& stream) {
executor_,
kinesis_client_,
metrics_manager_,
sts_client_,
[this](auto& ur) {
ipc_manager_->put(ur->to_put_record_result().SerializeAsString());
});
Expand Down
4 changes: 4 additions & 0 deletions aws/kinesis/core/kinesis_producer.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class KinesisProducer : boost::noncopyable {
shutdown_(false) {
create_kinesis_client(ca_path);
create_cw_client(ca_path);
create_sts_client(ca_path);
create_metrics_manager();
report_outstanding();
message_drainer_ = aws::thread([this] { this->drain_messages(); });
Expand All @@ -77,6 +78,8 @@ class KinesisProducer : boost::noncopyable {

void create_cw_client(const std::string& ca_path);

void create_sts_client(const std::string& ca_path);

Pipeline* create_pipeline(const std::string& stream);

void drain_messages();
Expand All @@ -103,6 +106,7 @@ class KinesisProducer : boost::noncopyable {
cw_creds_provider_;
std::shared_ptr<Aws::Kinesis::KinesisClient> kinesis_client_;
std::shared_ptr<Aws::CloudWatch::CloudWatchClient> cw_client_;
std::shared_ptr<Aws::STS::STSClient> sts_client_;
std::shared_ptr<aws::utils::Executor> executor_;

std::shared_ptr<IpcManager> ipc_manager_;
Expand Down
47 changes: 45 additions & 2 deletions aws/kinesis/core/pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <boost/format.hpp>
#include <iomanip>

#include <aws/core/utils/ARN.h>
#include <aws/core/utils/StringUtils.h>
#include <aws/kinesis/core/aggregator.h>
#include <aws/kinesis/core/collector.h>
#include <aws/kinesis/core/configuration.h>
Expand All @@ -29,6 +31,12 @@
#include <aws/kinesis/KinesisClient.h>
#include <aws/metrics/metrics_manager.h>
#include <aws/utils/processing_statistics_logger.h>
#include <aws/sts/STSClient.h>
#include <aws/sts/model/GetCallerIdentityRequest.h>
#include <aws/sts/model/GetCallerIdentityResult.h>

#include <aws/utils/logging.h>


namespace aws {
namespace kinesis {
Expand All @@ -46,20 +54,24 @@ class Pipeline : boost::noncopyable {
std::shared_ptr<aws::utils::Executor> executor,
std::shared_ptr<Aws::Kinesis::KinesisClient> kinesis_client,
std::shared_ptr<aws::metrics::MetricsManager> metrics_manager,
std::shared_ptr<Aws::STS::STSClient> sts_client,
Retrier::UserRecordCallback finish_user_record_cb)
: stream_(std::move(stream)),
region_(std::move(region)),
stream_arn_(std::move(init_stream_arn(sts_client, region_, stream_))),
config_(std::move(config)),
stats_logger_(stream_, config_->record_max_buffered_time()),
executor_(std::move(executor)),
kinesis_client_(std::move(kinesis_client)),
metrics_manager_(std::move(metrics_manager)),
sts_client_(std::move(sts_client)),
finish_user_record_cb_(std::move(finish_user_record_cb)),
shard_map_(
std::make_shared<ShardMap>(
executor_,
kinesis_client_,
stream_,
stream_arn_,
metrics_manager_)),
aggregator_(
std::make_shared<Aggregator>(
Expand Down Expand Up @@ -151,7 +163,7 @@ class Pipeline : boost::noncopyable {
}

void send_put_records_request(const std::shared_ptr<PutRecordsRequest>& prr) {
auto prc = std::make_shared<PutRecordsContext>(stream_, prr->items());
auto prc = std::make_shared<PutRecordsContext>(stream_, stream_arn_, prr->items());
prc->set_start(std::chrono::steady_clock::now());
kinesis_client_->PutRecordsAsync(
prc->to_sdk_request(),
Expand Down Expand Up @@ -190,13 +202,44 @@ class Pipeline : boost::noncopyable {
});
}

std::string stream_;
// Retrieve the account ID and partition from the STS service.
static std::string init_stream_arn(const std::shared_ptr<Aws::STS::STSClient>& sts_client,
const std::string &region,
const std::string &stream_name) {
Aws::STS::Model::GetCallerIdentityRequest request;
auto outcome = sts_client->GetCallerIdentity(request);
if (outcome.IsSuccess()) {
auto result = outcome.GetResult();
Aws::Utils::ARN sts_arn(result.GetArn());

// Construct and return the Kinesis stream ARN.
std::stringstream arn;
arn << "arn:" << sts_arn.GetPartition() << ":kinesis:" << region << ":" << result.GetAccount()
<< ":stream/" << stream_name;
Comment on lines +217 to +218
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this stream arn format same for all regions? I thought we had some issues with arn format in some regions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good call out, we once had a problem with pod1, but I think the format is good for all public commercial regions

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, we should still test to make sure this works in all partitions.


auto arn_str = arn.str();
LOG(info) << "StreamARN \"" << arn_str << "\" has been successfully configured, "
<< "and will be used in requests including ListShards and PutRecords";
return arn_str;
}
auto e = outcome.GetError();
auto code = e.GetExceptionName();
auto msg = e.GetMessage();
LOG(error) << "Failed to get StreamARN using STS GetCallerIdentity | Code: " << code
<< " | Message: " << msg
<< " | Request was: " << request.SerializePayload();
exit(EXIT_FAILURE);
}

std::string region_;
std::string stream_;
std::string stream_arn_;
std::shared_ptr<Configuration> config_;
aws::utils::processing_statistics_logger stats_logger_;
std::shared_ptr<aws::utils::Executor> executor_;
std::shared_ptr<Aws::Kinesis::KinesisClient> kinesis_client_;
std::shared_ptr<aws::metrics::MetricsManager> metrics_manager_;
std::shared_ptr<Aws::STS::STSClient> sts_client_;
Retrier::UserRecordCallback finish_user_record_cb_;

std::shared_ptr<ShardMap> shard_map_;
Expand Down
10 changes: 9 additions & 1 deletion aws/kinesis/core/put_records_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,20 @@ namespace core {
class PutRecordsContext : public Aws::Client::AsyncCallerContext {
public:
PutRecordsContext(std::string stream,
std::string stream_arn,
std::vector<std::shared_ptr<KinesisRecord>> records)
: stream_(stream),
: stream_(std::move(stream)),
stream_arn_(std::move(stream_arn)),
records_(std::move(records)) {}

const std::string& get_stream() const {
return stream_;
}

const std::string& get_stream_arn() const {
return stream_arn_;
}

std::chrono::steady_clock::time_point get_start() const {
return start_;
}
Expand Down Expand Up @@ -76,6 +82,7 @@ class PutRecordsContext : public Aws::Client::AsyncCallerContext {
req.AddRecords(std::move(e));
}
req.SetStreamName(stream_);
if (!stream_arn_.empty()) req.SetStreamARN(stream_arn_);
return req;
}

Expand All @@ -96,6 +103,7 @@ class PutRecordsContext : public Aws::Client::AsyncCallerContext {

private:
std::string stream_;
std::string stream_arn_;
std::chrono::steady_clock::time_point start_;
std::chrono::steady_clock::time_point end_;
std::vector<std::shared_ptr<KinesisRecord>> records_;
Expand Down
1 change: 1 addition & 0 deletions aws/kinesis/core/retrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class MetricsPutter {
private:
std::shared_ptr<aws::metrics::MetricsManager> metrics_manager_;
std::string stream_;
std::string stream_arn_;
};

} // namespace detail
Expand Down
13 changes: 9 additions & 4 deletions aws/kinesis/core/shard_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ ShardMap::ShardMap(
std::shared_ptr<aws::utils::Executor> executor,
std::shared_ptr<Aws::Kinesis::KinesisClient> kinesis_client,
std::string stream,
std::string stream_arn,
std::shared_ptr<aws::metrics::MetricsManager> metrics_manager,
std::chrono::milliseconds min_backoff,
std::chrono::milliseconds max_backoff)
: executor_(std::move(executor)),
kinesis_client_(std::move(kinesis_client)),
stream_(std::move(stream)),
stream_arn_(std::move(stream_arn)),
metrics_manager_(std::move(metrics_manager)),
state_(INVALID),
min_backoff_(min_backoff),
Expand Down Expand Up @@ -101,8 +103,10 @@ void ShardMap::list_shards(const Aws::String& next_token) {

if (!next_token.empty()) {
req.SetNextToken(next_token);
if (!stream_arn_.empty()) req.SetStreamARN(stream_arn_);
} else {
req.SetStreamName(stream_);
if (!stream_arn_.empty()) req.SetStreamARN(stream_arn_);
Aws::Kinesis::Model::ShardFilter shardFilter;
shardFilter.SetType(Aws::Kinesis::Model::ShardFilterType::AT_LATEST);
req.SetShardFilter(shardFilter);
Expand Down Expand Up @@ -146,12 +150,13 @@ void ShardMap::list_shards_callback(
updated_at_ = std::chrono::steady_clock::now();

LOG(info) << "Successfully updated shard map for stream \""
<< stream_ << "\" found " << end_hash_key_to_shard_id_.size()
<< " shards";
<< stream_ << (stream_arn_.empty() ? "\"" : "\" (arn: \"" + stream_arn_ + "\"). Found ")
<< end_hash_key_to_shard_id_.size() << " shards";
}

void ShardMap::update_fail(const std::string& code, const std::string& msg) {
LOG(error) << "Shard map update for stream \"" << stream_ << "\" failed. "
void ShardMap::update_fail(const std::string &code, const std::string &msg) {
LOG(error) << "Shard map update for stream \""
<< stream_ << (stream_arn_.empty() ? "\"" : "\" (arn: \"" + stream_arn_ + "\") failed. ")
<< "Code: " << code << " Message: " << msg << "; retrying in "
<< backoff_.count() << " ms";

Expand Down
2 changes: 2 additions & 0 deletions aws/kinesis/core/shard_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ShardMap : boost::noncopyable {
ShardMap(std::shared_ptr<aws::utils::Executor> executor,
std::shared_ptr<Aws::Kinesis::KinesisClient> kinesis_client,
std::string stream,
std::string stream_arn,
std::shared_ptr<aws::metrics::MetricsManager> metrics_manager
= std::make_shared<aws::metrics::NullMetricsManager>(),
std::chrono::milliseconds min_backoff = kMinBackoff,
Expand Down Expand Up @@ -88,6 +89,7 @@ class ShardMap : boost::noncopyable {
std::shared_ptr<aws::utils::Executor> executor_;
std::shared_ptr<Aws::Kinesis::KinesisClient> kinesis_client_;
std::string stream_;
std::string stream_arn_;
std::shared_ptr<aws::metrics::MetricsManager> metrics_manager_;

State state_;
Expand Down
1 change: 1 addition & 0 deletions aws/kinesis/core/test/retrier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ auto make_prr_ctx(size_t num_kr,
}
auto ctx = std::make_shared<aws::kinesis::core::PutRecordsContext>(
"myStream",
"arn:aws:kinesis:us-east-2:123456789012:stream/myStream",
krs);
ctx->set_outcome(outcome);
return ctx;
Expand Down
Loading