1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <map>
17 #include <memory>
18 #include <set>
19
20 #include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h"
21 #include "tensorflow/contrib/cloud/kernels/bigquery_table_partition.pb.h"
22 #include "tensorflow/core/framework/reader_base.h"
23 #include "tensorflow/core/framework/reader_op_kernel.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/math/math_util.h"
26 #include "tensorflow/core/lib/strings/numbers.h"
27
28 namespace tensorflow {
29 namespace {
30
31 constexpr int64 kDefaultRowBufferSize = 1000; // Number of rows to buffer.
32
33 // This is a helper function for reading table attributes from context.
GetTableAttrs(OpKernelConstruction * context,string * project_id,string * dataset_id,string * table_id,int64 * timestamp_millis,std::vector<string> * columns,string * test_end_point)34 Status GetTableAttrs(OpKernelConstruction* context, string* project_id,
35 string* dataset_id, string* table_id,
36 int64* timestamp_millis, std::vector<string>* columns,
37 string* test_end_point) {
38 TF_RETURN_IF_ERROR(context->GetAttr("project_id", project_id));
39 TF_RETURN_IF_ERROR(context->GetAttr("dataset_id", dataset_id));
40 TF_RETURN_IF_ERROR(context->GetAttr("table_id", table_id));
41 TF_RETURN_IF_ERROR(context->GetAttr("timestamp_millis", timestamp_millis));
42 TF_RETURN_IF_ERROR(context->GetAttr("columns", columns));
43 TF_RETURN_IF_ERROR(context->GetAttr("test_end_point", test_end_point));
44 return Status::OK();
45 }
46
47 } // namespace
48
49 // Note that overridden methods with names ending in "Locked" are called by
50 // ReaderBase while a mutex is held.
51 // See comments for ReaderBase.
52 class BigQueryReader : public ReaderBase {
53 public:
BigQueryReader(BigQueryTableAccessor * bigquery_table_accessor,const string & node_name)54 explicit BigQueryReader(BigQueryTableAccessor* bigquery_table_accessor,
55 const string& node_name)
56 : ReaderBase(strings::StrCat("BigQueryReader '", node_name, "'")),
57 bigquery_table_accessor_(CHECK_NOTNULL(bigquery_table_accessor)) {}
58
OnWorkStartedLocked()59 Status OnWorkStartedLocked() override {
60 BigQueryTablePartition partition;
61 if (!partition.ParseFromString(current_work())) {
62 return errors::InvalidArgument(
63 "Could not parse work as valid partition.");
64 }
65 TF_RETURN_IF_ERROR(bigquery_table_accessor_->SetPartition(partition));
66 return Status::OK();
67 }
68
ReadLocked(string * key,string * value,bool * produced,bool * at_end)69 Status ReadLocked(string* key, string* value, bool* produced,
70 bool* at_end) override {
71 *at_end = false;
72 *produced = false;
73 if (bigquery_table_accessor_->Done()) {
74 *at_end = true;
75 return Status::OK();
76 }
77
78 Example example;
79 int64 row_id;
80 TF_RETURN_IF_ERROR(bigquery_table_accessor_->ReadRow(&row_id, &example));
81
82 *key = std::to_string(row_id);
83 *value = example.SerializeAsString();
84 *produced = true;
85 return Status::OK();
86 }
87
88 private:
89 // Not owned.
90 BigQueryTableAccessor* bigquery_table_accessor_;
91 };
92
93 class BigQueryReaderOp : public ReaderOpKernel {
94 public:
BigQueryReaderOp(OpKernelConstruction * context)95 explicit BigQueryReaderOp(OpKernelConstruction* context)
96 : ReaderOpKernel(context) {
97 string table_id;
98 string project_id;
99 string dataset_id;
100 int64 timestamp_millis;
101 std::vector<string> columns;
102 string test_end_point;
103
104 OP_REQUIRES_OK(context,
105 GetTableAttrs(context, &project_id, &dataset_id, &table_id,
106 ×tamp_millis, &columns, &test_end_point));
107 OP_REQUIRES_OK(context,
108 BigQueryTableAccessor::New(
109 project_id, dataset_id, table_id, timestamp_millis,
110 kDefaultRowBufferSize, test_end_point, columns,
111 BigQueryTablePartition(), &bigquery_table_accessor_));
112
113 SetReaderFactory([this]() {
114 return new BigQueryReader(bigquery_table_accessor_.get(), name());
115 });
116 }
117
118 private:
119 std::unique_ptr<BigQueryTableAccessor> bigquery_table_accessor_;
120 };
121
122 REGISTER_KERNEL_BUILDER(Name("BigQueryReader").Device(DEVICE_CPU),
123 BigQueryReaderOp);
124
125 class GenerateBigQueryReaderPartitionsOp : public OpKernel {
126 public:
GenerateBigQueryReaderPartitionsOp(OpKernelConstruction * context)127 explicit GenerateBigQueryReaderPartitionsOp(OpKernelConstruction* context)
128 : OpKernel(context) {
129 string project_id;
130 string dataset_id;
131 string table_id;
132 int64 timestamp_millis;
133 std::vector<string> columns;
134 string test_end_point;
135
136 OP_REQUIRES_OK(context,
137 GetTableAttrs(context, &project_id, &dataset_id, &table_id,
138 ×tamp_millis, &columns, &test_end_point));
139 OP_REQUIRES_OK(context,
140 BigQueryTableAccessor::New(
141 project_id, dataset_id, table_id, timestamp_millis,
142 kDefaultRowBufferSize, test_end_point, columns,
143 BigQueryTablePartition(), &bigquery_table_accessor_));
144 OP_REQUIRES_OK(context, InitializeNumberOfPartitions(context));
145 OP_REQUIRES_OK(context, InitializeTotalNumberOfRows());
146 }
147
Compute(OpKernelContext * context)148 void Compute(OpKernelContext* context) override {
149 const int64 partition_size = tensorflow::MathUtil::CeilOfRatio<int64>(
150 total_num_rows_, num_partitions_);
151 Tensor* output_tensor = nullptr;
152 OP_REQUIRES_OK(context,
153 context->allocate_output(0, TensorShape({num_partitions_}),
154 &output_tensor));
155
156 auto output = output_tensor->template flat<string>();
157 for (int64 i = 0; i < num_partitions_; ++i) {
158 BigQueryTablePartition partition;
159 partition.set_start_index(i * partition_size);
160 partition.set_end_index(
161 std::min(total_num_rows_, (i + 1) * partition_size) - 1);
162 output(i) = partition.SerializeAsString();
163 }
164 }
165
166 private:
InitializeTotalNumberOfRows()167 Status InitializeTotalNumberOfRows() {
168 total_num_rows_ = bigquery_table_accessor_->total_num_rows();
169 if (total_num_rows_ <= 0) {
170 return errors::FailedPrecondition("Invalid total number of rows.");
171 }
172 return Status::OK();
173 }
174
InitializeNumberOfPartitions(OpKernelConstruction * context)175 Status InitializeNumberOfPartitions(OpKernelConstruction* context) {
176 TF_RETURN_IF_ERROR(context->GetAttr("num_partitions", &num_partitions_));
177 if (num_partitions_ <= 0) {
178 return errors::FailedPrecondition("Invalid number of partitions.");
179 }
180 return Status::OK();
181 }
182
183 int64 num_partitions_;
184 int64 total_num_rows_;
185 std::unique_ptr<BigQueryTableAccessor> bigquery_table_accessor_;
186 };
187
188 REGISTER_KERNEL_BUILDER(
189 Name("GenerateBigQueryReaderPartitions").Device(DEVICE_CPU),
190 GenerateBigQueryReaderPartitionsOp);
191
192 } // namespace tensorflow
193