• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <aws/core/Aws.h>
17 #include <aws/core/config/AWSProfileConfigLoader.h>
18 #include <aws/core/utils/Outcome.h>
19 #include <aws/kinesis/KinesisClient.h>
20 #include <aws/kinesis/model/DescribeStreamRequest.h>
21 #include <aws/kinesis/model/GetRecordsRequest.h>
22 #include <aws/kinesis/model/GetShardIteratorRequest.h>
23 #include <aws/kinesis/model/PutRecordsRequest.h>
24 #include <aws/kinesis/model/ShardIteratorType.h>
25 #include "tensorflow/core/framework/dataset.h"
26 #include "tensorflow/core/platform/s3/aws_crypto.h"
27 
28 namespace tensorflow {
29 namespace {
30 
InitializeDefaultClientConfig()31 Aws::Client::ClientConfiguration* InitializeDefaultClientConfig() {
32   static Aws::Client::ClientConfiguration config;
33   const char* endpoint = getenv("KINESIS_ENDPOINT");
34   if (endpoint) {
35     config.endpointOverride = Aws::String(endpoint);
36   }
37   const char* region = getenv("AWS_REGION");
38   if (region) {
39     config.region = Aws::String(region);
40   } else {
41     // Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG
42     // is set with a truthy value.
43     const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG");
44     string load_config =
45         load_config_env ? str_util::Lowercase(load_config_env) : "";
46     if (load_config == "true" || load_config == "1") {
47       Aws::String config_file;
48       // If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config.
49       const char* config_file_env = getenv("AWS_CONFIG_FILE");
50       if (config_file_env) {
51         config_file = config_file_env;
52       } else {
53         const char* home_env = getenv("HOME");
54         if (home_env) {
55           config_file = home_env;
56           config_file += "/.aws/config";
57         }
58       }
59       Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file);
60       // Load the configuration. If successful, get the region.
61       // If the load is not successful, then generate a warning.
62       if (loader.Load()) {
63         auto profiles = loader.GetProfiles();
64         if (!profiles["default"].GetRegion().empty()) {
65           config.region = profiles["default"].GetRegion();
66         }
67       } else {
68         LOG(WARNING) << "Failed to load the profile in " << config_file << ".";
69       }
70     }
71   }
72   const char* use_https = getenv("KINESIS_USE_HTTPS");
73   if (use_https) {
74     if (use_https[0] == '0') {
75       config.scheme = Aws::Http::Scheme::HTTP;
76     } else {
77       config.scheme = Aws::Http::Scheme::HTTPS;
78     }
79   }
80   const char* verify_ssl = getenv("KINESIS_VERIFY_SSL");
81   if (verify_ssl) {
82     if (verify_ssl[0] == '0') {
83       config.verifySSL = false;
84     } else {
85       config.verifySSL = true;
86     }
87   }
88   const char* connect_timeout = getenv("KINESIS_CONNECT_TIMEOUT_MSEC");
89   if (connect_timeout) {
90     int64 timeout;
91 
92     if (strings::safe_strto64(connect_timeout, &timeout)) {
93       config.connectTimeoutMs = timeout;
94     }
95   }
96   const char* request_timeout = getenv("KINESIS_REQUEST_TIMEOUT_MSEC");
97   if (request_timeout) {
98     int64 timeout;
99 
100     if (strings::safe_strto64(request_timeout, &timeout)) {
101       config.requestTimeoutMs = timeout;
102     }
103   }
104 
105   return &config;
106 }
107 
GetDefaultClientConfig()108 Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
109   static Aws::Client::ClientConfiguration* config =
110       InitializeDefaultClientConfig();
111   return *config;
112 }
113 
114 static mutex mu(LINKER_INITIALIZED);
115 static unsigned count(0);
AwsInitAPI()116 void AwsInitAPI() {
117   mutex_lock lock(mu);
118   count++;
119   if (count == 1) {
120     Aws::SDKOptions options;
121     options.cryptoOptions.sha256Factory_create_fn = []() {
122       return Aws::MakeShared<AWSSHA256Factory>(AWSCryptoAllocationTag);
123     };
124     options.cryptoOptions.sha256HMACFactory_create_fn = []() {
125       return Aws::MakeShared<AWSSHA256HmacFactory>(AWSCryptoAllocationTag);
126     };
127     Aws::InitAPI(options);
128   }
129 }
AwsShutdownAPI()130 void AwsShutdownAPI() {
131   mutex_lock lock(mu);
132   count--;
133   if (count == 0) {
134     Aws::SDKOptions options;
135     Aws::ShutdownAPI(options);
136   }
137 }
ShutdownClient(Aws::Kinesis::KinesisClient * client)138 void ShutdownClient(Aws::Kinesis::KinesisClient* client) {
139   if (client != nullptr) {
140     delete client;
141     AwsShutdownAPI();
142   }
143 }
144 }
145 class KinesisDatasetOp : public DatasetOpKernel {
146  public:
147   using DatasetOpKernel::DatasetOpKernel;
148 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)149   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
150     std::string stream = "";
151     OP_REQUIRES_OK(ctx,
152                    ParseScalarArgument<std::string>(ctx, "stream", &stream));
153     std::string shard = "";
154     OP_REQUIRES_OK(ctx, ParseScalarArgument<std::string>(ctx, "shard", &shard));
155     bool read_indefinitely = true;
156     OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "read_indefinitely",
157                                                   &read_indefinitely));
158     int64 interval = -1;
159     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "interval", &interval));
160     OP_REQUIRES(ctx, (interval > 0),
161                 errors::InvalidArgument(
162                     "Interval value should be large than 0, got ", interval));
163     *output = new Dataset(ctx, stream, shard, read_indefinitely, interval);
164   }
165 
166  private:
167   class Dataset : public DatasetBase {
168    public:
Dataset(OpKernelContext * ctx,const string & stream,const string & shard,const bool read_indefinitely,const int64 interval)169     Dataset(OpKernelContext* ctx, const string& stream, const string& shard,
170             const bool read_indefinitely, const int64 interval)
171         : DatasetBase(DatasetContext(ctx)),
172           stream_(stream),
173           shard_(shard),
174           read_indefinitely_(read_indefinitely),
175           interval_(interval) {}
176 
MakeIteratorInternal(const string & prefix) const177     std::unique_ptr<IteratorBase> MakeIteratorInternal(
178         const string& prefix) const override {
179       return std::unique_ptr<IteratorBase>(
180           new Iterator({this, strings::StrCat(prefix, "::Kinesis")}));
181     }
182 
output_dtypes() const183     const DataTypeVector& output_dtypes() const override {
184       static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
185       return *dtypes;
186     }
187 
output_shapes() const188     const std::vector<PartialTensorShape>& output_shapes() const override {
189       static std::vector<PartialTensorShape>* shapes =
190           new std::vector<PartialTensorShape>({{}});
191       return *shapes;
192     }
193 
DebugString() const194     string DebugString() const override { return "KinesisDatasetOp::Dataset"; }
195 
196    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const197     Status AsGraphDefInternal(SerializationContext* ctx,
198                               DatasetGraphDefBuilder* b,
199                               Node** output) const override {
200       Node* stream = nullptr;
201       TF_RETURN_IF_ERROR(b->AddScalar(stream_, &stream));
202       Node* shard = nullptr;
203       TF_RETURN_IF_ERROR(b->AddScalar(shard_, &shard));
204       Node* read_indefinitely = nullptr;
205       TF_RETURN_IF_ERROR(b->AddScalar(read_indefinitely_, &read_indefinitely));
206       Node* interval = nullptr;
207       TF_RETURN_IF_ERROR(b->AddScalar(interval_, &interval));
208       TF_RETURN_IF_ERROR(b->AddDataset(
209           this, {stream, shard, read_indefinitely, interval}, output));
210       return Status::OK();
211     }
212 
213    private:
214     class Iterator : public DatasetIterator<Dataset> {
215      public:
Iterator(const Params & params)216       explicit Iterator(const Params& params)
217           : DatasetIterator<Dataset>(params),
218             client_(nullptr, ShutdownClient) {}
219 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)220       Status GetNextInternal(IteratorContext* ctx,
221                              std::vector<Tensor>* out_tensors,
222                              bool* end_of_sequence) override {
223         mutex_lock l(mu_);
224         if (iterator_ == "") {
225           TF_RETURN_IF_ERROR(SetupStreamsLocked());
226         }
227         do {
228           Aws::Kinesis::Model::GetRecordsRequest request;
229           auto outcome = client_->GetRecords(
230               request.WithShardIterator(iterator_).WithLimit(1));
231           if (!outcome.IsSuccess()) {
232             return errors::Unknown(outcome.GetError().GetExceptionName(), ": ",
233                                    outcome.GetError().GetMessage());
234           }
235           if (outcome.GetResult().GetRecords().size() == 0) {
236             // If no records were returned then nothing is available at the
237             // moment.
238             if (!dataset()->read_indefinitely_) {
239               *end_of_sequence = true;
240               return Status::OK();
241             }
242             // Continue the loop after a period of time.
243             ctx->env()->SleepForMicroseconds(dataset()->interval_);
244             continue;
245           }
246           if (outcome.GetResult().GetRecords().size() != 1) {
247             return errors::Unknown("invalid number of records ",
248                                    outcome.GetResult().GetRecords().size(),
249                                    " returned");
250           }
251 
252           iterator_ = outcome.GetResult().GetNextShardIterator();
253 
254           const auto& data = outcome.GetResult().GetRecords()[0].GetData();
255           StringPiece value(
256               reinterpret_cast<const char*>(data.GetUnderlyingData()),
257               data.GetLength());
258           Tensor value_tensor(ctx->allocator({}), DT_STRING, {});
259           value_tensor.scalar<std::string>()() = std::string(value);
260           out_tensors->emplace_back(std::move(value_tensor));
261 
262           *end_of_sequence = false;
263           return Status::OK();
264         } while (true);
265       }
266 
267      protected:
SaveInternal(IteratorStateWriter * writer)268       Status SaveInternal(IteratorStateWriter* writer) override {
269         return errors::Unimplemented("SaveInternal is currently not supported");
270       }
271 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)272       Status RestoreInternal(IteratorContext* ctx,
273                              IteratorStateReader* reader) override {
274         return errors::Unimplemented(
275             "RestoreInternal is currently not supported");
276       }
277 
278      private:
279       // Sets up Kinesis streams to read from.
SetupStreamsLocked()280       Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
281         AwsInitAPI();
282         client_.reset(
283             new Aws::Kinesis::KinesisClient(GetDefaultClientConfig()));
284 
285         Aws::Kinesis::Model::DescribeStreamRequest request;
286         auto outcome = client_->DescribeStream(
287             request.WithStreamName(dataset()->stream_.c_str()));
288         if (!outcome.IsSuccess()) {
289           return errors::Unknown(outcome.GetError().GetExceptionName(), ": ",
290                                  outcome.GetError().GetMessage());
291         }
292         Aws::String shard;
293         Aws::String sequence;
294         if (dataset()->shard_ == "") {
295           if (outcome.GetResult().GetStreamDescription().GetShards().size() !=
296               1) {
297             return errors::InvalidArgument(
298                 "shard has to be provided unless the stream only have one "
299                 "shard, there are ",
300                 outcome.GetResult().GetStreamDescription().GetShards().size(),
301                 " shards in stream ", dataset()->stream_);
302           }
303           shard = outcome.GetResult()
304                       .GetStreamDescription()
305                       .GetShards()[0]
306                       .GetShardId();
307           sequence = outcome.GetResult()
308                          .GetStreamDescription()
309                          .GetShards()[0]
310                          .GetSequenceNumberRange()
311                          .GetStartingSequenceNumber();
312         } else {
313           for (const auto& entry :
314                outcome.GetResult().GetStreamDescription().GetShards()) {
315             if (entry.GetShardId() == dataset()->shard_.c_str()) {
316               shard = entry.GetShardId();
317               sequence =
318                   entry.GetSequenceNumberRange().GetStartingSequenceNumber();
319               break;
320             }
321           }
322           if (shard == "") {
323             return errors::InvalidArgument("no shard ", dataset()->shard_,
324                                            " in stream ", dataset()->stream_);
325           }
326         }
327 
328         Aws::Kinesis::Model::GetShardIteratorRequest iterator_request;
329         auto iterator_outcome = client_->GetShardIterator(
330             iterator_request.WithStreamName(dataset()->stream_.c_str())
331                 .WithShardId(shard)
332                 .WithShardIteratorType(
333                     Aws::Kinesis::Model::ShardIteratorType::AT_SEQUENCE_NUMBER)
334                 .WithStartingSequenceNumber(sequence));
335         if (!iterator_outcome.IsSuccess()) {
336           return errors::Unknown(iterator_outcome.GetError().GetExceptionName(),
337                                  ": ",
338                                  iterator_outcome.GetError().GetMessage());
339         }
340         iterator_ = iterator_outcome.GetResult().GetShardIterator();
341         return Status::OK();
342       }
343 
344       mutex mu_;
345       Aws::String iterator_ GUARDED_BY(mu_);
346       std::unique_ptr<Aws::Kinesis::KinesisClient, decltype(&ShutdownClient)>
347           client_ GUARDED_BY(mu_);
348     };
349 
350     const std::string stream_;
351     const std::string shard_;
352     const bool read_indefinitely_;
353     const int64 interval_;
354   };
355 };
356 
357 REGISTER_KERNEL_BUILDER(Name("KinesisDataset").Device(DEVICE_CPU),
358                         KinesisDatasetOp);
359 
360 }  // namespace tensorflow
361