• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h"
16 
17 #include "tensorflow/core/example/feature.pb.h"
18 #include "tensorflow/core/lib/strings/numbers.h"
19 
20 namespace tensorflow {
21 namespace {
22 
23 constexpr size_t kBufferSize = 1024 * 1024;  // In bytes.
24 const string kBigQueryEndPoint = "https://www.googleapis.com/bigquery/v2";
25 
IsPartitionEmpty(const BigQueryTablePartition & partition)26 bool IsPartitionEmpty(const BigQueryTablePartition& partition) {
27   if (partition.end_index() != -1 &&
28       partition.end_index() < partition.start_index()) {
29     return true;
30   }
31   return false;
32 }
33 
ParseJson(StringPiece json,Json::Value * result)34 Status ParseJson(StringPiece json, Json::Value* result) {
35   Json::Reader reader;
36   if (!reader.parse(json.ToString(), *result)) {
37     return errors::Internal("Couldn't parse JSON response from BigQuery.");
38   }
39   return Status::OK();
40 }
41 
ParseColumnType(const string & type,BigQueryTableAccessor::ColumnType * enum_type)42 Status ParseColumnType(const string& type,
43                        BigQueryTableAccessor::ColumnType* enum_type) {
44   if (type == "RECORD") {
45     *enum_type = BigQueryTableAccessor::ColumnType::kRecord;
46   } else if (type == "STRING") {
47     *enum_type = BigQueryTableAccessor::ColumnType::kString;
48   } else if (type == "BYTES") {
49     *enum_type = BigQueryTableAccessor::ColumnType::kBytes;
50   } else if (type == "INTEGER") {
51     *enum_type = BigQueryTableAccessor::ColumnType::kInteger;
52   } else if (type == "FLOAT") {
53     *enum_type = BigQueryTableAccessor::ColumnType::kFloat;
54   } else if (type == "BOOLEAN") {
55     *enum_type = BigQueryTableAccessor::ColumnType::kBoolean;
56   } else if (type == "TIMESTAMP") {
57     *enum_type = BigQueryTableAccessor::ColumnType::kTimestamp;
58   } else if (type == "DATE") {
59     *enum_type = BigQueryTableAccessor::ColumnType::kDate;
60   } else if (type == "TIME") {
61     *enum_type = BigQueryTableAccessor::ColumnType::kTime;
62   } else if (type == "DATETIME") {
63     *enum_type = BigQueryTableAccessor::ColumnType::kDatetime;
64   } else {
65     return errors::Internal(
66         strings::StrCat("Could not parse column type ", type));
67   }
68   return Status::OK();
69 }
70 
71 }  // namespace
72 
New(const string & project_id,const string & dataset_id,const string & table_id,int64 timestamp_millis,int64 row_buffer_size,const string & end_point,const std::vector<string> & columns,const BigQueryTablePartition & partition,std::unique_ptr<BigQueryTableAccessor> * accessor)73 Status BigQueryTableAccessor::New(
74     const string& project_id, const string& dataset_id, const string& table_id,
75     int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
76     const std::vector<string>& columns, const BigQueryTablePartition& partition,
77     std::unique_ptr<BigQueryTableAccessor>* accessor) {
78   return New(project_id, dataset_id, table_id, timestamp_millis,
79              row_buffer_size, end_point, columns, partition, nullptr, nullptr,
80              accessor);
81 }
82 
New(const string & project_id,const string & dataset_id,const string & table_id,int64 timestamp_millis,int64 row_buffer_size,const string & end_point,const std::vector<string> & columns,const BigQueryTablePartition & partition,std::unique_ptr<AuthProvider> auth_provider,std::unique_ptr<HttpRequest::Factory> http_request_factory,std::unique_ptr<BigQueryTableAccessor> * accessor)83 Status BigQueryTableAccessor::New(
84     const string& project_id, const string& dataset_id, const string& table_id,
85     int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
86     const std::vector<string>& columns, const BigQueryTablePartition& partition,
87     std::unique_ptr<AuthProvider> auth_provider,
88     std::unique_ptr<HttpRequest::Factory> http_request_factory,
89     std::unique_ptr<BigQueryTableAccessor>* accessor) {
90   if (timestamp_millis <= 0) {
91     return errors::InvalidArgument(
92         "Cannot use zero or negative timestamp to query a table.");
93   }
94   const string& big_query_end_point =
95       end_point.empty() ? kBigQueryEndPoint : end_point;
96   if (auth_provider == nullptr && http_request_factory == nullptr) {
97     accessor->reset(new BigQueryTableAccessor(
98         project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
99         big_query_end_point, columns, partition));
100   } else {
101     accessor->reset(new BigQueryTableAccessor(
102         project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
103         big_query_end_point, columns, partition, std::move(auth_provider),
104         std::move(http_request_factory)));
105   }
106   return (*accessor)->ReadSchema();
107 }
108 
BigQueryTableAccessor(const string & project_id,const string & dataset_id,const string & table_id,int64 timestamp_millis,int64 row_buffer_size,const string & end_point,const std::vector<string> & columns,const BigQueryTablePartition & partition)109 BigQueryTableAccessor::BigQueryTableAccessor(
110     const string& project_id, const string& dataset_id, const string& table_id,
111     int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
112     const std::vector<string>& columns, const BigQueryTablePartition& partition)
113     : BigQueryTableAccessor(
114           project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
115           end_point, columns, partition,
116           std::unique_ptr<AuthProvider>(new GoogleAuthProvider()),
117           std::unique_ptr<HttpRequest::Factory>(
118               new CurlHttpRequest::Factory())) {
119   row_buffer_.resize(row_buffer_size);
120 }
121 
BigQueryTableAccessor(const string & project_id,const string & dataset_id,const string & table_id,int64 timestamp_millis,int64 row_buffer_size,const string & end_point,const std::vector<string> & columns,const BigQueryTablePartition & partition,std::unique_ptr<AuthProvider> auth_provider,std::unique_ptr<HttpRequest::Factory> http_request_factory)122 BigQueryTableAccessor::BigQueryTableAccessor(
123     const string& project_id, const string& dataset_id, const string& table_id,
124     int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
125     const std::vector<string>& columns, const BigQueryTablePartition& partition,
126     std::unique_ptr<AuthProvider> auth_provider,
127     std::unique_ptr<HttpRequest::Factory> http_request_factory)
128     : project_id_(project_id),
129       dataset_id_(dataset_id),
130       table_id_(table_id),
131       timestamp_millis_(timestamp_millis),
132       columns_(columns.begin(), columns.end()),
133       bigquery_end_point_(end_point),
134       partition_(partition),
135       auth_provider_(std::move(auth_provider)),
136       http_request_factory_(std::move(http_request_factory)) {
137   row_buffer_.resize(row_buffer_size);
138   Reset();
139 }
140 
SetPartition(const BigQueryTablePartition & partition)141 Status BigQueryTableAccessor::SetPartition(
142     const BigQueryTablePartition& partition) {
143   if (partition.start_index() < 0) {
144     return errors::InvalidArgument("Start index cannot be negative.");
145   }
146   partition_ = partition;
147   Reset();
148   return Status::OK();
149 }
150 
Reset()151 void BigQueryTableAccessor::Reset() {
152   first_buffered_row_index_ = partition_.start_index();
153   next_row_in_buffer_ = -1;
154   next_page_token_ = "";
155 }
156 
ReadRow(int64 * row_id,Example * example)157 Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) {
158   if (Done()) {
159     return errors::OutOfRange("Reached end of table ", FullTableName());
160   }
161 
162   // If the next row is already fetched and cached, return the row from the
163   // buffer. Otherwise, fill up the row buffer from BigQuery and return a row.
164   if (next_row_in_buffer_ != -1 &&
165       next_row_in_buffer_ < ComputeMaxResultsArg()) {
166     *row_id = first_buffered_row_index_ + next_row_in_buffer_;
167     *example = row_buffer_[next_row_in_buffer_];
168     next_row_in_buffer_++;
169   } else {
170     string auth_token;
171     TF_RETURN_IF_ERROR(
172         AuthProvider::GetToken(auth_provider_.get(), &auth_token));
173 
174     std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
175     std::vector<char> output_buffer;
176     output_buffer.reserve(kBufferSize);
177 
178     // The first time that we access BigQuery there is no page token. After that
179     // we use the page token (which returns rows faster).
180     if (!next_page_token_.empty()) {
181       request->SetUri(strings::StrCat(
182           BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(),
183           "&pageToken=", request->EscapeString(next_page_token_)));
184       first_buffered_row_index_ += row_buffer_.size();
185     } else {
186       request->SetUri(strings::StrCat(
187           BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(),
188           "&startIndex=", first_buffered_row_index_));
189     }
190     request->AddAuthBearerHeader(auth_token);
191     request->SetResultBuffer(&output_buffer);
192     TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading rows from ",
193                                     FullTableName());
194 
195     // Parse the returned row.
196     StringPiece response_piece =
197         StringPiece(&output_buffer[0], output_buffer.size());
198     Json::Value root;
199     TF_RETURN_IF_ERROR(ParseJson(response_piece, &root));
200     for (unsigned int i = 0; i < root["rows"].size(); ++i) {
201       row_buffer_[i].Clear();
202       TF_RETURN_IF_ERROR(
203           ParseColumnValues(root["rows"][i], schema_root_, &row_buffer_[i]));
204     }
205 
206     next_page_token_ = root["pageToken"].asString();
207     *row_id = first_buffered_row_index_;
208     *example = row_buffer_[0];
209     next_row_in_buffer_ = 1;
210   }
211   return Status::OK();
212 }
213 
ComputeMaxResultsArg()214 int64 BigQueryTableAccessor::ComputeMaxResultsArg() {
215   if (partition_.end_index() == -1) {
216     return row_buffer_.size();
217   }
218   if (IsPartitionEmpty(partition_)) {
219     return 0;
220   }
221   return std::min(static_cast<int64>(row_buffer_.size()),
222                   static_cast<int64>(partition_.end_index() -
223                                      partition_.start_index() + 1));
224 }
225 
ParseColumnValues(const Json::Value & value,const SchemaNode & root_schema_node,Example * example)226 Status BigQueryTableAccessor::ParseColumnValues(
227     const Json::Value& value, const SchemaNode& root_schema_node,
228     Example* example) {
229   if (value.empty()) {
230     return Status::OK();
231   }
232   if (value["f"].isNull()) {
233     return Status::OK();
234   }
235   int value_index = 0;
236   for (const auto& schema_node : root_schema_node.schema_nodes) {
237     if (value["f"][value_index].isNull()) {
238       value_index++;
239       continue;
240     }
241 
242     if (schema_node.type == ColumnType::kRecord) {
243       TF_RETURN_IF_ERROR(ParseColumnValues(value["f"][value_index]["v"],
244                                            schema_node, example));
245     } else {
246       // Append the column value only if user has requested the column.
247       if (columns_.empty() ||
248           columns_.find(schema_node.name) != columns_.end()) {
249         TF_RETURN_IF_ERROR(AppendValueToExample(schema_node.name,
250                                                 value["f"][value_index]["v"],
251                                                 schema_node.type, example));
252       }
253     }
254     value_index++;
255   }
256   return Status::OK();
257 }
258 
ReadSchema()259 Status BigQueryTableAccessor::ReadSchema() {
260   string auth_token;
261   TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_.get(), &auth_token));
262 
263   // Send a request to read the schema.
264   std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
265   std::vector<char> output_buffer;
266   output_buffer.reserve(kBufferSize);
267   request->SetUri(BigQueryUriPrefix());
268   request->AddAuthBearerHeader(auth_token);
269   request->SetResultBuffer(&output_buffer);
270   TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading schema for ",
271                                   FullTableName());
272 
273   // Parse the schema.
274   StringPiece response_piece =
275       StringPiece(&output_buffer[0], output_buffer.size());
276 
277   Json::Value root;
278   TF_RETURN_IF_ERROR(ParseJson(response_piece, &root));
279   const auto& columns = root["schema"]["fields"];
280   string column_name_prefix = "";
281   schema_root_ = {"", ColumnType::kNone};
282   TF_RETURN_IF_ERROR(
283       ExtractColumnType(columns, column_name_prefix, &schema_root_));
284   if (root["numRows"].isNull()) {
285     return errors::Internal("Number of rows cannot be extracted for table ",
286                             FullTableName());
287   }
288   strings::safe_strto64(root["numRows"].asString().c_str(), &total_num_rows_);
289   return Status::OK();
290 }
291 
ExtractColumnType(const Json::Value & columns,const string & column_name_prefix,SchemaNode * root)292 Status BigQueryTableAccessor::ExtractColumnType(
293     const Json::Value& columns, const string& column_name_prefix,
294     SchemaNode* root) {
295   for (auto columns_it = columns.begin(); columns_it != columns.end();
296        ++columns_it) {
297     if ((*columns_it)["mode"].asString() == "REPEATED") {
298       return errors::Unimplemented(strings::StrCat(
299           "Tables with repeated columns are not supported: ", FullTableName()));
300     }
301     ColumnType type;
302     const string current_column_name = strings::StrCat(
303         column_name_prefix, (*columns_it)["name"].asString().c_str());
304     TF_RETURN_IF_ERROR(
305         ParseColumnType((*columns_it)["type"].asString().c_str(), &type));
306     root->schema_nodes.emplace_back(current_column_name, type);
307     if (type == ColumnType::kRecord) {
308       const auto new_prefix = strings::StrCat(current_column_name, ".");
309       TF_RETURN_IF_ERROR(ExtractColumnType((*columns_it)["fields"], new_prefix,
310                                            &root->schema_nodes.back()));
311     }
312   }
313   return Status::OK();
314 }
315 
AppendValueToExample(const string & column_name,const Json::Value & column_value,const BigQueryTableAccessor::ColumnType type,Example * example)316 Status BigQueryTableAccessor::AppendValueToExample(
317     const string& column_name, const Json::Value& column_value,
318     const BigQueryTableAccessor::ColumnType type, Example* example) {
319   if (column_value.isNull()) {
320     return Status::OK();
321   }
322   auto& feature =
323       (*example->mutable_features()->mutable_feature())[column_name];
324 
325   switch (type) {
326     case BigQueryTableAccessor::ColumnType::kNone:
327     case BigQueryTableAccessor::ColumnType::kRecord:
328       return errors::Unimplemented("Cannot append type to an example.");
329     case BigQueryTableAccessor::ColumnType::kTimestamp:
330     case BigQueryTableAccessor::ColumnType::kDate:
331     case BigQueryTableAccessor::ColumnType::kTime:
332     case BigQueryTableAccessor::ColumnType::kDatetime:
333     case BigQueryTableAccessor::ColumnType::kString:
334     case BigQueryTableAccessor::ColumnType::kBytes:
335       feature.mutable_bytes_list()->add_value(column_value.asString());
336       break;
337     case BigQueryTableAccessor::ColumnType::kBoolean:
338       feature.mutable_int64_list()->add_value(
339           column_value.asString() == "false" ? 0 : 1);
340       break;
341     case BigQueryTableAccessor::ColumnType::kInteger:
342       int64 column_value_int64;
343       if (!strings::safe_strto64(column_value.asString().c_str(),
344                                  &column_value_int64)) {
345         return errors::Internal("Cannot convert value to integer ",
346                                 column_value.asString().c_str());
347       }
348       feature.mutable_int64_list()->add_value(column_value_int64);
349       break;
350     case BigQueryTableAccessor::ColumnType::kFloat:
351       // BigQuery float is actually a double.
352       double column_value_double;
353       if (!strings::safe_strtod(column_value.asString().c_str(),
354                                 &column_value_double)) {
355         return errors::Internal("Cannot convert value to double: ",
356                                 column_value.asString().c_str());
357       }
358       feature.mutable_float_list()->add_value(
359           static_cast<float>(column_value_double));
360       break;
361   }
362   return Status::OK();
363 }
364 
BigQueryUriPrefix()365 string BigQueryTableAccessor::BigQueryTableAccessor::BigQueryUriPrefix() {
366   CurlHttpRequest request;
367   return strings::StrCat(bigquery_end_point_, "/projects/",
368                          request.EscapeString(project_id_), "/datasets/",
369                          request.EscapeString(dataset_id_), "/tables/",
370                          request.EscapeString(table_id_), "/");
371 }
372 
Done()373 bool BigQueryTableAccessor::Done() {
374   return (total_num_rows_ <= first_buffered_row_index_ + next_row_in_buffer_) ||
375          IsPartitionEmpty(partition_) ||
376          (partition_.end_index() != -1 &&
377           partition_.end_index() <
378               first_buffered_row_index_ + next_row_in_buffer_);
379 }
380 
381 }  // namespace tensorflow
382