1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h"
16
17 #include "tensorflow/core/example/feature.pb.h"
18 #include "tensorflow/core/lib/strings/numbers.h"
19
20 namespace tensorflow {
21 namespace {
22
23 constexpr size_t kBufferSize = 1024 * 1024; // In bytes.
24 const string kBigQueryEndPoint = "https://www.googleapis.com/bigquery/v2";
25
IsPartitionEmpty(const BigQueryTablePartition & partition)26 bool IsPartitionEmpty(const BigQueryTablePartition& partition) {
27 if (partition.end_index() != -1 &&
28 partition.end_index() < partition.start_index()) {
29 return true;
30 }
31 return false;
32 }
33
ParseJson(StringPiece json,Json::Value * result)34 Status ParseJson(StringPiece json, Json::Value* result) {
35 Json::Reader reader;
36 if (!reader.parse(json.ToString(), *result)) {
37 return errors::Internal("Couldn't parse JSON response from BigQuery.");
38 }
39 return Status::OK();
40 }
41
ParseColumnType(const string & type,BigQueryTableAccessor::ColumnType * enum_type)42 Status ParseColumnType(const string& type,
43 BigQueryTableAccessor::ColumnType* enum_type) {
44 if (type == "RECORD") {
45 *enum_type = BigQueryTableAccessor::ColumnType::kRecord;
46 } else if (type == "STRING") {
47 *enum_type = BigQueryTableAccessor::ColumnType::kString;
48 } else if (type == "BYTES") {
49 *enum_type = BigQueryTableAccessor::ColumnType::kBytes;
50 } else if (type == "INTEGER") {
51 *enum_type = BigQueryTableAccessor::ColumnType::kInteger;
52 } else if (type == "FLOAT") {
53 *enum_type = BigQueryTableAccessor::ColumnType::kFloat;
54 } else if (type == "BOOLEAN") {
55 *enum_type = BigQueryTableAccessor::ColumnType::kBoolean;
56 } else if (type == "TIMESTAMP") {
57 *enum_type = BigQueryTableAccessor::ColumnType::kTimestamp;
58 } else if (type == "DATE") {
59 *enum_type = BigQueryTableAccessor::ColumnType::kDate;
60 } else if (type == "TIME") {
61 *enum_type = BigQueryTableAccessor::ColumnType::kTime;
62 } else if (type == "DATETIME") {
63 *enum_type = BigQueryTableAccessor::ColumnType::kDatetime;
64 } else {
65 return errors::Internal(
66 strings::StrCat("Could not parse column type ", type));
67 }
68 return Status::OK();
69 }
70
71 } // namespace
72
New(const string & project_id,const string & dataset_id,const string & table_id,int64 timestamp_millis,int64 row_buffer_size,const string & end_point,const std::vector<string> & columns,const BigQueryTablePartition & partition,std::unique_ptr<BigQueryTableAccessor> * accessor)73 Status BigQueryTableAccessor::New(
74 const string& project_id, const string& dataset_id, const string& table_id,
75 int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
76 const std::vector<string>& columns, const BigQueryTablePartition& partition,
77 std::unique_ptr<BigQueryTableAccessor>* accessor) {
78 return New(project_id, dataset_id, table_id, timestamp_millis,
79 row_buffer_size, end_point, columns, partition, nullptr, nullptr,
80 accessor);
81 }
82
New(const string & project_id,const string & dataset_id,const string & table_id,int64 timestamp_millis,int64 row_buffer_size,const string & end_point,const std::vector<string> & columns,const BigQueryTablePartition & partition,std::unique_ptr<AuthProvider> auth_provider,std::unique_ptr<HttpRequest::Factory> http_request_factory,std::unique_ptr<BigQueryTableAccessor> * accessor)83 Status BigQueryTableAccessor::New(
84 const string& project_id, const string& dataset_id, const string& table_id,
85 int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
86 const std::vector<string>& columns, const BigQueryTablePartition& partition,
87 std::unique_ptr<AuthProvider> auth_provider,
88 std::unique_ptr<HttpRequest::Factory> http_request_factory,
89 std::unique_ptr<BigQueryTableAccessor>* accessor) {
90 if (timestamp_millis <= 0) {
91 return errors::InvalidArgument(
92 "Cannot use zero or negative timestamp to query a table.");
93 }
94 const string& big_query_end_point =
95 end_point.empty() ? kBigQueryEndPoint : end_point;
96 if (auth_provider == nullptr && http_request_factory == nullptr) {
97 accessor->reset(new BigQueryTableAccessor(
98 project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
99 big_query_end_point, columns, partition));
100 } else {
101 accessor->reset(new BigQueryTableAccessor(
102 project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
103 big_query_end_point, columns, partition, std::move(auth_provider),
104 std::move(http_request_factory)));
105 }
106 return (*accessor)->ReadSchema();
107 }
108
BigQueryTableAccessor(const string & project_id,const string & dataset_id,const string & table_id,int64 timestamp_millis,int64 row_buffer_size,const string & end_point,const std::vector<string> & columns,const BigQueryTablePartition & partition)109 BigQueryTableAccessor::BigQueryTableAccessor(
110 const string& project_id, const string& dataset_id, const string& table_id,
111 int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
112 const std::vector<string>& columns, const BigQueryTablePartition& partition)
113 : BigQueryTableAccessor(
114 project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
115 end_point, columns, partition,
116 std::unique_ptr<AuthProvider>(new GoogleAuthProvider()),
117 std::unique_ptr<HttpRequest::Factory>(
118 new CurlHttpRequest::Factory())) {
119 row_buffer_.resize(row_buffer_size);
120 }
121
BigQueryTableAccessor(const string & project_id,const string & dataset_id,const string & table_id,int64 timestamp_millis,int64 row_buffer_size,const string & end_point,const std::vector<string> & columns,const BigQueryTablePartition & partition,std::unique_ptr<AuthProvider> auth_provider,std::unique_ptr<HttpRequest::Factory> http_request_factory)122 BigQueryTableAccessor::BigQueryTableAccessor(
123 const string& project_id, const string& dataset_id, const string& table_id,
124 int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
125 const std::vector<string>& columns, const BigQueryTablePartition& partition,
126 std::unique_ptr<AuthProvider> auth_provider,
127 std::unique_ptr<HttpRequest::Factory> http_request_factory)
128 : project_id_(project_id),
129 dataset_id_(dataset_id),
130 table_id_(table_id),
131 timestamp_millis_(timestamp_millis),
132 columns_(columns.begin(), columns.end()),
133 bigquery_end_point_(end_point),
134 partition_(partition),
135 auth_provider_(std::move(auth_provider)),
136 http_request_factory_(std::move(http_request_factory)) {
137 row_buffer_.resize(row_buffer_size);
138 Reset();
139 }
140
SetPartition(const BigQueryTablePartition & partition)141 Status BigQueryTableAccessor::SetPartition(
142 const BigQueryTablePartition& partition) {
143 if (partition.start_index() < 0) {
144 return errors::InvalidArgument("Start index cannot be negative.");
145 }
146 partition_ = partition;
147 Reset();
148 return Status::OK();
149 }
150
Reset()151 void BigQueryTableAccessor::Reset() {
152 first_buffered_row_index_ = partition_.start_index();
153 next_row_in_buffer_ = -1;
154 next_page_token_ = "";
155 }
156
ReadRow(int64 * row_id,Example * example)157 Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) {
158 if (Done()) {
159 return errors::OutOfRange("Reached end of table ", FullTableName());
160 }
161
162 // If the next row is already fetched and cached, return the row from the
163 // buffer. Otherwise, fill up the row buffer from BigQuery and return a row.
164 if (next_row_in_buffer_ != -1 &&
165 next_row_in_buffer_ < ComputeMaxResultsArg()) {
166 *row_id = first_buffered_row_index_ + next_row_in_buffer_;
167 *example = row_buffer_[next_row_in_buffer_];
168 next_row_in_buffer_++;
169 } else {
170 string auth_token;
171 TF_RETURN_IF_ERROR(
172 AuthProvider::GetToken(auth_provider_.get(), &auth_token));
173
174 std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
175 std::vector<char> output_buffer;
176 output_buffer.reserve(kBufferSize);
177
178 // The first time that we access BigQuery there is no page token. After that
179 // we use the page token (which returns rows faster).
180 if (!next_page_token_.empty()) {
181 request->SetUri(strings::StrCat(
182 BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(),
183 "&pageToken=", request->EscapeString(next_page_token_)));
184 first_buffered_row_index_ += row_buffer_.size();
185 } else {
186 request->SetUri(strings::StrCat(
187 BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(),
188 "&startIndex=", first_buffered_row_index_));
189 }
190 request->AddAuthBearerHeader(auth_token);
191 request->SetResultBuffer(&output_buffer);
192 TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading rows from ",
193 FullTableName());
194
195 // Parse the returned row.
196 StringPiece response_piece =
197 StringPiece(&output_buffer[0], output_buffer.size());
198 Json::Value root;
199 TF_RETURN_IF_ERROR(ParseJson(response_piece, &root));
200 for (unsigned int i = 0; i < root["rows"].size(); ++i) {
201 row_buffer_[i].Clear();
202 TF_RETURN_IF_ERROR(
203 ParseColumnValues(root["rows"][i], schema_root_, &row_buffer_[i]));
204 }
205
206 next_page_token_ = root["pageToken"].asString();
207 *row_id = first_buffered_row_index_;
208 *example = row_buffer_[0];
209 next_row_in_buffer_ = 1;
210 }
211 return Status::OK();
212 }
213
ComputeMaxResultsArg()214 int64 BigQueryTableAccessor::ComputeMaxResultsArg() {
215 if (partition_.end_index() == -1) {
216 return row_buffer_.size();
217 }
218 if (IsPartitionEmpty(partition_)) {
219 return 0;
220 }
221 return std::min(static_cast<int64>(row_buffer_.size()),
222 static_cast<int64>(partition_.end_index() -
223 partition_.start_index() + 1));
224 }
225
ParseColumnValues(const Json::Value & value,const SchemaNode & root_schema_node,Example * example)226 Status BigQueryTableAccessor::ParseColumnValues(
227 const Json::Value& value, const SchemaNode& root_schema_node,
228 Example* example) {
229 if (value.empty()) {
230 return Status::OK();
231 }
232 if (value["f"].isNull()) {
233 return Status::OK();
234 }
235 int value_index = 0;
236 for (const auto& schema_node : root_schema_node.schema_nodes) {
237 if (value["f"][value_index].isNull()) {
238 value_index++;
239 continue;
240 }
241
242 if (schema_node.type == ColumnType::kRecord) {
243 TF_RETURN_IF_ERROR(ParseColumnValues(value["f"][value_index]["v"],
244 schema_node, example));
245 } else {
246 // Append the column value only if user has requested the column.
247 if (columns_.empty() ||
248 columns_.find(schema_node.name) != columns_.end()) {
249 TF_RETURN_IF_ERROR(AppendValueToExample(schema_node.name,
250 value["f"][value_index]["v"],
251 schema_node.type, example));
252 }
253 }
254 value_index++;
255 }
256 return Status::OK();
257 }
258
ReadSchema()259 Status BigQueryTableAccessor::ReadSchema() {
260 string auth_token;
261 TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_.get(), &auth_token));
262
263 // Send a request to read the schema.
264 std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
265 std::vector<char> output_buffer;
266 output_buffer.reserve(kBufferSize);
267 request->SetUri(BigQueryUriPrefix());
268 request->AddAuthBearerHeader(auth_token);
269 request->SetResultBuffer(&output_buffer);
270 TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading schema for ",
271 FullTableName());
272
273 // Parse the schema.
274 StringPiece response_piece =
275 StringPiece(&output_buffer[0], output_buffer.size());
276
277 Json::Value root;
278 TF_RETURN_IF_ERROR(ParseJson(response_piece, &root));
279 const auto& columns = root["schema"]["fields"];
280 string column_name_prefix = "";
281 schema_root_ = {"", ColumnType::kNone};
282 TF_RETURN_IF_ERROR(
283 ExtractColumnType(columns, column_name_prefix, &schema_root_));
284 if (root["numRows"].isNull()) {
285 return errors::Internal("Number of rows cannot be extracted for table ",
286 FullTableName());
287 }
288 strings::safe_strto64(root["numRows"].asString().c_str(), &total_num_rows_);
289 return Status::OK();
290 }
291
ExtractColumnType(const Json::Value & columns,const string & column_name_prefix,SchemaNode * root)292 Status BigQueryTableAccessor::ExtractColumnType(
293 const Json::Value& columns, const string& column_name_prefix,
294 SchemaNode* root) {
295 for (auto columns_it = columns.begin(); columns_it != columns.end();
296 ++columns_it) {
297 if ((*columns_it)["mode"].asString() == "REPEATED") {
298 return errors::Unimplemented(strings::StrCat(
299 "Tables with repeated columns are not supported: ", FullTableName()));
300 }
301 ColumnType type;
302 const string current_column_name = strings::StrCat(
303 column_name_prefix, (*columns_it)["name"].asString().c_str());
304 TF_RETURN_IF_ERROR(
305 ParseColumnType((*columns_it)["type"].asString().c_str(), &type));
306 root->schema_nodes.emplace_back(current_column_name, type);
307 if (type == ColumnType::kRecord) {
308 const auto new_prefix = strings::StrCat(current_column_name, ".");
309 TF_RETURN_IF_ERROR(ExtractColumnType((*columns_it)["fields"], new_prefix,
310 &root->schema_nodes.back()));
311 }
312 }
313 return Status::OK();
314 }
315
AppendValueToExample(const string & column_name,const Json::Value & column_value,const BigQueryTableAccessor::ColumnType type,Example * example)316 Status BigQueryTableAccessor::AppendValueToExample(
317 const string& column_name, const Json::Value& column_value,
318 const BigQueryTableAccessor::ColumnType type, Example* example) {
319 if (column_value.isNull()) {
320 return Status::OK();
321 }
322 auto& feature =
323 (*example->mutable_features()->mutable_feature())[column_name];
324
325 switch (type) {
326 case BigQueryTableAccessor::ColumnType::kNone:
327 case BigQueryTableAccessor::ColumnType::kRecord:
328 return errors::Unimplemented("Cannot append type to an example.");
329 case BigQueryTableAccessor::ColumnType::kTimestamp:
330 case BigQueryTableAccessor::ColumnType::kDate:
331 case BigQueryTableAccessor::ColumnType::kTime:
332 case BigQueryTableAccessor::ColumnType::kDatetime:
333 case BigQueryTableAccessor::ColumnType::kString:
334 case BigQueryTableAccessor::ColumnType::kBytes:
335 feature.mutable_bytes_list()->add_value(column_value.asString());
336 break;
337 case BigQueryTableAccessor::ColumnType::kBoolean:
338 feature.mutable_int64_list()->add_value(
339 column_value.asString() == "false" ? 0 : 1);
340 break;
341 case BigQueryTableAccessor::ColumnType::kInteger:
342 int64 column_value_int64;
343 if (!strings::safe_strto64(column_value.asString().c_str(),
344 &column_value_int64)) {
345 return errors::Internal("Cannot convert value to integer ",
346 column_value.asString().c_str());
347 }
348 feature.mutable_int64_list()->add_value(column_value_int64);
349 break;
350 case BigQueryTableAccessor::ColumnType::kFloat:
351 // BigQuery float is actually a double.
352 double column_value_double;
353 if (!strings::safe_strtod(column_value.asString().c_str(),
354 &column_value_double)) {
355 return errors::Internal("Cannot convert value to double: ",
356 column_value.asString().c_str());
357 }
358 feature.mutable_float_list()->add_value(
359 static_cast<float>(column_value_double));
360 break;
361 }
362 return Status::OK();
363 }
364
BigQueryUriPrefix()365 string BigQueryTableAccessor::BigQueryTableAccessor::BigQueryUriPrefix() {
366 CurlHttpRequest request;
367 return strings::StrCat(bigquery_end_point_, "/projects/",
368 request.EscapeString(project_id_), "/datasets/",
369 request.EscapeString(dataset_id_), "/tables/",
370 request.EscapeString(table_id_), "/");
371 }
372
Done()373 bool BigQueryTableAccessor::Done() {
374 return (total_num_rows_ <= first_buffered_row_index_ + next_row_in_buffer_) ||
375 IsPartitionEmpty(partition_) ||
376 (partition_.end_index() != -1 &&
377 partition_.end_index() <
378 first_buffered_row_index_ + next_row_in_buffer_);
379 }
380
381 } // namespace tensorflow
382