1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <aws/core/Aws.h>
17 #include <aws/core/config/AWSProfileConfigLoader.h>
18 #include <aws/core/utils/Outcome.h>
19 #include <aws/kinesis/KinesisClient.h>
20 #include <aws/kinesis/model/DescribeStreamRequest.h>
21 #include <aws/kinesis/model/GetRecordsRequest.h>
22 #include <aws/kinesis/model/GetShardIteratorRequest.h>
23 #include <aws/kinesis/model/PutRecordsRequest.h>
24 #include <aws/kinesis/model/ShardIteratorType.h>
25 #include "tensorflow/core/framework/dataset.h"
26 #include "tensorflow/core/platform/s3/aws_crypto.h"
27
28 namespace tensorflow {
29 namespace {
30
InitializeDefaultClientConfig()31 Aws::Client::ClientConfiguration* InitializeDefaultClientConfig() {
32 static Aws::Client::ClientConfiguration config;
33 const char* endpoint = getenv("KINESIS_ENDPOINT");
34 if (endpoint) {
35 config.endpointOverride = Aws::String(endpoint);
36 }
37 const char* region = getenv("AWS_REGION");
38 if (region) {
39 config.region = Aws::String(region);
40 } else {
41 // Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG
42 // is set with a truthy value.
43 const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG");
44 string load_config =
45 load_config_env ? str_util::Lowercase(load_config_env) : "";
46 if (load_config == "true" || load_config == "1") {
47 Aws::String config_file;
48 // If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config.
49 const char* config_file_env = getenv("AWS_CONFIG_FILE");
50 if (config_file_env) {
51 config_file = config_file_env;
52 } else {
53 const char* home_env = getenv("HOME");
54 if (home_env) {
55 config_file = home_env;
56 config_file += "/.aws/config";
57 }
58 }
59 Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file);
60 // Load the configuration. If successful, get the region.
61 // If the load is not successful, then generate a warning.
62 if (loader.Load()) {
63 auto profiles = loader.GetProfiles();
64 if (!profiles["default"].GetRegion().empty()) {
65 config.region = profiles["default"].GetRegion();
66 }
67 } else {
68 LOG(WARNING) << "Failed to load the profile in " << config_file << ".";
69 }
70 }
71 }
72 const char* use_https = getenv("KINESIS_USE_HTTPS");
73 if (use_https) {
74 if (use_https[0] == '0') {
75 config.scheme = Aws::Http::Scheme::HTTP;
76 } else {
77 config.scheme = Aws::Http::Scheme::HTTPS;
78 }
79 }
80 const char* verify_ssl = getenv("KINESIS_VERIFY_SSL");
81 if (verify_ssl) {
82 if (verify_ssl[0] == '0') {
83 config.verifySSL = false;
84 } else {
85 config.verifySSL = true;
86 }
87 }
88 const char* connect_timeout = getenv("KINESIS_CONNECT_TIMEOUT_MSEC");
89 if (connect_timeout) {
90 int64 timeout;
91
92 if (strings::safe_strto64(connect_timeout, &timeout)) {
93 config.connectTimeoutMs = timeout;
94 }
95 }
96 const char* request_timeout = getenv("KINESIS_REQUEST_TIMEOUT_MSEC");
97 if (request_timeout) {
98 int64 timeout;
99
100 if (strings::safe_strto64(request_timeout, &timeout)) {
101 config.requestTimeoutMs = timeout;
102 }
103 }
104
105 return &config;
106 }
107
GetDefaultClientConfig()108 Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
109 static Aws::Client::ClientConfiguration* config =
110 InitializeDefaultClientConfig();
111 return *config;
112 }
113
114 static mutex mu(LINKER_INITIALIZED);
115 static unsigned count(0);
AwsInitAPI()116 void AwsInitAPI() {
117 mutex_lock lock(mu);
118 count++;
119 if (count == 1) {
120 Aws::SDKOptions options;
121 options.cryptoOptions.sha256Factory_create_fn = []() {
122 return Aws::MakeShared<AWSSHA256Factory>(AWSCryptoAllocationTag);
123 };
124 options.cryptoOptions.sha256HMACFactory_create_fn = []() {
125 return Aws::MakeShared<AWSSHA256HmacFactory>(AWSCryptoAllocationTag);
126 };
127 Aws::InitAPI(options);
128 }
129 }
AwsShutdownAPI()130 void AwsShutdownAPI() {
131 mutex_lock lock(mu);
132 count--;
133 if (count == 0) {
134 Aws::SDKOptions options;
135 Aws::ShutdownAPI(options);
136 }
137 }
ShutdownClient(Aws::Kinesis::KinesisClient * client)138 void ShutdownClient(Aws::Kinesis::KinesisClient* client) {
139 if (client != nullptr) {
140 delete client;
141 AwsShutdownAPI();
142 }
143 }
144 }
145 class KinesisDatasetOp : public DatasetOpKernel {
146 public:
147 using DatasetOpKernel::DatasetOpKernel;
148
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)149 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
150 std::string stream = "";
151 OP_REQUIRES_OK(ctx,
152 ParseScalarArgument<std::string>(ctx, "stream", &stream));
153 std::string shard = "";
154 OP_REQUIRES_OK(ctx, ParseScalarArgument<std::string>(ctx, "shard", &shard));
155 bool read_indefinitely = true;
156 OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "read_indefinitely",
157 &read_indefinitely));
158 int64 interval = -1;
159 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "interval", &interval));
160 OP_REQUIRES(ctx, (interval > 0),
161 errors::InvalidArgument(
162 "Interval value should be large than 0, got ", interval));
163 *output = new Dataset(ctx, stream, shard, read_indefinitely, interval);
164 }
165
166 private:
167 class Dataset : public DatasetBase {
168 public:
Dataset(OpKernelContext * ctx,const string & stream,const string & shard,const bool read_indefinitely,const int64 interval)169 Dataset(OpKernelContext* ctx, const string& stream, const string& shard,
170 const bool read_indefinitely, const int64 interval)
171 : DatasetBase(DatasetContext(ctx)),
172 stream_(stream),
173 shard_(shard),
174 read_indefinitely_(read_indefinitely),
175 interval_(interval) {}
176
MakeIteratorInternal(const string & prefix) const177 std::unique_ptr<IteratorBase> MakeIteratorInternal(
178 const string& prefix) const override {
179 return std::unique_ptr<IteratorBase>(
180 new Iterator({this, strings::StrCat(prefix, "::Kinesis")}));
181 }
182
output_dtypes() const183 const DataTypeVector& output_dtypes() const override {
184 static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
185 return *dtypes;
186 }
187
output_shapes() const188 const std::vector<PartialTensorShape>& output_shapes() const override {
189 static std::vector<PartialTensorShape>* shapes =
190 new std::vector<PartialTensorShape>({{}});
191 return *shapes;
192 }
193
DebugString() const194 string DebugString() const override { return "KinesisDatasetOp::Dataset"; }
195
196 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const197 Status AsGraphDefInternal(SerializationContext* ctx,
198 DatasetGraphDefBuilder* b,
199 Node** output) const override {
200 Node* stream = nullptr;
201 TF_RETURN_IF_ERROR(b->AddScalar(stream_, &stream));
202 Node* shard = nullptr;
203 TF_RETURN_IF_ERROR(b->AddScalar(shard_, &shard));
204 Node* read_indefinitely = nullptr;
205 TF_RETURN_IF_ERROR(b->AddScalar(read_indefinitely_, &read_indefinitely));
206 Node* interval = nullptr;
207 TF_RETURN_IF_ERROR(b->AddScalar(interval_, &interval));
208 TF_RETURN_IF_ERROR(b->AddDataset(
209 this, {stream, shard, read_indefinitely, interval}, output));
210 return Status::OK();
211 }
212
213 private:
214 class Iterator : public DatasetIterator<Dataset> {
215 public:
Iterator(const Params & params)216 explicit Iterator(const Params& params)
217 : DatasetIterator<Dataset>(params),
218 client_(nullptr, ShutdownClient) {}
219
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)220 Status GetNextInternal(IteratorContext* ctx,
221 std::vector<Tensor>* out_tensors,
222 bool* end_of_sequence) override {
223 mutex_lock l(mu_);
224 if (iterator_ == "") {
225 TF_RETURN_IF_ERROR(SetupStreamsLocked());
226 }
227 do {
228 Aws::Kinesis::Model::GetRecordsRequest request;
229 auto outcome = client_->GetRecords(
230 request.WithShardIterator(iterator_).WithLimit(1));
231 if (!outcome.IsSuccess()) {
232 return errors::Unknown(outcome.GetError().GetExceptionName(), ": ",
233 outcome.GetError().GetMessage());
234 }
235 if (outcome.GetResult().GetRecords().size() == 0) {
236 // If no records were returned then nothing is available at the
237 // moment.
238 if (!dataset()->read_indefinitely_) {
239 *end_of_sequence = true;
240 return Status::OK();
241 }
242 // Continue the loop after a period of time.
243 ctx->env()->SleepForMicroseconds(dataset()->interval_);
244 continue;
245 }
246 if (outcome.GetResult().GetRecords().size() != 1) {
247 return errors::Unknown("invalid number of records ",
248 outcome.GetResult().GetRecords().size(),
249 " returned");
250 }
251
252 iterator_ = outcome.GetResult().GetNextShardIterator();
253
254 const auto& data = outcome.GetResult().GetRecords()[0].GetData();
255 StringPiece value(
256 reinterpret_cast<const char*>(data.GetUnderlyingData()),
257 data.GetLength());
258 Tensor value_tensor(ctx->allocator({}), DT_STRING, {});
259 value_tensor.scalar<std::string>()() = std::string(value);
260 out_tensors->emplace_back(std::move(value_tensor));
261
262 *end_of_sequence = false;
263 return Status::OK();
264 } while (true);
265 }
266
267 protected:
SaveInternal(IteratorStateWriter * writer)268 Status SaveInternal(IteratorStateWriter* writer) override {
269 return errors::Unimplemented("SaveInternal is currently not supported");
270 }
271
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)272 Status RestoreInternal(IteratorContext* ctx,
273 IteratorStateReader* reader) override {
274 return errors::Unimplemented(
275 "RestoreInternal is currently not supported");
276 }
277
278 private:
279 // Sets up Kinesis streams to read from.
SetupStreamsLocked()280 Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
281 AwsInitAPI();
282 client_.reset(
283 new Aws::Kinesis::KinesisClient(GetDefaultClientConfig()));
284
285 Aws::Kinesis::Model::DescribeStreamRequest request;
286 auto outcome = client_->DescribeStream(
287 request.WithStreamName(dataset()->stream_.c_str()));
288 if (!outcome.IsSuccess()) {
289 return errors::Unknown(outcome.GetError().GetExceptionName(), ": ",
290 outcome.GetError().GetMessage());
291 }
292 Aws::String shard;
293 Aws::String sequence;
294 if (dataset()->shard_ == "") {
295 if (outcome.GetResult().GetStreamDescription().GetShards().size() !=
296 1) {
297 return errors::InvalidArgument(
298 "shard has to be provided unless the stream only have one "
299 "shard, there are ",
300 outcome.GetResult().GetStreamDescription().GetShards().size(),
301 " shards in stream ", dataset()->stream_);
302 }
303 shard = outcome.GetResult()
304 .GetStreamDescription()
305 .GetShards()[0]
306 .GetShardId();
307 sequence = outcome.GetResult()
308 .GetStreamDescription()
309 .GetShards()[0]
310 .GetSequenceNumberRange()
311 .GetStartingSequenceNumber();
312 } else {
313 for (const auto& entry :
314 outcome.GetResult().GetStreamDescription().GetShards()) {
315 if (entry.GetShardId() == dataset()->shard_.c_str()) {
316 shard = entry.GetShardId();
317 sequence =
318 entry.GetSequenceNumberRange().GetStartingSequenceNumber();
319 break;
320 }
321 }
322 if (shard == "") {
323 return errors::InvalidArgument("no shard ", dataset()->shard_,
324 " in stream ", dataset()->stream_);
325 }
326 }
327
328 Aws::Kinesis::Model::GetShardIteratorRequest iterator_request;
329 auto iterator_outcome = client_->GetShardIterator(
330 iterator_request.WithStreamName(dataset()->stream_.c_str())
331 .WithShardId(shard)
332 .WithShardIteratorType(
333 Aws::Kinesis::Model::ShardIteratorType::AT_SEQUENCE_NUMBER)
334 .WithStartingSequenceNumber(sequence));
335 if (!iterator_outcome.IsSuccess()) {
336 return errors::Unknown(iterator_outcome.GetError().GetExceptionName(),
337 ": ",
338 iterator_outcome.GetError().GetMessage());
339 }
340 iterator_ = iterator_outcome.GetResult().GetShardIterator();
341 return Status::OK();
342 }
343
344 mutex mu_;
345 Aws::String iterator_ GUARDED_BY(mu_);
346 std::unique_ptr<Aws::Kinesis::KinesisClient, decltype(&ShutdownClient)>
347 client_ GUARDED_BY(mu_);
348 };
349
350 const std::string stream_;
351 const std::string shard_;
352 const bool read_indefinitely_;
353 const int64 interval_;
354 };
355 };
356
357 REGISTER_KERNEL_BUILDER(Name("KinesisDataset").Device(DEVICE_CPU),
358 KinesisDatasetOp);
359
360 } // namespace tensorflow
361