• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <utility>
16 
17 #include "tensorflow/core/framework/dataset.h"
18 #include "tensorflow/core/framework/partial_tensor_shape.h"
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/kernels/data/experimental/sql/driver_manager.h"
21 #include "tensorflow/core/kernels/data/experimental/sql/query_connection.h"
22 #include "tensorflow/core/lib/io/inputbuffer.h"
23 #include "tensorflow/core/lib/io/record_reader.h"
24 #include "tensorflow/core/lib/strings/stringprintf.h"
25 
26 namespace tensorflow {
27 namespace data {
28 namespace experimental {
29 namespace {
30 
31 class SqlDatasetOp : public DatasetOpKernel {
32  public:
SqlDatasetOp(OpKernelConstruction * ctx)33   explicit SqlDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
34     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
35     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
36     for (const DataType& dt : output_types_) {
37       OP_REQUIRES(ctx,
38                   dt == DT_STRING || dt == DT_INT8 || dt == DT_INT16 ||
39                       dt == DT_INT32 || dt == DT_INT64 || dt == DT_UINT8 ||
40                       dt == DT_UINT16 || dt == DT_BOOL || dt == DT_DOUBLE,
41                   errors::InvalidArgument(
42                       "Each element of `output_types_` must be one of: "
43                       "DT_STRING, DT_INT8, DT_INT16, DT_INT32, DT_INT64, "
44                       "DT_UINT8, DT_UINT16, DT_BOOL, DT_DOUBLE "));
45     }
46     for (const PartialTensorShape& pts : output_shapes_) {
47       OP_REQUIRES(ctx, pts.dims() == 0,
48                   errors::InvalidArgument(
49                       "Each element of `output_shapes_` must be a scalar."));
50     }
51   }
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)52   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
53     tstring driver_name;
54     OP_REQUIRES_OK(
55         ctx, ParseScalarArgument<tstring>(ctx, "driver_name", &driver_name));
56 
57     tstring data_source_name;
58     OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "data_source_name",
59                                                      &data_source_name));
60 
61     tstring query;
62     OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "query", &query));
63 
64     // TODO(b/64276826) Change this check when we add support for other
65     // databases.
66     OP_REQUIRES(ctx, driver_name == "sqlite",
67                 errors::InvalidArgument(tensorflow::strings::Printf(
68                     "The database type, %s, is not supported by SqlDataset. "
69                     "The set of supported databases is: {'sqlite'}.",
70                     driver_name.c_str())));
71 
72     *output = new Dataset(ctx, driver_name, data_source_name, query,
73                           output_types_, output_shapes_);
74   }
75 
76  private:
77   class Dataset : public DatasetBase {
78    public:
Dataset(OpKernelContext * ctx,const string & driver_name,const string & data_source_name,const string & query,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)79     Dataset(OpKernelContext* ctx, const string& driver_name,
80             const string& data_source_name, const string& query,
81             const DataTypeVector& output_types,
82             const std::vector<PartialTensorShape>& output_shapes)
83         : DatasetBase(DatasetContext(ctx)),
84           driver_name_(driver_name),
85           data_source_name_(data_source_name),
86           query_(query),
87           output_types_(output_types),
88           output_shapes_(output_shapes) {}
89 
MakeIteratorInternal(const string & prefix) const90     std::unique_ptr<IteratorBase> MakeIteratorInternal(
91         const string& prefix) const override {
92       return std::make_unique<Iterator>(
93           Iterator::Params{this, strings::StrCat(prefix, "::Sql")});
94     }
95 
output_dtypes() const96     const DataTypeVector& output_dtypes() const override {
97       return output_types_;
98     }
99 
output_shapes() const100     const std::vector<PartialTensorShape>& output_shapes() const override {
101       return output_shapes_;
102     }
103 
DebugString() const104     string DebugString() const override { return "SqlDatasetOp::Dataset"; }
105 
InputDatasets(std::vector<const DatasetBase * > * inputs) const106     Status InputDatasets(
107         std::vector<const DatasetBase*>* inputs) const override {
108       return OkStatus();
109     }
110 
CheckExternalState() const111     Status CheckExternalState() const override { return OkStatus(); }
112 
113    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const114     Status AsGraphDefInternal(SerializationContext* ctx,
115                               DatasetGraphDefBuilder* b,
116                               Node** output) const override {
117       Node* driver_name_node;
118       TF_RETURN_IF_ERROR(b->AddScalar(driver_name_, &driver_name_node));
119       Node* data_source_name_node;
120       TF_RETURN_IF_ERROR(
121           b->AddScalar(data_source_name_, &data_source_name_node));
122       Node* query_node;
123       TF_RETURN_IF_ERROR(b->AddScalar(query_, &query_node));
124       TF_RETURN_IF_ERROR(b->AddDataset(
125           this, {driver_name_node, data_source_name_node, query_node}, output));
126       return OkStatus();
127     }
128 
129    private:
130     class Iterator : public DatasetIterator<Dataset> {
131      public:
Iterator(const Params & params)132       explicit Iterator(const Params& params)
133           : DatasetIterator<Dataset>(params) {}
~Iterator()134       ~Iterator() override {
135         if (query_connection_initialized_) {
136           Status s = query_connection_->Close();
137           if (!s.ok()) {
138             LOG(WARNING) << "Failed to close query connection: " << s;
139           }
140         }
141       }
142 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)143       Status GetNextInternal(IteratorContext* ctx,
144                              std::vector<Tensor>* out_tensors,
145                              bool* end_of_sequence) override {
146         mutex_lock l(mu_);
147         if (!query_connection_initialized_) {
148           TF_RETURN_IF_ERROR(InitializeQueryConnection());
149         }
150         Status status = OkStatus();
151         if (!end_of_sequence_) {
152           next_calls_++;
153           status =
154               query_connection_->GetNext(ctx, out_tensors, &end_of_sequence_);
155         }
156         *end_of_sequence = end_of_sequence_;
157         return status;
158       }
159 
160      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const161       std::shared_ptr<model::Node> CreateNode(
162           IteratorContext* ctx, model::Node::Args args) const override {
163         return model::MakeSourceNode(std::move(args));
164       }
165 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)166       Status SaveInternal(SerializationContext* ctx,
167                           IteratorStateWriter* writer) override {
168         mutex_lock l(mu_);
169         if (query_connection_initialized_) {
170           TF_RETURN_IF_ERROR(
171               writer->WriteScalar(full_name("next_calls"), next_calls_));
172         }
173         return OkStatus();
174       }
175 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)176       Status RestoreInternal(IteratorContext* ctx,
177                              IteratorStateReader* reader) override {
178         mutex_lock l(mu_);
179         if (reader->Contains(full_name("next_calls"))) {
180           TF_RETURN_IF_ERROR(InitializeQueryConnection());
181           TF_RETURN_IF_ERROR(
182               reader->ReadScalar(full_name("next_calls"), &next_calls_));
183           int64_t rem_next_calls = next_calls_;
184           std::vector<Tensor> out_tensors;
185           end_of_sequence_ = false;
186           while (rem_next_calls--) {
187             TF_RETURN_IF_ERROR(query_connection_->GetNext(ctx, &out_tensors,
188                                                           &end_of_sequence_));
189             out_tensors.clear();
190           }
191         } else {
192           query_connection_initialized_ = false;
193           end_of_sequence_ = false;
194         }
195         return OkStatus();
196       }
197 
198      private:
InitializeQueryConnection()199       Status InitializeQueryConnection() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
200         query_connection_initialized_ = true;
201         end_of_sequence_ = false;
202         query_connection_ =
203             sql::DriverManager::CreateQueryConnection(dataset()->driver_name_);
204         Status s = query_connection_->Open(dataset()->data_source_name_,
205                                            dataset()->query_,
206                                            dataset()->output_types_);
207         next_calls_ = 0;
208         if (!s.ok()) {
209           LOG(WARNING) << "Failed to connect to database: " << s;
210           return s;
211         }
212         return OkStatus();
213       }
214 
215       mutex mu_;
216       // TODO(b/129062371): explore ways to seek into a SQLite databases.
217       int64_t next_calls_ TF_GUARDED_BY(mu_) = 0;
218       std::unique_ptr<sql::QueryConnection> query_connection_
219           TF_GUARDED_BY(mu_);
220       bool query_connection_initialized_ TF_GUARDED_BY(mu_) = false;
221       bool end_of_sequence_ TF_GUARDED_BY(mu_) = false;
222     };
223     const tstring driver_name_;
224     const tstring data_source_name_;
225     const tstring query_;
226     const DataTypeVector output_types_;
227     const std::vector<PartialTensorShape> output_shapes_;
228   };
229   DataTypeVector output_types_;
230   std::vector<PartialTensorShape> output_shapes_;
231 };
232 
233 REGISTER_KERNEL_BUILDER(Name("SqlDataset").Device(DEVICE_CPU), SqlDatasetOp);
234 REGISTER_KERNEL_BUILDER(Name("ExperimentalSqlDataset").Device(DEVICE_CPU),
235                         SqlDatasetOp);
236 
237 }  // namespace
238 }  // namespace experimental
239 }  // namespace data
240 }  // namespace tensorflow
241