1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <utility> 16 17 #include "tensorflow/core/framework/dataset.h" 18 #include "tensorflow/core/framework/partial_tensor_shape.h" 19 #include "tensorflow/core/framework/tensor.h" 20 #include "tensorflow/core/kernels/data/experimental/sql/driver_manager.h" 21 #include "tensorflow/core/kernels/data/experimental/sql/query_connection.h" 22 #include "tensorflow/core/lib/io/inputbuffer.h" 23 #include "tensorflow/core/lib/io/record_reader.h" 24 #include "tensorflow/core/lib/strings/stringprintf.h" 25 26 namespace tensorflow { 27 namespace data { 28 namespace experimental { 29 namespace { 30 31 class SqlDatasetOp : public DatasetOpKernel { 32 public: SqlDatasetOp(OpKernelConstruction * ctx)33 explicit SqlDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { 34 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); 35 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); 36 for (const DataType& dt : output_types_) { 37 OP_REQUIRES(ctx, 38 dt == DT_STRING || dt == DT_INT8 || dt == DT_INT16 || 39 dt == DT_INT32 || dt == DT_INT64 || dt == DT_UINT8 || 40 dt == DT_UINT16 || dt == DT_BOOL || dt == DT_DOUBLE, 41 errors::InvalidArgument( 42 "Each element of `output_types_` must be one of: " 43 "DT_STRING, DT_INT8, DT_INT16, DT_INT32, DT_INT64, " 44 "DT_UINT8, DT_UINT16, DT_BOOL, DT_DOUBLE ")); 45 } 46 for (const PartialTensorShape& pts : output_shapes_) { 47 OP_REQUIRES(ctx, pts.dims() == 0, 48 errors::InvalidArgument( 49 "Each element of `output_shapes_` must be a scalar.")); 50 } 51 } MakeDataset(OpKernelContext * ctx,DatasetBase ** output)52 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { 53 tstring driver_name; 54 OP_REQUIRES_OK( 55 ctx, ParseScalarArgument<tstring>(ctx, "driver_name", &driver_name)); 56 57 tstring data_source_name; 58 OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "data_source_name", 59 &data_source_name)); 60 61 tstring query; 62 OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "query", &query)); 63 64 // TODO(b/64276826) Change this check when we add support for other 65 // databases. 66 OP_REQUIRES(ctx, driver_name == "sqlite", 67 errors::InvalidArgument(tensorflow::strings::Printf( 68 "The database type, %s, is not supported by SqlDataset. " 69 "The set of supported databases is: {'sqlite'}.", 70 driver_name.c_str()))); 71 72 *output = new Dataset(ctx, driver_name, data_source_name, query, 73 output_types_, output_shapes_); 74 } 75 76 private: 77 class Dataset : public DatasetBase { 78 public: Dataset(OpKernelContext * ctx,const string & driver_name,const string & data_source_name,const string & query,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)79 Dataset(OpKernelContext* ctx, const string& driver_name, 80 const string& data_source_name, const string& query, 81 const DataTypeVector& output_types, 82 const std::vector<PartialTensorShape>& output_shapes) 83 : DatasetBase(DatasetContext(ctx)), 84 driver_name_(driver_name), 85 data_source_name_(data_source_name), 86 query_(query), 87 output_types_(output_types), 88 output_shapes_(output_shapes) {} 89 MakeIteratorInternal(const string & prefix) const90 std::unique_ptr<IteratorBase> MakeIteratorInternal( 91 const string& prefix) const override { 92 return std::make_unique<Iterator>( 93 Iterator::Params{this, strings::StrCat(prefix, "::Sql")}); 94 } 95 output_dtypes() const96 const DataTypeVector& output_dtypes() const override { 97 return output_types_; 98 } 99 output_shapes() const100 const std::vector<PartialTensorShape>& output_shapes() const override { 101 return output_shapes_; 102 } 103 DebugString() const104 string DebugString() const override { return "SqlDatasetOp::Dataset"; } 105 InputDatasets(std::vector<const DatasetBase * > * inputs) const106 Status InputDatasets( 107 std::vector<const DatasetBase*>* inputs) const override { 108 return OkStatus(); 109 } 110 CheckExternalState() const111 Status CheckExternalState() const override { return OkStatus(); } 112 113 protected: AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const114 Status AsGraphDefInternal(SerializationContext* ctx, 115 DatasetGraphDefBuilder* b, 116 Node** output) const override { 117 Node* driver_name_node; 118 TF_RETURN_IF_ERROR(b->AddScalar(driver_name_, &driver_name_node)); 119 Node* data_source_name_node; 120 TF_RETURN_IF_ERROR( 121 b->AddScalar(data_source_name_, &data_source_name_node)); 122 Node* query_node; 123 TF_RETURN_IF_ERROR(b->AddScalar(query_, &query_node)); 124 TF_RETURN_IF_ERROR(b->AddDataset( 125 this, {driver_name_node, data_source_name_node, query_node}, output)); 126 return OkStatus(); 127 } 128 129 private: 130 class Iterator : public DatasetIterator<Dataset> { 131 public: Iterator(const Params & params)132 explicit Iterator(const Params& params) 133 : DatasetIterator<Dataset>(params) {} ~Iterator()134 ~Iterator() override { 135 if (query_connection_initialized_) { 136 Status s = query_connection_->Close(); 137 if (!s.ok()) { 138 LOG(WARNING) << "Failed to close query connection: " << s; 139 } 140 } 141 } 142 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)143 Status GetNextInternal(IteratorContext* ctx, 144 std::vector<Tensor>* out_tensors, 145 bool* end_of_sequence) override { 146 mutex_lock l(mu_); 147 if (!query_connection_initialized_) { 148 TF_RETURN_IF_ERROR(InitializeQueryConnection()); 149 } 150 Status status = OkStatus(); 151 if (!end_of_sequence_) { 152 next_calls_++; 153 status = 154 query_connection_->GetNext(ctx, out_tensors, &end_of_sequence_); 155 } 156 *end_of_sequence = end_of_sequence_; 157 return status; 158 } 159 160 protected: CreateNode(IteratorContext * ctx,model::Node::Args args) const161 std::shared_ptr<model::Node> CreateNode( 162 IteratorContext* ctx, model::Node::Args args) const override { 163 return model::MakeSourceNode(std::move(args)); 164 } 165 SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)166 Status SaveInternal(SerializationContext* ctx, 167 IteratorStateWriter* writer) override { 168 mutex_lock l(mu_); 169 if (query_connection_initialized_) { 170 TF_RETURN_IF_ERROR( 171 writer->WriteScalar(full_name("next_calls"), next_calls_)); 172 } 173 return OkStatus(); 174 } 175 RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)176 Status RestoreInternal(IteratorContext* ctx, 177 IteratorStateReader* reader) override { 178 mutex_lock l(mu_); 179 if (reader->Contains(full_name("next_calls"))) { 180 TF_RETURN_IF_ERROR(InitializeQueryConnection()); 181 TF_RETURN_IF_ERROR( 182 reader->ReadScalar(full_name("next_calls"), &next_calls_)); 183 int64_t rem_next_calls = next_calls_; 184 std::vector<Tensor> out_tensors; 185 end_of_sequence_ = false; 186 while (rem_next_calls--) { 187 TF_RETURN_IF_ERROR(query_connection_->GetNext(ctx, &out_tensors, 188 &end_of_sequence_)); 189 out_tensors.clear(); 190 } 191 } else { 192 query_connection_initialized_ = false; 193 end_of_sequence_ = false; 194 } 195 return OkStatus(); 196 } 197 198 private: InitializeQueryConnection()199 Status InitializeQueryConnection() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 200 query_connection_initialized_ = true; 201 end_of_sequence_ = false; 202 query_connection_ = 203 sql::DriverManager::CreateQueryConnection(dataset()->driver_name_); 204 Status s = query_connection_->Open(dataset()->data_source_name_, 205 dataset()->query_, 206 dataset()->output_types_); 207 next_calls_ = 0; 208 if (!s.ok()) { 209 LOG(WARNING) << "Failed to connect to database: " << s; 210 return s; 211 } 212 return OkStatus(); 213 } 214 215 mutex mu_; 216 // TODO(b/129062371): explore ways to seek into a SQLite databases. 217 int64_t next_calls_ TF_GUARDED_BY(mu_) = 0; 218 std::unique_ptr<sql::QueryConnection> query_connection_ 219 TF_GUARDED_BY(mu_); 220 bool query_connection_initialized_ TF_GUARDED_BY(mu_) = false; 221 bool end_of_sequence_ TF_GUARDED_BY(mu_) = false; 222 }; 223 const tstring driver_name_; 224 const tstring data_source_name_; 225 const tstring query_; 226 const DataTypeVector output_types_; 227 const std::vector<PartialTensorShape> output_shapes_; 228 }; 229 DataTypeVector output_types_; 230 std::vector<PartialTensorShape> output_shapes_; 231 }; 232 233 REGISTER_KERNEL_BUILDER(Name("SqlDataset").Device(DEVICE_CPU), SqlDatasetOp); 234 REGISTER_KERNEL_BUILDER(Name("ExperimentalSqlDataset").Device(DEVICE_CPU), 235 SqlDatasetOp); 236 237 } // namespace 238 } // namespace experimental 239 } // namespace data 240 } // namespace tensorflow 241