• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/computation_builder.h"
17 
18 #include <stddef.h>
19 #include <array>
20 #include <numeric>
21 #include <set>
22 #include <vector>
23 
24 #include "tensorflow/compiler/xla/ptr_util.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/compiler/xla/xla.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/protobuf.h"
34 
35 namespace xla {
36 
ComputationBuilder(Client * client,const string & computation_name)37 ComputationBuilder::ComputationBuilder(Client* client,
38                                        const string& computation_name)
39     : name_(computation_name), client_(client) {}
40 
~ComputationBuilder()41 ComputationBuilder::~ComputationBuilder() {}
42 
NoteError(const Status & error)43 void ComputationBuilder::NoteError(const Status& error) {
44   if (die_immediately_on_error_) {
45     LOG(FATAL) << "error building computation: " << error;
46   }
47 
48   if (first_error_.ok()) {
49     first_error_ = error;
50     first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
51   }
52 }
53 
CreateSubBuilder(const string & computation_name)54 std::unique_ptr<ComputationBuilder> ComputationBuilder::CreateSubBuilder(
55     const string& computation_name) {
56   auto sub_builder = MakeUnique<ComputationBuilder>(client_, computation_name);
57   sub_builder->parent_builder_ = this;
58   sub_builder->die_immediately_on_error_ = die_immediately_on_error_;
59   return sub_builder;
60 }
61 
PrepareComputation()62 Status ComputationBuilder::PrepareComputation() {
63   TF_RETURN_IF_ERROR(first_error_);
64 
65   if (!computation_.IsNull()) {
66     return Status::OK();
67   }
68 
69   ComputationRequest request;
70   request.set_name(name_);
71   ComputationResponse response;
72 
73   VLOG(2) << "making computation request";
74   Status s = client_->stub()->Computation(&request, &response);
75   VLOG(2) << "done with computation request";
76 
77   if (!s.ok()) {
78     NoteError(s);
79     return first_error_;
80   }
81 
82   computation_ = Computation(client_->stub(), response.computation());
83   return Status::OK();
84 }
85 
RunOp(OpRequest * op_request,OpResponse * op_response)86 Status ComputationBuilder::RunOp(OpRequest* op_request,
87                                  OpResponse* op_response) {
88   TF_RETURN_IF_ERROR(first_error_);
89   TF_RETURN_IF_ERROR(PrepareComputation());
90 
91   // Fill in fields that are set on every OpRequest.
92   *op_request->mutable_computation() = computation_.handle();
93   *op_request->mutable_metadata() = metadata_;
94   if (sharding_) {
95     *op_request->mutable_sharding() = *sharding_;
96   }
97 
98   const string& op_name =
99       OpRequest::descriptor()->FindFieldByNumber(op_request->op_case())->name();
100   VLOG(2) << "running op request: " << op_name;
101   Status status = client_->stub()->Op(op_request, op_response);
102   VLOG(2) << "done with op request: " << op_name;
103   return status;
104 }
105 
RunOpAndNoteError(OpRequest * op_request)106 void ComputationBuilder::RunOpAndNoteError(OpRequest* op_request) {
107   OpResponse op_response;
108   Status status = RunOp(op_request, &op_response);
109   if (!status.ok()) {
110     NoteError(status);
111   }
112 }
113 
RunOpAndParseResponse(OpRequest * op_request)114 ComputationDataHandle ComputationBuilder::RunOpAndParseResponse(
115     OpRequest* op_request) {
116   OpResponse op_response;
117   Status status = RunOp(op_request, &op_response);
118   if (!status.ok()) {
119     NoteError(status);
120     return ComputationDataHandle();
121   }
122   if (op_response.output().handle() == 0) {
123     NoteError(InternalError("No output handle"));
124     return ComputationDataHandle();
125   }
126   return op_response.output();
127 }
128 
MakeWindow(tensorflow::gtl::ArraySlice<int64> window_dimensions,tensorflow::gtl::ArraySlice<int64> window_strides,tensorflow::gtl::ArraySlice<std::pair<int64,int64>> padding,tensorflow::gtl::ArraySlice<int64> lhs_dilation,tensorflow::gtl::ArraySlice<int64> rhs_dilation,Window * window)129 bool ComputationBuilder::MakeWindow(
130     tensorflow::gtl::ArraySlice<int64> window_dimensions,
131     tensorflow::gtl::ArraySlice<int64> window_strides,
132     tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
133     tensorflow::gtl::ArraySlice<int64> lhs_dilation,
134     tensorflow::gtl::ArraySlice<int64> rhs_dilation, Window* window) {
135   const auto verify_size = [&](const size_t x, const char* x_name) {
136     if (x == 0 || x == window_dimensions.size()) {
137       return true;
138     } else {
139       NoteError(InvalidArgument(
140           "%s", tensorflow::strings::StrCat(
141                     "Window has different number of window dimensions than of ",
142                     x_name, "\nNumber of window dimensions: ",
143                     window_dimensions.size(), "\nNumber of ", x_name, ": ", x,
144                     "\n")
145                     .c_str()));  //
146       return false;
147     }
148   };
149   if (!verify_size(window_strides.size(), "window strides") ||
150       !verify_size(padding.size(), "padding entries") ||
151       !verify_size(lhs_dilation.size(), "lhs dilation factors") ||
152       !verify_size(rhs_dilation.size(), "rhs dilation factors")) {
153     return false;
154   }
155 
156   window->Clear();
157   for (size_t i = 0; i < window_dimensions.size(); i++) {
158     auto dim = window->add_dimensions();
159     dim->set_size(window_dimensions[i]);
160     if (!window_strides.empty()) {
161       dim->set_stride(window_strides[i]);
162     } else {
163       dim->set_stride(1);
164     }
165     if (!padding.empty()) {
166       dim->set_padding_low(padding[i].first);
167       dim->set_padding_high(padding[i].second);
168     } else {
169       dim->set_padding_low(0);
170       dim->set_padding_high(0);
171     }
172     if (!lhs_dilation.empty()) {
173       dim->set_base_dilation(lhs_dilation[i]);
174     } else {
175       dim->set_base_dilation(1);
176     }
177     if (!rhs_dilation.empty()) {
178       dim->set_window_dilation(rhs_dilation[i]);
179     } else {
180       dim->set_window_dilation(1);
181     }
182     dim->set_window_reversal(false);
183   }
184   return true;
185 }
186 
ConstantLiteral(const Literal & literal)187 ComputationDataHandle ComputationBuilder::ConstantLiteral(
188     const Literal& literal) {
189   OpRequest op_request;
190   ConstantRequest* request = op_request.mutable_constant_request();
191   *request->mutable_literal() = literal.ToProto();
192   VLOG(3) << "created constant: " << request->literal().ShortDebugString();
193   return RunOpAndParseResponse(&op_request);
194 }
195 
Parameter(int64 parameter_number,const Shape & shape,const string & name)196 ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number,
197                                                     const Shape& shape,
198                                                     const string& name) {
199   OpRequest op_request;
200   ParameterRequest* request = op_request.mutable_parameter_request();
201   *request->mutable_shape() = shape;
202   request->set_parameter(parameter_number);
203   request->set_name(name);
204   return RunOpAndParseResponse(&op_request);
205 }
206 
GetShapeWithoutNoteError(const ComputationDataHandle & operand)207 StatusOr<std::unique_ptr<Shape>> ComputationBuilder::GetShapeWithoutNoteError(
208     const ComputationDataHandle& operand) {
209   GetLocalShapeRequest request;
210   *request.mutable_computation() = computation_.handle();
211   *request.mutable_operand() = operand;
212   GetLocalShapeResponse response;
213 
214   VLOG(2) << "making get-shape request";
215   TF_RETURN_IF_ERROR(client_->stub()->GetLocalShape(&request, &response));
216   VLOG(2) << "done with request";
217 
218   TF_RET_CHECK(response.has_shape());
219   std::unique_ptr<Shape> shape = WrapUnique(response.release_shape());
220   TF_RET_CHECK(shape != nullptr);
221   return std::move(shape);
222 }
223 
GetShape(const ComputationDataHandle & operand)224 StatusOr<std::unique_ptr<Shape>> ComputationBuilder::GetShape(
225     const ComputationDataHandle& operand) {
226   TF_RETURN_IF_ERROR(first_error_);
227 
228   auto status_or_shape = GetShapeWithoutNoteError(operand);
229   if (!status_or_shape.ok()) {
230     NoteError(status_or_shape.status());
231     return first_error_;
232   }
233   return status_or_shape;
234 }
235 
GetProgramShape()236 StatusOr<ProgramShape> ComputationBuilder::GetProgramShape() {
237   TF_RETURN_IF_ERROR(first_error_);
238 
239   GetComputationShapeRequest request;
240   *request.mutable_computation() = computation_.handle();
241   GetComputationShapeResponse response;
242 
243   VLOG(2) << "making get-program-shape-request";
244   Status status = client_->stub()->GetComputationShape(&request, &response);
245   VLOG(2) << "done with get-program-shape-request";
246 
247   if (!status.ok()) {
248     first_error_ = status;
249     return status;
250   }
251 
252   TF_RET_CHECK(response.has_program_shape());
253   return std::move(*response.mutable_program_shape());
254 }
255 
CheckShape(const ComputationDataHandle & operand,const Shape & expected_shape)256 ComputationDataHandle ComputationBuilder::CheckShape(
257     const ComputationDataHandle& operand, const Shape& expected_shape) {
258   std::unique_ptr<Shape> actual_shape = GetShape(operand).ConsumeValueOrDie();
259   CHECK(ShapeUtil::Equal(expected_shape, *actual_shape))
260       << "want " << ShapeUtil::HumanString(expected_shape) << " got "
261       << ShapeUtil::HumanString(*actual_shape);
262   return operand;
263 }
264 
CheckSameShape(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs)265 void ComputationBuilder::CheckSameShape(const ComputationDataHandle& lhs,
266                                         const ComputationDataHandle& rhs) {
267   std::unique_ptr<Shape> lhs_shape = GetShape(lhs).ConsumeValueOrDie();
268   std::unique_ptr<Shape> rhs_shape = GetShape(rhs).ConsumeValueOrDie();
269   VLOG(2) << "checking " << ShapeUtil::HumanString(*lhs_shape) << " equals "
270           << ShapeUtil::HumanString(*rhs_shape);
271   CHECK(ShapeUtil::Equal(*lhs_shape, *rhs_shape))
272       << "lhs " << ShapeUtil::HumanString(*lhs_shape) << " rhs "
273       << ShapeUtil::HumanString(*rhs_shape);
274 }
275 
Slice(const ComputationDataHandle & operand,tensorflow::gtl::ArraySlice<int64> start_indices,tensorflow::gtl::ArraySlice<int64> limit_indices,tensorflow::gtl::ArraySlice<int64> strides)276 ComputationDataHandle ComputationBuilder::Slice(
277     const ComputationDataHandle& operand,
278     tensorflow::gtl::ArraySlice<int64> start_indices,
279     tensorflow::gtl::ArraySlice<int64> limit_indices,
280     tensorflow::gtl::ArraySlice<int64> strides) {
281   OpRequest op_request;
282   SliceRequest* request = op_request.mutable_slice_request();
283   *request->mutable_operand() = operand;
284   for (int64 index : start_indices) {
285     request->add_start_indices(index);
286   }
287   for (int64 index : limit_indices) {
288     request->add_limit_indices(index);
289   }
290   for (int64 index : strides) {
291     request->add_strides(index);
292   }
293   return RunOpAndParseResponse(&op_request);
294 }
295 
SliceInDim(const ComputationDataHandle & operand,int64 start_index,int64 limit_index,int64 stride,int64 dimno)296 ComputationDataHandle ComputationBuilder::SliceInDim(
297     const ComputationDataHandle& operand, int64 start_index, int64 limit_index,
298     int64 stride, int64 dimno) {
299   StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
300   if (!shape_status.ok()) {
301     NoteError(shape_status.status());
302     return ComputationDataHandle{};
303   }
304   const Shape& shape = *shape_status.ValueOrDie();
305   std::vector<int64> starts(ShapeUtil::Rank(shape), 0);
306   std::vector<int64> limits(shape.dimensions().begin(),
307                             shape.dimensions().end());
308   std::vector<int64> strides(ShapeUtil::Rank(shape), 1);
309   starts[dimno] = start_index;
310   limits[dimno] = limit_index;
311   strides[dimno] = stride;
312   return Slice(operand, starts, limits, strides);
313 }
314 
DynamicSlice(const ComputationDataHandle & operand,const ComputationDataHandle & start_indices,tensorflow::gtl::ArraySlice<int64> slice_sizes)315 ComputationDataHandle ComputationBuilder::DynamicSlice(
316     const ComputationDataHandle& operand,
317     const ComputationDataHandle& start_indices,
318     tensorflow::gtl::ArraySlice<int64> slice_sizes) {
319   OpRequest op_request;
320   DynamicSliceRequest* request = op_request.mutable_dynamic_slice_request();
321   *request->mutable_operand() = operand;
322   *request->mutable_start_indices() = start_indices;
323   for (int64 index : slice_sizes) {
324     request->add_slice_sizes(index);
325   }
326   return RunOpAndParseResponse(&op_request);
327 }
328 
DynamicUpdateSlice(const ComputationDataHandle & operand,const ComputationDataHandle & update,const ComputationDataHandle & start_indices)329 ComputationDataHandle ComputationBuilder::DynamicUpdateSlice(
330     const ComputationDataHandle& operand, const ComputationDataHandle& update,
331     const ComputationDataHandle& start_indices) {
332   OpRequest op_request;
333   DynamicUpdateSliceRequest* request =
334       op_request.mutable_dynamic_update_slice_request();
335   *request->mutable_operand() = operand;
336   *request->mutable_update() = update;
337   *request->mutable_start_indices() = start_indices;
338   return RunOpAndParseResponse(&op_request);
339 }
340 
ConcatInDim(tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,int64 dimension)341 ComputationDataHandle ComputationBuilder::ConcatInDim(
342     tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
343     int64 dimension) {
344   OpRequest op_request;
345   ConcatenateRequest* request = op_request.mutable_concatenate_request();
346   for (const ComputationDataHandle& operand : operands) {
347     *request->add_operands() = operand;
348   }
349   request->set_dimension(dimension);
350   return RunOpAndParseResponse(&op_request);
351 }
352 
Broadcast(const ComputationDataHandle & operand,tensorflow::gtl::ArraySlice<int64> broadcast_sizes)353 ComputationDataHandle ComputationBuilder::Broadcast(
354     const ComputationDataHandle& operand,
355     tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
356   OpRequest op_request;
357   BroadcastRequest* request = op_request.mutable_broadcast_request();
358   *request->mutable_operand() = operand;
359   for (int64 size : broadcast_sizes) {
360     request->add_broadcast_sizes(size);
361   }
362   return RunOpAndParseResponse(&op_request);
363 }
364 
Pad(const ComputationDataHandle & operand,const ComputationDataHandle & padding_value,const PaddingConfig & padding_config)365 ComputationDataHandle ComputationBuilder::Pad(
366     const ComputationDataHandle& operand,
367     const ComputationDataHandle& padding_value,
368     const PaddingConfig& padding_config) {
369   OpRequest op_request;
370   PadRequest* request = op_request.mutable_pad_request();
371   *request->mutable_operand() = operand;
372   *request->mutable_padding_value() = padding_value;
373   *request->mutable_padding_config() = padding_config;
374   return RunOpAndParseResponse(&op_request);
375 }
376 
Reshape(const ComputationDataHandle & operand,tensorflow::gtl::ArraySlice<int64> dimensions,tensorflow::gtl::ArraySlice<int64> new_sizes)377 ComputationDataHandle ComputationBuilder::Reshape(
378     const ComputationDataHandle& operand,
379     tensorflow::gtl::ArraySlice<int64> dimensions,
380     tensorflow::gtl::ArraySlice<int64> new_sizes) {
381   OpRequest op_request;
382   ReshapeRequest* request = op_request.mutable_reshape_request();
383   *request->mutable_operand() = operand;
384   for (int64 dimension : dimensions) {
385     request->add_dimensions(dimension);
386   }
387   for (int64 new_size : new_sizes) {
388     request->add_new_sizes(new_size);
389   }
390   return RunOpAndParseResponse(&op_request);
391 }
392 
Reshape(const ComputationDataHandle & operand,tensorflow::gtl::ArraySlice<int64> new_sizes)393 ComputationDataHandle ComputationBuilder::Reshape(
394     const ComputationDataHandle& operand,
395     tensorflow::gtl::ArraySlice<int64> new_sizes) {
396   if (!first_error_.ok()) {
397     return ComputationDataHandle();
398   }
399 
400   StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
401   if (!shape.ok()) {
402     return ComputationDataHandle();
403   }
404   std::vector<int64> dimensions(shape.ValueOrDie()->dimensions().size());
405   std::iota(dimensions.begin(), dimensions.end(), 0);
406   return Reshape(operand, dimensions, new_sizes);
407 }
408 
Collapse(const ComputationDataHandle & operand,tensorflow::gtl::ArraySlice<int64> dims_to_collapse)409 ComputationDataHandle ComputationBuilder::Collapse(
410     const ComputationDataHandle& operand,
411     tensorflow::gtl::ArraySlice<int64> dims_to_collapse) {
412   if (!first_error_.ok()) {
413     return ComputationDataHandle();
414   }
415 
416   // Don't support out-of-order collapse here.
417   // Checks that the collapsed dimensions are in order and consecutive.
418   for (tensorflow::gtl::ArraySlice<int64>::size_type i = 1;
419        i < dims_to_collapse.size(); ++i) {
420     if (dims_to_collapse[i] - 1 != dims_to_collapse[i - 1]) {
421       NoteError(InvalidArgument(
422           "Collapsed dimensions are not in order and consecutive."));
423       return ComputationDataHandle();
424     }
425   }
426 
427   // Create a new sizes vector from the old shape, replacing the collapsed
428   // dimensions by the product of their sizes.
429   StatusOr<std::unique_ptr<Shape>> shape_or_status = GetShape(operand);
430   if (!shape_or_status.ok()) {
431     return ComputationDataHandle();
432   }
433   std::unique_ptr<Shape> original_shape = shape_or_status.ConsumeValueOrDie();
434 
435   VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape);
436   VLOG(3) << "dims to collapse: "
437           << tensorflow::str_util::Join(dims_to_collapse, ",");
438 
439   if (dims_to_collapse.size() <= 1) {
440     // Not collapsing anything, trivially we can return the operand versus
441     // enqueueing a trivial reshape.
442     return operand;
443   }
444 
445   std::vector<int64> new_sizes;
446   for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) {
447     if (i <= dims_to_collapse.front() || i > dims_to_collapse.back()) {
448       new_sizes.push_back(original_shape->dimensions(i));
449     } else {
450       new_sizes.back() *= original_shape->dimensions(i);
451     }
452   }
453 
454   VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",")
455           << "]";
456 
457   return Reshape(operand, new_sizes);
458 }
459 
Trace(const string & tag,const ComputationDataHandle & operand)460 void ComputationBuilder::Trace(const string& tag,
461                                const ComputationDataHandle& operand) {
462   OpRequest op_request;
463   TraceRequest* request = op_request.mutable_trace_request();
464   request->set_tag(tag);
465   *request->mutable_operand() = operand;
466   RunOpAndNoteError(&op_request);
467 }
468 
Select(const ComputationDataHandle & pred,const ComputationDataHandle & on_true,const ComputationDataHandle & on_false)469 ComputationDataHandle ComputationBuilder::Select(
470     const ComputationDataHandle& pred, const ComputationDataHandle& on_true,
471     const ComputationDataHandle& on_false) {
472   return TernaryOp(TRIOP_SELECT, pred, on_true, on_false);
473 }
474 
Tuple(tensorflow::gtl::ArraySlice<ComputationDataHandle> elements)475 ComputationDataHandle ComputationBuilder::Tuple(
476     tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) {
477   OpRequest op_request;
478   VariadicOpRequest* request = op_request.mutable_variadic_op_request();
479   request->set_varop(VAROP_TUPLE);
480   for (const ComputationDataHandle& operand : elements) {
481     *request->add_operands() = operand;
482   }
483   return RunOpAndParseResponse(&op_request);
484 }
485 
GetTupleElement(const ComputationDataHandle & tuple_data,int64 index)486 ComputationDataHandle ComputationBuilder::GetTupleElement(
487     const ComputationDataHandle& tuple_data, int64 index) {
488   OpRequest op_request;
489   GetTupleElementRequest* request =
490       op_request.mutable_get_tuple_element_request();
491   *request->mutable_operand() = tuple_data;
492   request->set_index(index);
493   return RunOpAndParseResponse(&op_request);
494 }
495 
Eq(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)496 ComputationDataHandle ComputationBuilder::Eq(
497     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
498     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
499   return BinaryOp(BINOP_EQ, lhs, rhs, broadcast_dimensions);
500 }
501 
Ne(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)502 ComputationDataHandle ComputationBuilder::Ne(
503     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
504     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
505   return BinaryOp(BINOP_NE, lhs, rhs, broadcast_dimensions);
506 }
507 
Ge(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)508 ComputationDataHandle ComputationBuilder::Ge(
509     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
510     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
511   return BinaryOp(BINOP_GE, lhs, rhs, broadcast_dimensions);
512 }
513 
Gt(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)514 ComputationDataHandle ComputationBuilder::Gt(
515     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
516     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
517   return BinaryOp(BINOP_GT, lhs, rhs, broadcast_dimensions);
518 }
519 
Le(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)520 ComputationDataHandle ComputationBuilder::Le(
521     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
522     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
523   return BinaryOp(BINOP_LE, lhs, rhs, broadcast_dimensions);
524 }
525 
Lt(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)526 ComputationDataHandle ComputationBuilder::Lt(
527     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
528     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
529   return BinaryOp(BINOP_LT, lhs, rhs, broadcast_dimensions);
530 }
531 
Dot(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs)532 ComputationDataHandle ComputationBuilder::Dot(
533     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) {
534   StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
535   if (!lhs_shape_or_status.ok()) {
536     return ComputationDataHandle();
537   }
538   std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
539 
540   DotDimensionNumbers dimension_numbers;
541   dimension_numbers.add_lhs_contracting_dimensions(
542       lhs_shape->dimensions_size() == 1 ? 0 : 1);
543   dimension_numbers.add_rhs_contracting_dimensions(0);
544   return DotGeneral(lhs, rhs, dimension_numbers);
545 }
546 
DotGeneral(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,const DotDimensionNumbers & dimension_numbers)547 ComputationDataHandle ComputationBuilder::DotGeneral(
548     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
549     const DotDimensionNumbers& dimension_numbers) {
550   OpRequest op_request;
551   DotRequest* request = op_request.mutable_dot_request();
552   *request->mutable_lhs() = lhs;
553   *request->mutable_rhs() = rhs;
554   *request->mutable_dimension_numbers() = dimension_numbers;
555   return RunOpAndParseResponse(&op_request);
556 }
557 
Conv(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> window_strides,Padding padding)558 ComputationDataHandle ComputationBuilder::Conv(
559     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
560     tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
561   return ConvWithGeneralDimensions(
562       lhs, rhs, window_strides, padding,
563       CreateDefaultConvDimensionNumbers(window_strides.size()));
564 }
565 
ConvWithGeneralPadding(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> window_strides,tensorflow::gtl::ArraySlice<std::pair<int64,int64>> padding)566 ComputationDataHandle ComputationBuilder::ConvWithGeneralPadding(
567     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
568     tensorflow::gtl::ArraySlice<int64> window_strides,
569     tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
570   return ConvGeneral(lhs, rhs, window_strides, padding,
571                      CreateDefaultConvDimensionNumbers(window_strides.size()));
572 }
573 
VerifyConvolution(const Shape & lhs_shape,const Shape & rhs_shape,const ConvolutionDimensionNumbers & dimension_numbers)574 bool ComputationBuilder::VerifyConvolution(
575     const Shape& lhs_shape, const Shape& rhs_shape,
576     const ConvolutionDimensionNumbers& dimension_numbers) {
577   if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) {
578     NoteError(
579         InvalidArgument("Convolution arguments must have same number of "
580                         "dimensions. Got: %s and %s",
581                         ShapeUtil::HumanString(lhs_shape).c_str(),
582                         ShapeUtil::HumanString(rhs_shape).c_str()));
583     return false;
584   }
585   int num_dims = ShapeUtil::Rank(lhs_shape);
586   if (num_dims < 2) {
587     NoteError(InvalidArgument(
588         "Convolution expects argument arrays with >= 3 dimensions. "
589         "Got: %s and %s",
590         ShapeUtil::HumanString(lhs_shape).c_str(),
591         ShapeUtil::HumanString(rhs_shape).c_str()));
592     return false;
593   }
594   int num_spatial_dims = num_dims - 2;
595 
596   const auto check_spatial_dimensions =
597       [&](const char* const field_name,
598           const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
599               numbers) {
600         if (numbers.size() != num_spatial_dims) {
601           NoteError(InvalidArgument("Expected %d elements for %s, but got %d.",
602                                     num_spatial_dims, field_name,
603                                     numbers.size()));
604           return false;
605         }
606         for (int i = 0; i < numbers.size(); ++i) {
607           if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
608             NoteError(
609                 InvalidArgument("Convolution %s[%d] is out of bounds: %lld",
610                                 field_name, i, numbers.Get(i)));
611             return false;
612           }
613         }
614         return true;
615       };
616   return check_spatial_dimensions(
617              "input_spatial_dimensions",
618              dimension_numbers.input_spatial_dimensions()) &&
619          check_spatial_dimensions(
620              "kernel_spatial_dimensions",
621              dimension_numbers.kernel_spatial_dimensions()) &&
622          check_spatial_dimensions(
623              "output_spatial_dimensions",
624              dimension_numbers.output_spatial_dimensions());
625 }
626 
ConvWithGeneralDimensions(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers)627 ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions(
628     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
629     tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
630     const ConvolutionDimensionNumbers& dimension_numbers) {
631   if (!first_error_.ok() || !PrepareComputation().ok()) {
632     return ComputationDataHandle();
633   }
634 
635   StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
636   if (!lhs_shape_or_status.ok()) {
637     return ComputationDataHandle();
638   }
639 
640   StatusOr<std::unique_ptr<Shape>> rhs_shape_or_status = GetShape(rhs);
641   if (!rhs_shape_or_status.ok()) {
642     return ComputationDataHandle();
643   }
644 
645   std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
646   std::unique_ptr<Shape> rhs_shape = rhs_shape_or_status.ConsumeValueOrDie();
647 
648   if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) {
649     NoteError(InternalError("failed to verify convolution"));
650     return ComputationDataHandle();
651   }
652 
653   std::vector<int64> base_area_dimensions(
654       dimension_numbers.input_spatial_dimensions_size());
655   for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size();
656        ++i) {
657     base_area_dimensions[i] =
658         lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i));
659   }
660 
661   std::vector<int64> window_dimensions(
662       dimension_numbers.kernel_spatial_dimensions_size());
663   for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) {
664     window_dimensions[i] =
665         rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
666   }
667 
668   return ConvGeneral(lhs, rhs, window_strides,
669                      MakePadding(base_area_dimensions, window_dimensions,
670                                  window_strides, padding),
671                      dimension_numbers);
672 }
673 
ConvGeneral(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> window_strides,tensorflow::gtl::ArraySlice<std::pair<int64,int64>> padding,const ConvolutionDimensionNumbers & dimension_numbers)674 ComputationDataHandle ComputationBuilder::ConvGeneral(
675     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
676     tensorflow::gtl::ArraySlice<int64> window_strides,
677     tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
678     const ConvolutionDimensionNumbers& dimension_numbers) {
679   return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
680                             dimension_numbers);
681 }
682 
ConvGeneralDilated(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> window_strides,tensorflow::gtl::ArraySlice<std::pair<int64,int64>> padding,tensorflow::gtl::ArraySlice<int64> lhs_dilation,tensorflow::gtl::ArraySlice<int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers)683 ComputationDataHandle ComputationBuilder::ConvGeneralDilated(
684     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
685     tensorflow::gtl::ArraySlice<int64> window_strides,
686     tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
687     tensorflow::gtl::ArraySlice<int64> lhs_dilation,
688     tensorflow::gtl::ArraySlice<int64> rhs_dilation,
689     const ConvolutionDimensionNumbers& dimension_numbers) {
690   if (!first_error_.ok() || !PrepareComputation().ok()) {
691     return ComputationDataHandle();
692   }
693 
694   StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
695   if (!lhs_shape_or_status.ok()) {
696     return ComputationDataHandle();
697   }
698 
699   StatusOr<std::unique_ptr<Shape>> rhs_shape_or_status = GetShape(rhs);
700   if (!rhs_shape_or_status.ok()) {
701     return ComputationDataHandle();
702   }
703 
704   std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
705   std::unique_ptr<Shape> rhs_shape = rhs_shape_or_status.ConsumeValueOrDie();
706   if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) {
707     // Error is recorded in VerifyConvolution.
708     return ComputationDataHandle();
709   }
710 
711   std::vector<int64> window_dimensions(
712       dimension_numbers.kernel_spatial_dimensions_size());
713   for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) {
714     window_dimensions[i] =
715         rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
716   }
717 
718   OpRequest op_request;
719   ConvolveRequest* request = op_request.mutable_convolve_request();
720   *request->mutable_lhs() = lhs;
721   *request->mutable_rhs() = rhs;
722   *request->mutable_dimension_numbers() = dimension_numbers;
723 
724   if (!MakeWindow(window_dimensions, window_strides, padding, lhs_dilation,
725                   rhs_dilation, request->mutable_window())) {
726     // Error is recorded in MakeWindow.
727     return ComputationDataHandle();
728   }
729 
730   return RunOpAndParseResponse(&op_request);
731 }
732 
Fft(const ComputationDataHandle & operand,const FftType fft_type,const tensorflow::gtl::ArraySlice<int64> fft_length)733 ComputationDataHandle ComputationBuilder::Fft(
734     const ComputationDataHandle& operand, const FftType fft_type,
735     const tensorflow::gtl::ArraySlice<int64> fft_length) {
736   OpRequest op_request;
737   FftRequest* request = op_request.mutable_fft_request();
738   *request->mutable_operand() = operand;
739   request->set_fft_type(fft_type);
740   for (int64 dim_len : fft_length) {
741     request->add_fft_length(dim_len);
742   }
743   return RunOpAndParseResponse(&op_request);
744 }
745 
Infeed(const Shape & shape,const string & config)746 ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape,
747                                                  const string& config) {
748   OpRequest op_request;
749   InfeedRequest* request = op_request.mutable_infeed_request();
750   *request->mutable_shape() = shape;
751   *request->mutable_config() = config;
752   return RunOpAndParseResponse(&op_request);
753 }
754 
Outfeed(const ComputationDataHandle & operand,const Shape & shape,const string & outfeed_config)755 void ComputationBuilder::Outfeed(const ComputationDataHandle& operand,
756                                  const Shape& shape,
757                                  const string& outfeed_config) {
758   OpRequest op_request;
759   OutfeedRequest* request = op_request.mutable_outfeed_request();
760   request->set_outfeed_config(outfeed_config);
761   *request->mutable_operand() = operand;
762   *request->mutable_shape() = shape;
763   RunOpAndNoteError(&op_request);
764 }
765 
Call(const Computation & computation,tensorflow::gtl::ArraySlice<ComputationDataHandle> operands)766 ComputationDataHandle ComputationBuilder::Call(
767     const Computation& computation,
768     tensorflow::gtl::ArraySlice<ComputationDataHandle> operands) {
769   OpRequest op_request;
770   CallRequest* request = op_request.mutable_call_request();
771   *request->mutable_to_apply() = computation.handle();
772   for (const ComputationDataHandle& operand : operands) {
773     *request->add_operands() = operand;
774   }
775   return RunOpAndParseResponse(&op_request);
776 }
777 
CustomCall(const string & call_target_name,tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,const Shape & shape)778 ComputationDataHandle ComputationBuilder::CustomCall(
779     const string& call_target_name,
780     tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
781     const Shape& shape) {
782   OpRequest op_request;
783   CustomCallRequest* request = op_request.mutable_custom_call_request();
784   request->set_call_target_name(call_target_name);
785   for (const ComputationDataHandle& operand : operands) {
786     *request->add_operands() = operand;
787   }
788   *request->mutable_shape() = shape;
789   return RunOpAndParseResponse(&op_request);
790 }
791 
HostCompute(tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,const string & channel_name,int64 cost_estimate_ns,const Shape & shape)792 ComputationDataHandle ComputationBuilder::HostCompute(
793     tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
794     const string& channel_name, int64 cost_estimate_ns, const Shape& shape) {
795   OpRequest op_request;
796   HostComputeRequest* request = op_request.mutable_host_compute_request();
797   for (const ComputationDataHandle& operand : operands) {
798     *request->add_operands() = operand;
799   }
800   *request->mutable_shape() = shape;
801   request->set_channel_name(channel_name);
802   request->set_cost_estimate_ns(cost_estimate_ns);
803   return RunOpAndParseResponse(&op_request);
804 }
805 
Complex(const ComputationDataHandle & real,const ComputationDataHandle & imag,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)806 ComputationDataHandle ComputationBuilder::Complex(
807     const ComputationDataHandle& real, const ComputationDataHandle& imag,
808     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
809   return BinaryOp(BINOP_COMPLEX, real, imag, broadcast_dimensions);
810 }
811 
Conj(const ComputationDataHandle & operand)812 ComputationDataHandle ComputationBuilder::Conj(
813     const ComputationDataHandle& operand) {
814   return Complex(Real(operand), Neg(Imag(operand)));
815 }
816 
Add(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)817 ComputationDataHandle ComputationBuilder::Add(
818     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
819     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
820   return BinaryOp(BINOP_ADD, lhs, rhs, broadcast_dimensions);
821 }
822 
Sub(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)823 ComputationDataHandle ComputationBuilder::Sub(
824     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
825     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
826   return BinaryOp(BINOP_SUB, lhs, rhs, broadcast_dimensions);
827 }
828 
Mul(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)829 ComputationDataHandle ComputationBuilder::Mul(
830     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
831     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
832   return BinaryOp(BINOP_MUL, lhs, rhs, broadcast_dimensions);
833 }
834 
Div(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)835 ComputationDataHandle ComputationBuilder::Div(
836     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
837     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
838   return BinaryOp(BINOP_DIV, lhs, rhs, broadcast_dimensions);
839 }
840 
Rem(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)841 ComputationDataHandle ComputationBuilder::Rem(
842     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
843     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
844   return BinaryOp(BINOP_REM, lhs, rhs, broadcast_dimensions);
845 }
846 
Max(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)847 ComputationDataHandle ComputationBuilder::Max(
848     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
849     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
850   return BinaryOp(BINOP_MAX, lhs, rhs, broadcast_dimensions);
851 }
852 
Min(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)853 ComputationDataHandle ComputationBuilder::Min(
854     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
855     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
856   return BinaryOp(BINOP_MIN, lhs, rhs, broadcast_dimensions);
857 }
858 
And(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)859 ComputationDataHandle ComputationBuilder::And(
860     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
861     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
862   return BinaryOp(BINOP_AND, lhs, rhs, broadcast_dimensions);
863 }
864 
Or(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)865 ComputationDataHandle ComputationBuilder::Or(
866     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
867     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
868   return BinaryOp(BINOP_OR, lhs, rhs, broadcast_dimensions);
869 }
870 
Not(const ComputationDataHandle & operand)871 ComputationDataHandle ComputationBuilder::Not(
872     const ComputationDataHandle& operand) {
873   return UnaryOp(UNOP_NOT, operand);
874 }
875 
ShiftLeft(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)876 ComputationDataHandle ComputationBuilder::ShiftLeft(
877     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
878     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
879   return BinaryOp(BINOP_SHIFT_LEFT, lhs, rhs, broadcast_dimensions);
880 }
881 
ShiftRightArithmetic(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)882 ComputationDataHandle ComputationBuilder::ShiftRightArithmetic(
883     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
884     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
885   return BinaryOp(BINOP_SHIFT_RIGHT_ARITHMETIC, lhs, rhs, broadcast_dimensions);
886 }
887 
ShiftRightLogical(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)888 ComputationDataHandle ComputationBuilder::ShiftRightLogical(
889     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
890     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
891   return BinaryOp(BINOP_SHIFT_RIGHT_LOGICAL, lhs, rhs, broadcast_dimensions);
892 }
893 
Abs(const ComputationDataHandle & operand)894 ComputationDataHandle ComputationBuilder::Abs(
895     const ComputationDataHandle& operand) {
896   return UnaryOp(UNOP_ABS, operand);
897 }
898 
Atan2(const ComputationDataHandle & y,const ComputationDataHandle & x,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)899 ComputationDataHandle ComputationBuilder::Atan2(
900     const ComputationDataHandle& y, const ComputationDataHandle& x,
901     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
902   return BinaryOp(BINOP_ATAN2, y, x, broadcast_dimensions);
903 }
904 
Exp(const ComputationDataHandle & operand)905 ComputationDataHandle ComputationBuilder::Exp(
906     const ComputationDataHandle& operand) {
907   return UnaryOp(UNOP_EXP, operand);
908 }
909 
Floor(const ComputationDataHandle & operand)910 ComputationDataHandle ComputationBuilder::Floor(
911     const ComputationDataHandle& operand) {
912   return UnaryOp(UNOP_FLOOR, operand);
913 }
914 
Ceil(const ComputationDataHandle & operand)915 ComputationDataHandle ComputationBuilder::Ceil(
916     const ComputationDataHandle& operand) {
917   return UnaryOp(UNOP_CEIL, operand);
918 }
919 
Round(const ComputationDataHandle & operand)920 ComputationDataHandle ComputationBuilder::Round(
921     const ComputationDataHandle& operand) {
922   return UnaryOp(UNOP_ROUND_NEAREST_AFZ, operand);
923 }
924 
Log(const ComputationDataHandle & operand)925 ComputationDataHandle ComputationBuilder::Log(
926     const ComputationDataHandle& operand) {
927   return UnaryOp(UNOP_LOG, operand);
928 }
929 
Sign(const ComputationDataHandle & operand)930 ComputationDataHandle ComputationBuilder::Sign(
931     const ComputationDataHandle& operand) {
932   return UnaryOp(UNOP_SIGN, operand);
933 }
934 
Cos(const ComputationDataHandle & operand)935 ComputationDataHandle ComputationBuilder::Cos(
936     const ComputationDataHandle& operand) {
937   return UnaryOp(UNOP_COS, operand);
938 }
939 
Sin(const ComputationDataHandle & operand)940 ComputationDataHandle ComputationBuilder::Sin(
941     const ComputationDataHandle& operand) {
942   return UnaryOp(UNOP_SIN, operand);
943 }
944 
Tanh(const ComputationDataHandle & operand)945 ComputationDataHandle ComputationBuilder::Tanh(
946     const ComputationDataHandle& operand) {
947   return UnaryOp(UNOP_TANH, operand);
948 }
949 
Real(const ComputationDataHandle & operand)950 ComputationDataHandle ComputationBuilder::Real(
951     const ComputationDataHandle& operand) {
952   return UnaryOp(UNOP_REAL, operand);
953 }
954 
Imag(const ComputationDataHandle & operand)955 ComputationDataHandle ComputationBuilder::Imag(
956     const ComputationDataHandle& operand) {
957   return UnaryOp(UNOP_IMAG, operand);
958 }
959 
IsFinite(const ComputationDataHandle & operand)960 ComputationDataHandle ComputationBuilder::IsFinite(
961     const ComputationDataHandle& operand) {
962   return UnaryOp(UNOP_IS_FINITE, operand);
963 }
964 
Transpose(const ComputationDataHandle & operand,tensorflow::gtl::ArraySlice<int64> permutation)965 ComputationDataHandle ComputationBuilder::Transpose(
966     const ComputationDataHandle& operand,
967     tensorflow::gtl::ArraySlice<int64> permutation) {
968   OpRequest op_request;
969   TransposeRequest* request = op_request.mutable_transpose_request();
970   *request->mutable_operand() = operand;
971   for (int64 dimension : permutation) {
972     request->add_dimensions(dimension);
973   }
974   return RunOpAndParseResponse(&op_request);
975 }
976 
Rev(const ComputationDataHandle & operand,tensorflow::gtl::ArraySlice<int64> dimensions)977 ComputationDataHandle ComputationBuilder::Rev(
978     const ComputationDataHandle& operand,
979     tensorflow::gtl::ArraySlice<int64> dimensions) {
980   OpRequest op_request;
981   ReverseRequest* request = op_request.mutable_reverse_request();
982   *request->mutable_operand() = operand;
983   for (int64 dimension : dimensions) {
984     request->add_dimensions(dimension);
985   }
986   return RunOpAndParseResponse(&op_request);
987 }
988 
Sort(const ComputationDataHandle & operand)989 ComputationDataHandle ComputationBuilder::Sort(
990     const ComputationDataHandle& operand) {
991   return UnaryOp(UNOP_SORT, operand);
992 }
993 
SqrtF32(const ComputationDataHandle & operand)994 ComputationDataHandle ComputationBuilder::SqrtF32(
995     const ComputationDataHandle& operand) {
996   return BinaryOp(BINOP_POW, operand, ConstantR0<float>(0.5),
997                   /*broadcast_dimensions=*/{});
998 }
999 
Pow(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)1000 ComputationDataHandle ComputationBuilder::Pow(
1001     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
1002     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
1003   return BinaryOp(BINOP_POW, lhs, rhs, broadcast_dimensions);
1004 }
1005 
ConvertElementType(const ComputationDataHandle & operand,PrimitiveType new_element_type)1006 ComputationDataHandle ComputationBuilder::ConvertElementType(
1007     const ComputationDataHandle& operand, PrimitiveType new_element_type) {
1008   if (!first_error_.ok() || !PrepareComputation().ok()) {
1009     return ComputationDataHandle();
1010   }
1011 
1012   StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
1013   if (!shape_status.ok()) {
1014     return ComputationDataHandle();
1015   }
1016   std::unique_ptr<Shape> original = shape_status.ConsumeValueOrDie();
1017 
1018   OpRequest op_request;
1019   ConvertRequest* request = op_request.mutable_convert_request();
1020   *request->mutable_operand() = operand;
1021   request->set_new_element_type(new_element_type);
1022   return RunOpAndParseResponse(&op_request);
1023 }
1024 
BitcastConvertType(const ComputationDataHandle & operand,PrimitiveType new_element_type)1025 ComputationDataHandle ComputationBuilder::BitcastConvertType(
1026     const ComputationDataHandle& operand, PrimitiveType new_element_type) {
1027   if (!first_error_.ok() || !PrepareComputation().ok()) {
1028     return ComputationDataHandle();
1029   }
1030 
1031   StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
1032   if (!shape_status.ok()) {
1033     return ComputationDataHandle();
1034   }
1035   std::unique_ptr<Shape> original = shape_status.ConsumeValueOrDie();
1036 
1037   OpRequest op_request;
1038   ConvertRequest* request = op_request.mutable_bitcast_convert_request();
1039   *request->mutable_operand() = operand;
1040   request->set_new_element_type(new_element_type);
1041   return RunOpAndParseResponse(&op_request);
1042 }
1043 
SquareF32(const ComputationDataHandle & operand)1044 ComputationDataHandle ComputationBuilder::SquareF32(
1045     const ComputationDataHandle& operand) {
1046   return BinaryOp(BINOP_POW, operand, ConstantR0<float>(2.0),
1047                   /*broadcast_dimensions=*/{});
1048 }
1049 
ReciprocalF32(const ComputationDataHandle & operand)1050 ComputationDataHandle ComputationBuilder::ReciprocalF32(
1051     const ComputationDataHandle& operand) {
1052   return BinaryOp(BINOP_POW, operand, ConstantR0<float>(-1.0),
1053                   /*broadcast_dimensions=*/{});
1054 }
1055 
Neg(const ComputationDataHandle & operand)1056 ComputationDataHandle ComputationBuilder::Neg(
1057     const ComputationDataHandle& operand) {
1058   return UnaryOp(UNOP_NEGATE, operand);
1059 }
1060 
Clamp(const ComputationDataHandle & min,const ComputationDataHandle & operand,const ComputationDataHandle & max)1061 ComputationDataHandle ComputationBuilder::Clamp(
1062     const ComputationDataHandle& min, const ComputationDataHandle& operand,
1063     const ComputationDataHandle& max) {
1064   return TernaryOp(TRIOP_CLAMP, min, operand, max);
1065 }
1066 
UnaryOp(UnaryOperation unop,const ComputationDataHandle & operand)1067 ComputationDataHandle ComputationBuilder::UnaryOp(
1068     UnaryOperation unop, const ComputationDataHandle& operand) {
1069   OpRequest op_request;
1070   UnaryOpRequest* request = op_request.mutable_unary_op_request();
1071   request->set_unop(unop);
1072   *request->mutable_operand() = operand;
1073   return RunOpAndParseResponse(&op_request);
1074 }
1075 
BinaryOp(BinaryOperation binop,const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)1076 ComputationDataHandle ComputationBuilder::BinaryOp(
1077     BinaryOperation binop, const ComputationDataHandle& lhs,
1078     const ComputationDataHandle& rhs,
1079     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
1080   OpRequest op_request;
1081   BinaryOpRequest* request = op_request.mutable_binary_op_request();
1082   request->set_binop(binop);
1083   *request->mutable_lhs() = lhs;
1084   *request->mutable_rhs() = rhs;
1085   for (int64 dimension : broadcast_dimensions) {
1086     request->add_broadcast_dimensions(dimension);
1087   }
1088   return RunOpAndParseResponse(&op_request);
1089 }
1090 
RngOp(RandomDistribution distribution,tensorflow::gtl::ArraySlice<ComputationDataHandle> parameters,const Shape & shape)1091 ComputationDataHandle ComputationBuilder::RngOp(
1092     RandomDistribution distribution,
1093     tensorflow::gtl::ArraySlice<ComputationDataHandle> parameters,
1094     const Shape& shape) {
1095   OpRequest op_request;
1096   RngRequest* request = op_request.mutable_rng_request();
1097   request->set_distribution(distribution);
1098   for (const ComputationDataHandle& param : parameters) {
1099     *request->add_parameter() = param;
1100   }
1101   *request->mutable_shape() = shape;
1102   return RunOpAndParseResponse(&op_request);
1103 }
1104 
TernaryOp(TernaryOperation triop,const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,const ComputationDataHandle & ehs)1105 ComputationDataHandle ComputationBuilder::TernaryOp(
1106     TernaryOperation triop, const ComputationDataHandle& lhs,
1107     const ComputationDataHandle& rhs, const ComputationDataHandle& ehs) {
1108   OpRequest op_request;
1109   TernaryOpRequest* request = op_request.mutable_ternary_op_request();
1110   request->set_triop(triop);
1111   *request->mutable_lhs() = lhs;
1112   *request->mutable_rhs() = rhs;
1113   *request->mutable_ehs() = ehs;
1114   return RunOpAndParseResponse(&op_request);
1115 }
1116 
SetReturnValue(const ComputationDataHandle & operand)1117 Status ComputationBuilder::SetReturnValue(
1118     const ComputationDataHandle& operand) {
1119   TF_RETURN_IF_ERROR(first_error_);
1120 
1121   SetReturnValueRequest request;
1122   *request.mutable_computation() = computation_.handle();
1123   *request.mutable_operand() = operand;
1124 
1125   SetReturnValueResponse response;
1126 
1127   VLOG(2) << "making set-handle-to-execute request";
1128   Status s = client_->stub()->SetReturnValue(&request, &response);
1129   VLOG(2) << "done with request";
1130 
1131   if (!s.ok()) {
1132     NoteError(s);
1133     return first_error_;
1134   }
1135 
1136   return Status::OK();
1137 }
1138 
IsConstant(const ComputationDataHandle & operand,int64 num_parameters)1139 StatusOr<bool> ComputationBuilder::IsConstant(
1140     const ComputationDataHandle& operand, int64 num_parameters) {
1141   TF_RETURN_IF_ERROR(first_error_);
1142 
1143   IsConstantRequest request;
1144   *request.mutable_computation() = computation_.handle();
1145   *request.mutable_operand() = operand;
1146   request.set_num_parameters(num_parameters);
1147   IsConstantResponse response;
1148 
1149   VLOG(2) << "making IsConstant request";
1150   Status s = client_->stub()->IsConstant(&request, &response);
1151   VLOG(2) << "done with request";
1152 
1153   if (!s.ok()) {
1154     return s;
1155   }
1156   return response.is_constant();
1157 }
1158 
ComputeConstant(const ComputationDataHandle & operand,const Layout * output_layout,tensorflow::gtl::ArraySlice<Literal> parameters)1159 StatusOr<std::unique_ptr<Literal>> ComputationBuilder::ComputeConstant(
1160     const ComputationDataHandle& operand, const Layout* output_layout,
1161     tensorflow::gtl::ArraySlice<Literal> parameters) {
1162   TF_RETURN_IF_ERROR(first_error_);
1163 
1164   ComputeConstantRequest request;
1165   *request.mutable_computation() = computation_.handle();
1166   *request.mutable_operand() = operand;
1167   if (output_layout != nullptr) {
1168     *request.mutable_output_layout() = *output_layout;
1169   }
1170   for (const auto& param : parameters) {
1171     *request.add_parameters() = param.ToProto();
1172   }
1173 
1174   ComputeConstantResponse response;
1175 
1176   VLOG(2) << "making compute-constant request";
1177   Status s = client_->stub()->ComputeConstant(&request, &response);
1178   VLOG(2) << "done with request";
1179 
1180   if (!s.ok()) {
1181     return s;
1182   }
1183 
1184   VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}";
1185 
1186   if (!response.has_literal()) {
1187     return InternalError(
1188         "no computed literal in the provided response in ComputeConstant "
1189         "request");
1190   }
1191   return Literal::CreateFromProto(response.literal());
1192 }
1193 
Map(tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,const Computation & computation,tensorflow::gtl::ArraySlice<int64> dimensions,tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands)1194 ComputationDataHandle ComputationBuilder::Map(
1195     tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
1196     const Computation& computation,
1197     tensorflow::gtl::ArraySlice<int64> dimensions,
1198     tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands) {
1199   OpRequest op_request;
1200   MapRequest* request = op_request.mutable_map_request();
1201   for (const ComputationDataHandle& operand : operands) {
1202     *request->add_operands() = operand;
1203   }
1204   *request->mutable_to_apply() = computation.handle();
1205   for (int64 dimension : dimensions) {
1206     request->add_dimensions(dimension);
1207   }
1208   for (const ComputationDataHandle& sop : static_operands) {
1209     *request->add_static_operands() = sop;
1210   }
1211   return RunOpAndParseResponse(&op_request);
1212 }
1213 
RngNormal(const ComputationDataHandle & mu,const ComputationDataHandle & sigma,const Shape & shape)1214 ComputationDataHandle ComputationBuilder::RngNormal(
1215     const ComputationDataHandle& mu, const ComputationDataHandle& sigma,
1216     const Shape& shape) {
1217   return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
1218 }
1219 
RngUniform(const ComputationDataHandle & a,const ComputationDataHandle & b,const Shape & shape)1220 ComputationDataHandle ComputationBuilder::RngUniform(
1221     const ComputationDataHandle& a, const ComputationDataHandle& b,
1222     const Shape& shape) {
1223   return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
1224 }
1225 
While(const Computation & condition,const Computation & body,const ComputationDataHandle & init)1226 ComputationDataHandle ComputationBuilder::While(
1227     const Computation& condition, const Computation& body,
1228     const ComputationDataHandle& init) {
1229   OpRequest op_request;
1230   WhileRequest* request = op_request.mutable_while_request();
1231   *request->mutable_condition() = condition.handle();
1232   *request->mutable_body() = body.handle();
1233   *request->mutable_init() = init;
1234   return RunOpAndParseResponse(&op_request);
1235 }
1236 
Gather(const ComputationDataHandle & input,const ComputationDataHandle & gather_indices,const GatherDimensionNumbers & dimension_numbers,tensorflow::gtl::ArraySlice<int64> window_bounds)1237 ComputationDataHandle ComputationBuilder::Gather(
1238     const ComputationDataHandle& input,
1239     const ComputationDataHandle& gather_indices,
1240     const GatherDimensionNumbers& dimension_numbers,
1241     tensorflow::gtl::ArraySlice<int64> window_bounds) {
1242   OpRequest op_request;
1243   GatherRequest* gather_request = op_request.mutable_gather_request();
1244   *gather_request->mutable_input() = input;
1245   *gather_request->mutable_gather_indices() = gather_indices;
1246   *gather_request->mutable_dimension_numbers() = dimension_numbers;
1247   for (int64 window_bound : window_bounds) {
1248     gather_request->add_window_bounds(window_bound);
1249   }
1250   return RunOpAndParseResponse(&op_request);
1251 }
1252 
Conditional(const ComputationDataHandle & predicate,const ComputationDataHandle & true_operand,const Computation & true_computation,const ComputationDataHandle & false_operand,const Computation & false_computation)1253 ComputationDataHandle ComputationBuilder::Conditional(
1254     const ComputationDataHandle& predicate,
1255     const ComputationDataHandle& true_operand,
1256     const Computation& true_computation,
1257     const ComputationDataHandle& false_operand,
1258     const Computation& false_computation) {
1259   OpRequest op_request;
1260   ConditionalRequest* request = op_request.mutable_conditional_request();
1261   *request->mutable_predicate() = predicate;
1262   *request->mutable_true_operand() = true_operand;
1263   *request->mutable_true_computation() = true_computation.handle();
1264   *request->mutable_false_operand() = false_operand;
1265   *request->mutable_false_computation() = false_computation.handle();
1266   return RunOpAndParseResponse(&op_request);
1267 }
1268 
Reduce(const ComputationDataHandle & operand,const ComputationDataHandle & init_value,const Computation & computation,tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce)1269 ComputationDataHandle ComputationBuilder::Reduce(
1270     const ComputationDataHandle& operand,
1271     const ComputationDataHandle& init_value, const Computation& computation,
1272     tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
1273   OpRequest op_request;
1274   ReduceRequest* request = op_request.mutable_reduce_request();
1275   *request->mutable_operand() = operand;
1276   *request->mutable_init_value() = init_value;
1277   for (int64 dimension : dimensions_to_reduce) {
1278     request->add_dimensions(dimension);
1279   }
1280   *request->mutable_to_apply() = computation.handle();
1281   return RunOpAndParseResponse(&op_request);
1282 }
1283 
ReduceAll(const ComputationDataHandle & operand,const ComputationDataHandle & init_value,const Computation & computation)1284 ComputationDataHandle ComputationBuilder::ReduceAll(
1285     const ComputationDataHandle& operand,
1286     const ComputationDataHandle& init_value, const Computation& computation) {
1287   if (!first_error_.ok() || !PrepareComputation().ok()) {
1288     return ComputationDataHandle();
1289   }
1290 
1291   StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
1292   if (!shape.ok()) {
1293     return ComputationDataHandle();
1294   }
1295 
1296   std::vector<int64> all_dimnos(ShapeUtil::Rank(*shape.ValueOrDie()));
1297   std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
1298   return Reduce(operand, init_value, computation, all_dimnos);
1299 }
1300 
ReduceWindow(const ComputationDataHandle & operand,const ComputationDataHandle & init_value,const Computation & computation,tensorflow::gtl::ArraySlice<int64> window_dimensions,tensorflow::gtl::ArraySlice<int64> window_strides,Padding padding)1301 ComputationDataHandle ComputationBuilder::ReduceWindow(
1302     const ComputationDataHandle& operand,
1303     const ComputationDataHandle& init_value, const Computation& computation,
1304     tensorflow::gtl::ArraySlice<int64> window_dimensions,
1305     tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
1306   if (!first_error_.ok()) {
1307     return ComputationDataHandle();
1308   }
1309 
1310   StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
1311   if (!shape.ok()) {
1312     return ComputationDataHandle();
1313   }
1314 
1315   Status padding_valid =
1316       ValidatePaddingValues(AsInt64Slice(shape.ValueOrDie()->dimensions()),
1317                             window_dimensions, window_strides);
1318   if (!padding_valid.ok()) {
1319     first_error_ = padding_valid;
1320     return ComputationDataHandle();
1321   }
1322 
1323   std::vector<std::pair<int64, int64>> padding_values =
1324       MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()),
1325                   window_dimensions, window_strides, padding);
1326   return ReduceWindowWithGeneralPadding(operand, init_value, computation,
1327                                         window_dimensions, window_strides,
1328                                         padding_values);
1329 }
1330 
ReduceWindowWithGeneralPadding(const ComputationDataHandle & operand,const ComputationDataHandle & init_value,const Computation & computation,tensorflow::gtl::ArraySlice<int64> window_dimensions,tensorflow::gtl::ArraySlice<int64> window_strides,tensorflow::gtl::ArraySlice<std::pair<int64,int64>> padding)1331 ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding(
1332     const ComputationDataHandle& operand,
1333     const ComputationDataHandle& init_value, const Computation& computation,
1334     tensorflow::gtl::ArraySlice<int64> window_dimensions,
1335     tensorflow::gtl::ArraySlice<int64> window_strides,
1336     tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
1337   OpRequest op_request;
1338   ReduceWindowRequest* request = op_request.mutable_reduce_window_request();
1339   *request->mutable_operand() = operand;
1340   *request->mutable_to_apply() = computation.handle();
1341   *request->mutable_init_value() = init_value;
1342 
1343   if (!MakeWindow(window_dimensions, window_strides, padding, {}, {},
1344                   request->mutable_window())) {
1345     NoteError(InternalError("failed to make window"));
1346     return ComputationDataHandle();
1347   }
1348 
1349   return RunOpAndParseResponse(&op_request);
1350 }
1351 
BatchNormTraining(const ComputationDataHandle & operand,const ComputationDataHandle & scale,const ComputationDataHandle & offset,float epsilon,int64 feature_index)1352 ComputationDataHandle ComputationBuilder::BatchNormTraining(
1353     const ComputationDataHandle& operand, const ComputationDataHandle& scale,
1354     const ComputationDataHandle& offset, float epsilon, int64 feature_index) {
1355   OpRequest op_request;
1356   BatchNormTrainingRequest* request =
1357       op_request.mutable_batch_norm_training_request();
1358   *request->mutable_operand() = operand;
1359   *request->mutable_scale() = scale;
1360   *request->mutable_offset() = offset;
1361   request->set_epsilon(epsilon);
1362   request->set_feature_index(feature_index);
1363   return RunOpAndParseResponse(&op_request);
1364 }
1365 
BatchNormInference(const ComputationDataHandle & operand,const ComputationDataHandle & scale,const ComputationDataHandle & offset,const ComputationDataHandle & mean,const ComputationDataHandle & variance,float epsilon,int64 feature_index)1366 ComputationDataHandle ComputationBuilder::BatchNormInference(
1367     const ComputationDataHandle& operand, const ComputationDataHandle& scale,
1368     const ComputationDataHandle& offset, const ComputationDataHandle& mean,
1369     const ComputationDataHandle& variance, float epsilon, int64 feature_index) {
1370   OpRequest op_request;
1371   BatchNormInferenceRequest* request =
1372       op_request.mutable_batch_norm_inference_request();
1373   *request->mutable_operand() = operand;
1374   *request->mutable_scale() = scale;
1375   *request->mutable_offset() = offset;
1376   *request->mutable_mean() = mean;
1377   *request->mutable_variance() = variance;
1378   request->set_epsilon(epsilon);
1379   request->set_feature_index(feature_index);
1380   return RunOpAndParseResponse(&op_request);
1381 }
1382 
BatchNormGrad(const ComputationDataHandle & operand,const ComputationDataHandle & scale,const ComputationDataHandle & mean,const ComputationDataHandle & var,const ComputationDataHandle & grad_output,float epsilon,int64 feature_index)1383 ComputationDataHandle ComputationBuilder::BatchNormGrad(
1384     const ComputationDataHandle& operand, const ComputationDataHandle& scale,
1385     const ComputationDataHandle& mean, const ComputationDataHandle& var,
1386     const ComputationDataHandle& grad_output, float epsilon,
1387     int64 feature_index) {
1388   OpRequest op_request;
1389   BatchNormGradRequest* request = op_request.mutable_batch_norm_grad_request();
1390   *request->mutable_operand() = operand;
1391   *request->mutable_scale() = scale;
1392   *request->mutable_mean() = mean;
1393   *request->mutable_variance() = var;
1394   *request->mutable_grad_output() = grad_output;
1395   request->set_epsilon(epsilon);
1396   request->set_feature_index(feature_index);
1397   return RunOpAndParseResponse(&op_request);
1398 }
1399 
CrossReplicaSum(const ComputationDataHandle & operand)1400 ComputationDataHandle ComputationBuilder::CrossReplicaSum(
1401     const ComputationDataHandle& operand) {
1402   OpRequest op_request;
1403   CrossReplicaSumRequest* request =
1404       op_request.mutable_cross_replica_sum_request();
1405   *request->mutable_operand() = operand;
1406   return RunOpAndParseResponse(&op_request);
1407 }
1408 
SelectAndScatter(const ComputationDataHandle & operand,const Computation & select,tensorflow::gtl::ArraySlice<int64> window_dimensions,tensorflow::gtl::ArraySlice<int64> window_strides,Padding padding,const ComputationDataHandle & source,const ComputationDataHandle & init_value,const Computation & scatter)1409 ComputationDataHandle ComputationBuilder::SelectAndScatter(
1410     const ComputationDataHandle& operand, const Computation& select,
1411     tensorflow::gtl::ArraySlice<int64> window_dimensions,
1412     tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
1413     const ComputationDataHandle& source,
1414     const ComputationDataHandle& init_value, const Computation& scatter) {
1415   if (!first_error_.ok()) {
1416     return ComputationDataHandle();
1417   }
1418 
1419   StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
1420   if (!shape.ok()) {
1421     return ComputationDataHandle();
1422   }
1423   return SelectAndScatterWithGeneralPadding(
1424       operand, select, window_dimensions, window_strides,
1425       MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()),
1426                   window_dimensions, window_strides, padding),
1427       source, init_value, scatter);
1428 }
1429 
SelectAndScatterWithGeneralPadding(const ComputationDataHandle & operand,const Computation & select,tensorflow::gtl::ArraySlice<int64> window_dimensions,tensorflow::gtl::ArraySlice<int64> window_strides,tensorflow::gtl::ArraySlice<std::pair<int64,int64>> padding,const ComputationDataHandle & source,const ComputationDataHandle & init_value,const Computation & scatter)1430 ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding(
1431     const ComputationDataHandle& operand, const Computation& select,
1432     tensorflow::gtl::ArraySlice<int64> window_dimensions,
1433     tensorflow::gtl::ArraySlice<int64> window_strides,
1434     tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
1435     const ComputationDataHandle& source,
1436     const ComputationDataHandle& init_value, const Computation& scatter) {
1437   OpRequest op_request;
1438   SelectAndScatterRequest* request =
1439       op_request.mutable_select_and_scatter_request();
1440   *request->mutable_operand() = operand;
1441   *request->mutable_select() = select.handle();
1442   *request->mutable_source() = source;
1443   *request->mutable_init_value() = init_value;
1444   *request->mutable_scatter() = scatter.handle();
1445 
1446   if (!MakeWindow(window_dimensions, window_strides, padding, {}, {},
1447                   request->mutable_window())) {
1448     NoteError(InternalError("failed to make window"));
1449     return ComputationDataHandle();
1450   }
1451 
1452   return RunOpAndParseResponse(&op_request);
1453 }
1454 
ReducePrecision(const ComputationDataHandle & operand,const int exponent_bits,const int mantissa_bits)1455 ComputationDataHandle ComputationBuilder::ReducePrecision(
1456     const ComputationDataHandle& operand, const int exponent_bits,
1457     const int mantissa_bits) {
1458   OpRequest op_request;
1459   ReducePrecisionRequest* request =
1460       op_request.mutable_reduce_precision_request();
1461   *request->mutable_operand() = operand;
1462   request->set_exponent_bits(exponent_bits);
1463   request->set_mantissa_bits(mantissa_bits);
1464   return RunOpAndParseResponse(&op_request);
1465 }
1466 
Send(const ComputationDataHandle & operand,const ChannelHandle & handle)1467 void ComputationBuilder::Send(const ComputationDataHandle& operand,
1468                               const ChannelHandle& handle) {
1469   OpRequest op_request;
1470   SendRequest* request = op_request.mutable_send_request();
1471   *request->mutable_operand() = operand;
1472   *request->mutable_channel_handle() = handle;
1473   *op_request.mutable_computation() = computation_.handle();
1474   RunOpAndNoteError(&op_request);
1475 }
1476 
Recv(const Shape & shape,const ChannelHandle & handle)1477 ComputationDataHandle ComputationBuilder::Recv(const Shape& shape,
1478                                                const ChannelHandle& handle) {
1479   OpRequest op_request;
1480   RecvRequest* request = op_request.mutable_recv_request();
1481   *request->mutable_shape() = shape;
1482   *request->mutable_channel_handle() = handle;
1483   return RunOpAndParseResponse(&op_request);
1484 }
1485 
BuildAndNoteError()1486 Computation ComputationBuilder::BuildAndNoteError() {
1487   DCHECK(parent_builder_ != nullptr);
1488   auto build_status = Build();
1489   if (!build_status.ok()) {
1490     parent_builder_->NoteError(
1491         AddStatus(build_status.status(),
1492                   tensorflow::strings::StrCat("error from: ", name_)));
1493     return Computation();
1494   }
1495   return build_status.ConsumeValueOrDie();
1496 }
1497 
Build()1498 StatusOr<Computation> ComputationBuilder::Build() {
1499   if (!first_error_.ok()) {
1500     string backtrace;
1501     first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
1502     return AppendStatus(first_error_, backtrace);
1503   }
1504 
1505   if (computation_.IsNull()) {
1506     return FailedPrecondition("no computation was built");
1507   }
1508 
1509   return {std::move(computation_)};
1510 }
1511 
1512 /* static */ ConvolutionDimensionNumbers
CreateDefaultConvDimensionNumbers(int num_spatial_dims)1513 ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
1514   ConvolutionDimensionNumbers dimension_numbers;
1515   dimension_numbers.set_input_batch_dimension(kConvBatchDimension);
1516   dimension_numbers.set_input_feature_dimension(kConvFeatureDimension);
1517   dimension_numbers.set_output_batch_dimension(kConvBatchDimension);
1518   dimension_numbers.set_output_feature_dimension(kConvFeatureDimension);
1519   dimension_numbers.set_kernel_output_feature_dimension(
1520       kConvKernelOutputDimension);
1521   dimension_numbers.set_kernel_input_feature_dimension(
1522       kConvKernelInputDimension);
1523   for (int i = 0; i < num_spatial_dims; ++i) {
1524     dimension_numbers.add_input_spatial_dimensions(i + 2);
1525     dimension_numbers.add_kernel_spatial_dimensions(i + 2);
1526     dimension_numbers.add_output_spatial_dimensions(i + 2);
1527   }
1528   return dimension_numbers;
1529 }
1530 
1531 /* static */ StatusOr<ConvolutionDimensionNumbers>
CreateConvDimensionNumbers(int64 input_batch,int64 input_feature,int64 input_first_spatial,int64 input_second_spatial,int64 output_batch,int64 output_feature,int64 output_first_spatial,int64 output_second_spatial,int64 kernel_output_feature,int64 kernel_input_feature,int64 kernel_first_spatial,int64 kernel_second_spatial)1532 ComputationBuilder::CreateConvDimensionNumbers(
1533     int64 input_batch, int64 input_feature, int64 input_first_spatial,
1534     int64 input_second_spatial, int64 output_batch, int64 output_feature,
1535     int64 output_first_spatial, int64 output_second_spatial,
1536     int64 kernel_output_feature, int64 kernel_input_feature,
1537     int64 kernel_first_spatial, int64 kernel_second_spatial) {
1538   if (std::set<int64>({input_batch, input_feature, input_first_spatial,
1539                        input_second_spatial})
1540           .size() != 4) {
1541     return FailedPrecondition(
1542         "dimension numbers for the input are not unique: (%lld, %lld, %lld, "
1543         "%lld)",
1544         input_batch, input_feature, input_first_spatial, input_second_spatial);
1545   }
1546   if (std::set<int64>({kernel_output_feature, kernel_input_feature,
1547                        kernel_first_spatial, kernel_second_spatial})
1548           .size() != 4) {
1549     return FailedPrecondition(
1550         "dimension numbers for the weight are not unique: (%lld, %lld, %lld, "
1551         "%lld)",
1552         kernel_output_feature, kernel_input_feature, kernel_first_spatial,
1553         kernel_second_spatial);
1554   }
1555   if (std::set<int64>({output_batch, output_feature, output_first_spatial,
1556                        output_second_spatial})
1557           .size() != 4) {
1558     return FailedPrecondition(
1559         "dimension numbers for the output are not unique: (%lld, %lld, %lld, "
1560         "%lld)",
1561         output_batch, output_feature, output_first_spatial,
1562         output_second_spatial);
1563   }
1564   ConvolutionDimensionNumbers dimension_numbers;
1565   dimension_numbers.set_input_batch_dimension(input_batch);
1566   dimension_numbers.set_input_feature_dimension(input_feature);
1567   dimension_numbers.add_input_spatial_dimensions(input_first_spatial);
1568   dimension_numbers.add_input_spatial_dimensions(input_second_spatial);
1569   dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature);
1570   dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature);
1571   dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial);
1572   dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial);
1573   dimension_numbers.set_output_batch_dimension(output_batch);
1574   dimension_numbers.set_output_feature_dimension(output_feature);
1575   dimension_numbers.add_output_spatial_dimensions(output_first_spatial);
1576   dimension_numbers.add_output_spatial_dimensions(output_second_spatial);
1577   return dimension_numbers;
1578 }
1579 
1580 }  // namespace xla
1581