1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/computation_builder.h"
17
18 #include <stddef.h>
19 #include <array>
20 #include <numeric>
21 #include <set>
22 #include <vector>
23
24 #include "tensorflow/compiler/xla/ptr_util.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/compiler/xla/xla.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/protobuf.h"
34
35 namespace xla {
36
ComputationBuilder(Client * client,const string & computation_name)37 ComputationBuilder::ComputationBuilder(Client* client,
38 const string& computation_name)
39 : name_(computation_name), client_(client) {}
40
~ComputationBuilder()41 ComputationBuilder::~ComputationBuilder() {}
42
NoteError(const Status & error)43 void ComputationBuilder::NoteError(const Status& error) {
44 if (die_immediately_on_error_) {
45 LOG(FATAL) << "error building computation: " << error;
46 }
47
48 if (first_error_.ok()) {
49 first_error_ = error;
50 first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
51 }
52 }
53
CreateSubBuilder(const string & computation_name)54 std::unique_ptr<ComputationBuilder> ComputationBuilder::CreateSubBuilder(
55 const string& computation_name) {
56 auto sub_builder = MakeUnique<ComputationBuilder>(client_, computation_name);
57 sub_builder->parent_builder_ = this;
58 sub_builder->die_immediately_on_error_ = die_immediately_on_error_;
59 return sub_builder;
60 }
61
PrepareComputation()62 Status ComputationBuilder::PrepareComputation() {
63 TF_RETURN_IF_ERROR(first_error_);
64
65 if (!computation_.IsNull()) {
66 return Status::OK();
67 }
68
69 ComputationRequest request;
70 request.set_name(name_);
71 ComputationResponse response;
72
73 VLOG(2) << "making computation request";
74 Status s = client_->stub()->Computation(&request, &response);
75 VLOG(2) << "done with computation request";
76
77 if (!s.ok()) {
78 NoteError(s);
79 return first_error_;
80 }
81
82 computation_ = Computation(client_->stub(), response.computation());
83 return Status::OK();
84 }
85
RunOp(OpRequest * op_request,OpResponse * op_response)86 Status ComputationBuilder::RunOp(OpRequest* op_request,
87 OpResponse* op_response) {
88 TF_RETURN_IF_ERROR(first_error_);
89 TF_RETURN_IF_ERROR(PrepareComputation());
90
91 // Fill in fields that are set on every OpRequest.
92 *op_request->mutable_computation() = computation_.handle();
93 *op_request->mutable_metadata() = metadata_;
94 if (sharding_) {
95 *op_request->mutable_sharding() = *sharding_;
96 }
97
98 const string& op_name =
99 OpRequest::descriptor()->FindFieldByNumber(op_request->op_case())->name();
100 VLOG(2) << "running op request: " << op_name;
101 Status status = client_->stub()->Op(op_request, op_response);
102 VLOG(2) << "done with op request: " << op_name;
103 return status;
104 }
105
RunOpAndNoteError(OpRequest * op_request)106 void ComputationBuilder::RunOpAndNoteError(OpRequest* op_request) {
107 OpResponse op_response;
108 Status status = RunOp(op_request, &op_response);
109 if (!status.ok()) {
110 NoteError(status);
111 }
112 }
113
RunOpAndParseResponse(OpRequest * op_request)114 ComputationDataHandle ComputationBuilder::RunOpAndParseResponse(
115 OpRequest* op_request) {
116 OpResponse op_response;
117 Status status = RunOp(op_request, &op_response);
118 if (!status.ok()) {
119 NoteError(status);
120 return ComputationDataHandle();
121 }
122 if (op_response.output().handle() == 0) {
123 NoteError(InternalError("No output handle"));
124 return ComputationDataHandle();
125 }
126 return op_response.output();
127 }
128
MakeWindow(tensorflow::gtl::ArraySlice<int64> window_dimensions,tensorflow::gtl::ArraySlice<int64> window_strides,tensorflow::gtl::ArraySlice<std::pair<int64,int64>> padding,tensorflow::gtl::ArraySlice<int64> lhs_dilation,tensorflow::gtl::ArraySlice<int64> rhs_dilation,Window * window)129 bool ComputationBuilder::MakeWindow(
130 tensorflow::gtl::ArraySlice<int64> window_dimensions,
131 tensorflow::gtl::ArraySlice<int64> window_strides,
132 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
133 tensorflow::gtl::ArraySlice<int64> lhs_dilation,
134 tensorflow::gtl::ArraySlice<int64> rhs_dilation, Window* window) {
135 const auto verify_size = [&](const size_t x, const char* x_name) {
136 if (x == 0 || x == window_dimensions.size()) {
137 return true;
138 } else {
139 NoteError(InvalidArgument(
140 "%s", tensorflow::strings::StrCat(
141 "Window has different number of window dimensions than of ",
142 x_name, "\nNumber of window dimensions: ",
143 window_dimensions.size(), "\nNumber of ", x_name, ": ", x,
144 "\n")
145 .c_str())); //
146 return false;
147 }
148 };
149 if (!verify_size(window_strides.size(), "window strides") ||
150 !verify_size(padding.size(), "padding entries") ||
151 !verify_size(lhs_dilation.size(), "lhs dilation factors") ||
152 !verify_size(rhs_dilation.size(), "rhs dilation factors")) {
153 return false;
154 }
155
156 window->Clear();
157 for (size_t i = 0; i < window_dimensions.size(); i++) {
158 auto dim = window->add_dimensions();
159 dim->set_size(window_dimensions[i]);
160 if (!window_strides.empty()) {
161 dim->set_stride(window_strides[i]);
162 } else {
163 dim->set_stride(1);
164 }
165 if (!padding.empty()) {
166 dim->set_padding_low(padding[i].first);
167 dim->set_padding_high(padding[i].second);
168 } else {
169 dim->set_padding_low(0);
170 dim->set_padding_high(0);
171 }
172 if (!lhs_dilation.empty()) {
173 dim->set_base_dilation(lhs_dilation[i]);
174 } else {
175 dim->set_base_dilation(1);
176 }
177 if (!rhs_dilation.empty()) {
178 dim->set_window_dilation(rhs_dilation[i]);
179 } else {
180 dim->set_window_dilation(1);
181 }
182 dim->set_window_reversal(false);
183 }
184 return true;
185 }
186
ConstantLiteral(const Literal & literal)187 ComputationDataHandle ComputationBuilder::ConstantLiteral(
188 const Literal& literal) {
189 OpRequest op_request;
190 ConstantRequest* request = op_request.mutable_constant_request();
191 *request->mutable_literal() = literal.ToProto();
192 VLOG(3) << "created constant: " << request->literal().ShortDebugString();
193 return RunOpAndParseResponse(&op_request);
194 }
195
Parameter(int64 parameter_number,const Shape & shape,const string & name)196 ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number,
197 const Shape& shape,
198 const string& name) {
199 OpRequest op_request;
200 ParameterRequest* request = op_request.mutable_parameter_request();
201 *request->mutable_shape() = shape;
202 request->set_parameter(parameter_number);
203 request->set_name(name);
204 return RunOpAndParseResponse(&op_request);
205 }
206
GetShapeWithoutNoteError(const ComputationDataHandle & operand)207 StatusOr<std::unique_ptr<Shape>> ComputationBuilder::GetShapeWithoutNoteError(
208 const ComputationDataHandle& operand) {
209 GetLocalShapeRequest request;
210 *request.mutable_computation() = computation_.handle();
211 *request.mutable_operand() = operand;
212 GetLocalShapeResponse response;
213
214 VLOG(2) << "making get-shape request";
215 TF_RETURN_IF_ERROR(client_->stub()->GetLocalShape(&request, &response));
216 VLOG(2) << "done with request";
217
218 TF_RET_CHECK(response.has_shape());
219 std::unique_ptr<Shape> shape = WrapUnique(response.release_shape());
220 TF_RET_CHECK(shape != nullptr);
221 return std::move(shape);
222 }
223
GetShape(const ComputationDataHandle & operand)224 StatusOr<std::unique_ptr<Shape>> ComputationBuilder::GetShape(
225 const ComputationDataHandle& operand) {
226 TF_RETURN_IF_ERROR(first_error_);
227
228 auto status_or_shape = GetShapeWithoutNoteError(operand);
229 if (!status_or_shape.ok()) {
230 NoteError(status_or_shape.status());
231 return first_error_;
232 }
233 return status_or_shape;
234 }
235
GetProgramShape()236 StatusOr<ProgramShape> ComputationBuilder::GetProgramShape() {
237 TF_RETURN_IF_ERROR(first_error_);
238
239 GetComputationShapeRequest request;
240 *request.mutable_computation() = computation_.handle();
241 GetComputationShapeResponse response;
242
243 VLOG(2) << "making get-program-shape-request";
244 Status status = client_->stub()->GetComputationShape(&request, &response);
245 VLOG(2) << "done with get-program-shape-request";
246
247 if (!status.ok()) {
248 first_error_ = status;
249 return status;
250 }
251
252 TF_RET_CHECK(response.has_program_shape());
253 return std::move(*response.mutable_program_shape());
254 }
255
CheckShape(const ComputationDataHandle & operand,const Shape & expected_shape)256 ComputationDataHandle ComputationBuilder::CheckShape(
257 const ComputationDataHandle& operand, const Shape& expected_shape) {
258 std::unique_ptr<Shape> actual_shape = GetShape(operand).ConsumeValueOrDie();
259 CHECK(ShapeUtil::Equal(expected_shape, *actual_shape))
260 << "want " << ShapeUtil::HumanString(expected_shape) << " got "
261 << ShapeUtil::HumanString(*actual_shape);
262 return operand;
263 }
264
CheckSameShape(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs)265 void ComputationBuilder::CheckSameShape(const ComputationDataHandle& lhs,
266 const ComputationDataHandle& rhs) {
267 std::unique_ptr<Shape> lhs_shape = GetShape(lhs).ConsumeValueOrDie();
268 std::unique_ptr<Shape> rhs_shape = GetShape(rhs).ConsumeValueOrDie();
269 VLOG(2) << "checking " << ShapeUtil::HumanString(*lhs_shape) << " equals "
270 << ShapeUtil::HumanString(*rhs_shape);
271 CHECK(ShapeUtil::Equal(*lhs_shape, *rhs_shape))
272 << "lhs " << ShapeUtil::HumanString(*lhs_shape) << " rhs "
273 << ShapeUtil::HumanString(*rhs_shape);
274 }
275
Slice(const ComputationDataHandle & operand,tensorflow::gtl::ArraySlice<int64> start_indices,tensorflow::gtl::ArraySlice<int64> limit_indices,tensorflow::gtl::ArraySlice<int64> strides)276 ComputationDataHandle ComputationBuilder::Slice(
277 const ComputationDataHandle& operand,
278 tensorflow::gtl::ArraySlice<int64> start_indices,
279 tensorflow::gtl::ArraySlice<int64> limit_indices,
280 tensorflow::gtl::ArraySlice<int64> strides) {
281 OpRequest op_request;
282 SliceRequest* request = op_request.mutable_slice_request();
283 *request->mutable_operand() = operand;
284 for (int64 index : start_indices) {
285 request->add_start_indices(index);
286 }
287 for (int64 index : limit_indices) {
288 request->add_limit_indices(index);
289 }
290 for (int64 index : strides) {
291 request->add_strides(index);
292 }
293 return RunOpAndParseResponse(&op_request);
294 }
295
SliceInDim(const ComputationDataHandle & operand,int64 start_index,int64 limit_index,int64 stride,int64 dimno)296 ComputationDataHandle ComputationBuilder::SliceInDim(
297 const ComputationDataHandle& operand, int64 start_index, int64 limit_index,
298 int64 stride, int64 dimno) {
299 StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
300 if (!shape_status.ok()) {
301 NoteError(shape_status.status());
302 return ComputationDataHandle{};
303 }
304 const Shape& shape = *shape_status.ValueOrDie();
305 std::vector<int64> starts(ShapeUtil::Rank(shape), 0);
306 std::vector<int64> limits(shape.dimensions().begin(),
307 shape.dimensions().end());
308 std::vector<int64> strides(ShapeUtil::Rank(shape), 1);
309 starts[dimno] = start_index;
310 limits[dimno] = limit_index;
311 strides[dimno] = stride;
312 return Slice(operand, starts, limits, strides);
313 }
314
DynamicSlice(const ComputationDataHandle & operand,const ComputationDataHandle & start_indices,tensorflow::gtl::ArraySlice<int64> slice_sizes)315 ComputationDataHandle ComputationBuilder::DynamicSlice(
316 const ComputationDataHandle& operand,
317 const ComputationDataHandle& start_indices,
318 tensorflow::gtl::ArraySlice<int64> slice_sizes) {
319 OpRequest op_request;
320 DynamicSliceRequest* request = op_request.mutable_dynamic_slice_request();
321 *request->mutable_operand() = operand;
322 *request->mutable_start_indices() = start_indices;
323 for (int64 index : slice_sizes) {
324 request->add_slice_sizes(index);
325 }
326 return RunOpAndParseResponse(&op_request);
327 }
328
DynamicUpdateSlice(const ComputationDataHandle & operand,const ComputationDataHandle & update,const ComputationDataHandle & start_indices)329 ComputationDataHandle ComputationBuilder::DynamicUpdateSlice(
330 const ComputationDataHandle& operand, const ComputationDataHandle& update,
331 const ComputationDataHandle& start_indices) {
332 OpRequest op_request;
333 DynamicUpdateSliceRequest* request =
334 op_request.mutable_dynamic_update_slice_request();
335 *request->mutable_operand() = operand;
336 *request->mutable_update() = update;
337 *request->mutable_start_indices() = start_indices;
338 return RunOpAndParseResponse(&op_request);
339 }
340
ConcatInDim(tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,int64 dimension)341 ComputationDataHandle ComputationBuilder::ConcatInDim(
342 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
343 int64 dimension) {
344 OpRequest op_request;
345 ConcatenateRequest* request = op_request.mutable_concatenate_request();
346 for (const ComputationDataHandle& operand : operands) {
347 *request->add_operands() = operand;
348 }
349 request->set_dimension(dimension);
350 return RunOpAndParseResponse(&op_request);
351 }
352
Broadcast(const ComputationDataHandle & operand,tensorflow::gtl::ArraySlice<int64> broadcast_sizes)353 ComputationDataHandle ComputationBuilder::Broadcast(
354 const ComputationDataHandle& operand,
355 tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
356 OpRequest op_request;
357 BroadcastRequest* request = op_request.mutable_broadcast_request();
358 *request->mutable_operand() = operand;
359 for (int64 size : broadcast_sizes) {
360 request->add_broadcast_sizes(size);
361 }
362 return RunOpAndParseResponse(&op_request);
363 }
364
Pad(const ComputationDataHandle & operand,const ComputationDataHandle & padding_value,const PaddingConfig & padding_config)365 ComputationDataHandle ComputationBuilder::Pad(
366 const ComputationDataHandle& operand,
367 const ComputationDataHandle& padding_value,
368 const PaddingConfig& padding_config) {
369 OpRequest op_request;
370 PadRequest* request = op_request.mutable_pad_request();
371 *request->mutable_operand() = operand;
372 *request->mutable_padding_value() = padding_value;
373 *request->mutable_padding_config() = padding_config;
374 return RunOpAndParseResponse(&op_request);
375 }
376
Reshape(const ComputationDataHandle & operand,tensorflow::gtl::ArraySlice<int64> dimensions,tensorflow::gtl::ArraySlice<int64> new_sizes)377 ComputationDataHandle ComputationBuilder::Reshape(
378 const ComputationDataHandle& operand,
379 tensorflow::gtl::ArraySlice<int64> dimensions,
380 tensorflow::gtl::ArraySlice<int64> new_sizes) {
381 OpRequest op_request;
382 ReshapeRequest* request = op_request.mutable_reshape_request();
383 *request->mutable_operand() = operand;
384 for (int64 dimension : dimensions) {
385 request->add_dimensions(dimension);
386 }
387 for (int64 new_size : new_sizes) {
388 request->add_new_sizes(new_size);
389 }
390 return RunOpAndParseResponse(&op_request);
391 }
392
Reshape(const ComputationDataHandle & operand,tensorflow::gtl::ArraySlice<int64> new_sizes)393 ComputationDataHandle ComputationBuilder::Reshape(
394 const ComputationDataHandle& operand,
395 tensorflow::gtl::ArraySlice<int64> new_sizes) {
396 if (!first_error_.ok()) {
397 return ComputationDataHandle();
398 }
399
400 StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
401 if (!shape.ok()) {
402 return ComputationDataHandle();
403 }
404 std::vector<int64> dimensions(shape.ValueOrDie()->dimensions().size());
405 std::iota(dimensions.begin(), dimensions.end(), 0);
406 return Reshape(operand, dimensions, new_sizes);
407 }
408
Collapse(const ComputationDataHandle & operand,tensorflow::gtl::ArraySlice<int64> dims_to_collapse)409 ComputationDataHandle ComputationBuilder::Collapse(
410 const ComputationDataHandle& operand,
411 tensorflow::gtl::ArraySlice<int64> dims_to_collapse) {
412 if (!first_error_.ok()) {
413 return ComputationDataHandle();
414 }
415
416 // Don't support out-of-order collapse here.
417 // Checks that the collapsed dimensions are in order and consecutive.
418 for (tensorflow::gtl::ArraySlice<int64>::size_type i = 1;
419 i < dims_to_collapse.size(); ++i) {
420 if (dims_to_collapse[i] - 1 != dims_to_collapse[i - 1]) {
421 NoteError(InvalidArgument(
422 "Collapsed dimensions are not in order and consecutive."));
423 return ComputationDataHandle();
424 }
425 }
426
427 // Create a new sizes vector from the old shape, replacing the collapsed
428 // dimensions by the product of their sizes.
429 StatusOr<std::unique_ptr<Shape>> shape_or_status = GetShape(operand);
430 if (!shape_or_status.ok()) {
431 return ComputationDataHandle();
432 }
433 std::unique_ptr<Shape> original_shape = shape_or_status.ConsumeValueOrDie();
434
435 VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape);
436 VLOG(3) << "dims to collapse: "
437 << tensorflow::str_util::Join(dims_to_collapse, ",");
438
439 if (dims_to_collapse.size() <= 1) {
440 // Not collapsing anything, trivially we can return the operand versus
441 // enqueueing a trivial reshape.
442 return operand;
443 }
444
445 std::vector<int64> new_sizes;
446 for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) {
447 if (i <= dims_to_collapse.front() || i > dims_to_collapse.back()) {
448 new_sizes.push_back(original_shape->dimensions(i));
449 } else {
450 new_sizes.back() *= original_shape->dimensions(i);
451 }
452 }
453
454 VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",")
455 << "]";
456
457 return Reshape(operand, new_sizes);
458 }
459
Trace(const string & tag,const ComputationDataHandle & operand)460 void ComputationBuilder::Trace(const string& tag,
461 const ComputationDataHandle& operand) {
462 OpRequest op_request;
463 TraceRequest* request = op_request.mutable_trace_request();
464 request->set_tag(tag);
465 *request->mutable_operand() = operand;
466 RunOpAndNoteError(&op_request);
467 }
468
Select(const ComputationDataHandle & pred,const ComputationDataHandle & on_true,const ComputationDataHandle & on_false)469 ComputationDataHandle ComputationBuilder::Select(
470 const ComputationDataHandle& pred, const ComputationDataHandle& on_true,
471 const ComputationDataHandle& on_false) {
472 return TernaryOp(TRIOP_SELECT, pred, on_true, on_false);
473 }
474
Tuple(tensorflow::gtl::ArraySlice<ComputationDataHandle> elements)475 ComputationDataHandle ComputationBuilder::Tuple(
476 tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) {
477 OpRequest op_request;
478 VariadicOpRequest* request = op_request.mutable_variadic_op_request();
479 request->set_varop(VAROP_TUPLE);
480 for (const ComputationDataHandle& operand : elements) {
481 *request->add_operands() = operand;
482 }
483 return RunOpAndParseResponse(&op_request);
484 }
485
GetTupleElement(const ComputationDataHandle & tuple_data,int64 index)486 ComputationDataHandle ComputationBuilder::GetTupleElement(
487 const ComputationDataHandle& tuple_data, int64 index) {
488 OpRequest op_request;
489 GetTupleElementRequest* request =
490 op_request.mutable_get_tuple_element_request();
491 *request->mutable_operand() = tuple_data;
492 request->set_index(index);
493 return RunOpAndParseResponse(&op_request);
494 }
495
Eq(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)496 ComputationDataHandle ComputationBuilder::Eq(
497 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
498 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
499 return BinaryOp(BINOP_EQ, lhs, rhs, broadcast_dimensions);
500 }
501
Ne(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)502 ComputationDataHandle ComputationBuilder::Ne(
503 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
504 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
505 return BinaryOp(BINOP_NE, lhs, rhs, broadcast_dimensions);
506 }
507
Ge(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)508 ComputationDataHandle ComputationBuilder::Ge(
509 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
510 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
511 return BinaryOp(BINOP_GE, lhs, rhs, broadcast_dimensions);
512 }
513
Gt(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)514 ComputationDataHandle ComputationBuilder::Gt(
515 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
516 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
517 return BinaryOp(BINOP_GT, lhs, rhs, broadcast_dimensions);
518 }
519
Le(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)520 ComputationDataHandle ComputationBuilder::Le(
521 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
522 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
523 return BinaryOp(BINOP_LE, lhs, rhs, broadcast_dimensions);
524 }
525
Lt(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)526 ComputationDataHandle ComputationBuilder::Lt(
527 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
528 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
529 return BinaryOp(BINOP_LT, lhs, rhs, broadcast_dimensions);
530 }
531
Dot(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs)532 ComputationDataHandle ComputationBuilder::Dot(
533 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) {
534 StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
535 if (!lhs_shape_or_status.ok()) {
536 return ComputationDataHandle();
537 }
538 std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
539
540 DotDimensionNumbers dimension_numbers;
541 dimension_numbers.add_lhs_contracting_dimensions(
542 lhs_shape->dimensions_size() == 1 ? 0 : 1);
543 dimension_numbers.add_rhs_contracting_dimensions(0);
544 return DotGeneral(lhs, rhs, dimension_numbers);
545 }
546
DotGeneral(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,const DotDimensionNumbers & dimension_numbers)547 ComputationDataHandle ComputationBuilder::DotGeneral(
548 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
549 const DotDimensionNumbers& dimension_numbers) {
550 OpRequest op_request;
551 DotRequest* request = op_request.mutable_dot_request();
552 *request->mutable_lhs() = lhs;
553 *request->mutable_rhs() = rhs;
554 *request->mutable_dimension_numbers() = dimension_numbers;
555 return RunOpAndParseResponse(&op_request);
556 }
557
Conv(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> window_strides,Padding padding)558 ComputationDataHandle ComputationBuilder::Conv(
559 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
560 tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
561 return ConvWithGeneralDimensions(
562 lhs, rhs, window_strides, padding,
563 CreateDefaultConvDimensionNumbers(window_strides.size()));
564 }
565
ConvWithGeneralPadding(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> window_strides,tensorflow::gtl::ArraySlice<std::pair<int64,int64>> padding)566 ComputationDataHandle ComputationBuilder::ConvWithGeneralPadding(
567 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
568 tensorflow::gtl::ArraySlice<int64> window_strides,
569 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
570 return ConvGeneral(lhs, rhs, window_strides, padding,
571 CreateDefaultConvDimensionNumbers(window_strides.size()));
572 }
573
VerifyConvolution(const Shape & lhs_shape,const Shape & rhs_shape,const ConvolutionDimensionNumbers & dimension_numbers)574 bool ComputationBuilder::VerifyConvolution(
575 const Shape& lhs_shape, const Shape& rhs_shape,
576 const ConvolutionDimensionNumbers& dimension_numbers) {
577 if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) {
578 NoteError(
579 InvalidArgument("Convolution arguments must have same number of "
580 "dimensions. Got: %s and %s",
581 ShapeUtil::HumanString(lhs_shape).c_str(),
582 ShapeUtil::HumanString(rhs_shape).c_str()));
583 return false;
584 }
585 int num_dims = ShapeUtil::Rank(lhs_shape);
586 if (num_dims < 2) {
587 NoteError(InvalidArgument(
588 "Convolution expects argument arrays with >= 3 dimensions. "
589 "Got: %s and %s",
590 ShapeUtil::HumanString(lhs_shape).c_str(),
591 ShapeUtil::HumanString(rhs_shape).c_str()));
592 return false;
593 }
594 int num_spatial_dims = num_dims - 2;
595
596 const auto check_spatial_dimensions =
597 [&](const char* const field_name,
598 const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
599 numbers) {
600 if (numbers.size() != num_spatial_dims) {
601 NoteError(InvalidArgument("Expected %d elements for %s, but got %d.",
602 num_spatial_dims, field_name,
603 numbers.size()));
604 return false;
605 }
606 for (int i = 0; i < numbers.size(); ++i) {
607 if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
608 NoteError(
609 InvalidArgument("Convolution %s[%d] is out of bounds: %lld",
610 field_name, i, numbers.Get(i)));
611 return false;
612 }
613 }
614 return true;
615 };
616 return check_spatial_dimensions(
617 "input_spatial_dimensions",
618 dimension_numbers.input_spatial_dimensions()) &&
619 check_spatial_dimensions(
620 "kernel_spatial_dimensions",
621 dimension_numbers.kernel_spatial_dimensions()) &&
622 check_spatial_dimensions(
623 "output_spatial_dimensions",
624 dimension_numbers.output_spatial_dimensions());
625 }
626
ConvWithGeneralDimensions(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers)627 ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions(
628 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
629 tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
630 const ConvolutionDimensionNumbers& dimension_numbers) {
631 if (!first_error_.ok() || !PrepareComputation().ok()) {
632 return ComputationDataHandle();
633 }
634
635 StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
636 if (!lhs_shape_or_status.ok()) {
637 return ComputationDataHandle();
638 }
639
640 StatusOr<std::unique_ptr<Shape>> rhs_shape_or_status = GetShape(rhs);
641 if (!rhs_shape_or_status.ok()) {
642 return ComputationDataHandle();
643 }
644
645 std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
646 std::unique_ptr<Shape> rhs_shape = rhs_shape_or_status.ConsumeValueOrDie();
647
648 if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) {
649 NoteError(InternalError("failed to verify convolution"));
650 return ComputationDataHandle();
651 }
652
653 std::vector<int64> base_area_dimensions(
654 dimension_numbers.input_spatial_dimensions_size());
655 for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size();
656 ++i) {
657 base_area_dimensions[i] =
658 lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i));
659 }
660
661 std::vector<int64> window_dimensions(
662 dimension_numbers.kernel_spatial_dimensions_size());
663 for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) {
664 window_dimensions[i] =
665 rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
666 }
667
668 return ConvGeneral(lhs, rhs, window_strides,
669 MakePadding(base_area_dimensions, window_dimensions,
670 window_strides, padding),
671 dimension_numbers);
672 }
673
ConvGeneral(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> window_strides,tensorflow::gtl::ArraySlice<std::pair<int64,int64>> padding,const ConvolutionDimensionNumbers & dimension_numbers)674 ComputationDataHandle ComputationBuilder::ConvGeneral(
675 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
676 tensorflow::gtl::ArraySlice<int64> window_strides,
677 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
678 const ConvolutionDimensionNumbers& dimension_numbers) {
679 return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
680 dimension_numbers);
681 }
682
ConvGeneralDilated(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> window_strides,tensorflow::gtl::ArraySlice<std::pair<int64,int64>> padding,tensorflow::gtl::ArraySlice<int64> lhs_dilation,tensorflow::gtl::ArraySlice<int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers)683 ComputationDataHandle ComputationBuilder::ConvGeneralDilated(
684 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
685 tensorflow::gtl::ArraySlice<int64> window_strides,
686 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
687 tensorflow::gtl::ArraySlice<int64> lhs_dilation,
688 tensorflow::gtl::ArraySlice<int64> rhs_dilation,
689 const ConvolutionDimensionNumbers& dimension_numbers) {
690 if (!first_error_.ok() || !PrepareComputation().ok()) {
691 return ComputationDataHandle();
692 }
693
694 StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
695 if (!lhs_shape_or_status.ok()) {
696 return ComputationDataHandle();
697 }
698
699 StatusOr<std::unique_ptr<Shape>> rhs_shape_or_status = GetShape(rhs);
700 if (!rhs_shape_or_status.ok()) {
701 return ComputationDataHandle();
702 }
703
704 std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
705 std::unique_ptr<Shape> rhs_shape = rhs_shape_or_status.ConsumeValueOrDie();
706 if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) {
707 // Error is recorded in VerifyConvolution.
708 return ComputationDataHandle();
709 }
710
711 std::vector<int64> window_dimensions(
712 dimension_numbers.kernel_spatial_dimensions_size());
713 for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) {
714 window_dimensions[i] =
715 rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
716 }
717
718 OpRequest op_request;
719 ConvolveRequest* request = op_request.mutable_convolve_request();
720 *request->mutable_lhs() = lhs;
721 *request->mutable_rhs() = rhs;
722 *request->mutable_dimension_numbers() = dimension_numbers;
723
724 if (!MakeWindow(window_dimensions, window_strides, padding, lhs_dilation,
725 rhs_dilation, request->mutable_window())) {
726 // Error is recorded in MakeWindow.
727 return ComputationDataHandle();
728 }
729
730 return RunOpAndParseResponse(&op_request);
731 }
732
Fft(const ComputationDataHandle & operand,const FftType fft_type,const tensorflow::gtl::ArraySlice<int64> fft_length)733 ComputationDataHandle ComputationBuilder::Fft(
734 const ComputationDataHandle& operand, const FftType fft_type,
735 const tensorflow::gtl::ArraySlice<int64> fft_length) {
736 OpRequest op_request;
737 FftRequest* request = op_request.mutable_fft_request();
738 *request->mutable_operand() = operand;
739 request->set_fft_type(fft_type);
740 for (int64 dim_len : fft_length) {
741 request->add_fft_length(dim_len);
742 }
743 return RunOpAndParseResponse(&op_request);
744 }
745
Infeed(const Shape & shape,const string & config)746 ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape,
747 const string& config) {
748 OpRequest op_request;
749 InfeedRequest* request = op_request.mutable_infeed_request();
750 *request->mutable_shape() = shape;
751 *request->mutable_config() = config;
752 return RunOpAndParseResponse(&op_request);
753 }
754
Outfeed(const ComputationDataHandle & operand,const Shape & shape,const string & outfeed_config)755 void ComputationBuilder::Outfeed(const ComputationDataHandle& operand,
756 const Shape& shape,
757 const string& outfeed_config) {
758 OpRequest op_request;
759 OutfeedRequest* request = op_request.mutable_outfeed_request();
760 request->set_outfeed_config(outfeed_config);
761 *request->mutable_operand() = operand;
762 *request->mutable_shape() = shape;
763 RunOpAndNoteError(&op_request);
764 }
765
Call(const Computation & computation,tensorflow::gtl::ArraySlice<ComputationDataHandle> operands)766 ComputationDataHandle ComputationBuilder::Call(
767 const Computation& computation,
768 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands) {
769 OpRequest op_request;
770 CallRequest* request = op_request.mutable_call_request();
771 *request->mutable_to_apply() = computation.handle();
772 for (const ComputationDataHandle& operand : operands) {
773 *request->add_operands() = operand;
774 }
775 return RunOpAndParseResponse(&op_request);
776 }
777
CustomCall(const string & call_target_name,tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,const Shape & shape)778 ComputationDataHandle ComputationBuilder::CustomCall(
779 const string& call_target_name,
780 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
781 const Shape& shape) {
782 OpRequest op_request;
783 CustomCallRequest* request = op_request.mutable_custom_call_request();
784 request->set_call_target_name(call_target_name);
785 for (const ComputationDataHandle& operand : operands) {
786 *request->add_operands() = operand;
787 }
788 *request->mutable_shape() = shape;
789 return RunOpAndParseResponse(&op_request);
790 }
791
HostCompute(tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,const string & channel_name,int64 cost_estimate_ns,const Shape & shape)792 ComputationDataHandle ComputationBuilder::HostCompute(
793 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
794 const string& channel_name, int64 cost_estimate_ns, const Shape& shape) {
795 OpRequest op_request;
796 HostComputeRequest* request = op_request.mutable_host_compute_request();
797 for (const ComputationDataHandle& operand : operands) {
798 *request->add_operands() = operand;
799 }
800 *request->mutable_shape() = shape;
801 request->set_channel_name(channel_name);
802 request->set_cost_estimate_ns(cost_estimate_ns);
803 return RunOpAndParseResponse(&op_request);
804 }
805
Complex(const ComputationDataHandle & real,const ComputationDataHandle & imag,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)806 ComputationDataHandle ComputationBuilder::Complex(
807 const ComputationDataHandle& real, const ComputationDataHandle& imag,
808 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
809 return BinaryOp(BINOP_COMPLEX, real, imag, broadcast_dimensions);
810 }
811
Conj(const ComputationDataHandle & operand)812 ComputationDataHandle ComputationBuilder::Conj(
813 const ComputationDataHandle& operand) {
814 return Complex(Real(operand), Neg(Imag(operand)));
815 }
816
Add(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)817 ComputationDataHandle ComputationBuilder::Add(
818 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
819 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
820 return BinaryOp(BINOP_ADD, lhs, rhs, broadcast_dimensions);
821 }
822
Sub(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)823 ComputationDataHandle ComputationBuilder::Sub(
824 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
825 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
826 return BinaryOp(BINOP_SUB, lhs, rhs, broadcast_dimensions);
827 }
828
Mul(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)829 ComputationDataHandle ComputationBuilder::Mul(
830 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
831 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
832 return BinaryOp(BINOP_MUL, lhs, rhs, broadcast_dimensions);
833 }
834
Div(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)835 ComputationDataHandle ComputationBuilder::Div(
836 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
837 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
838 return BinaryOp(BINOP_DIV, lhs, rhs, broadcast_dimensions);
839 }
840
Rem(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)841 ComputationDataHandle ComputationBuilder::Rem(
842 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
843 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
844 return BinaryOp(BINOP_REM, lhs, rhs, broadcast_dimensions);
845 }
846
Max(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)847 ComputationDataHandle ComputationBuilder::Max(
848 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
849 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
850 return BinaryOp(BINOP_MAX, lhs, rhs, broadcast_dimensions);
851 }
852
Min(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)853 ComputationDataHandle ComputationBuilder::Min(
854 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
855 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
856 return BinaryOp(BINOP_MIN, lhs, rhs, broadcast_dimensions);
857 }
858
And(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)859 ComputationDataHandle ComputationBuilder::And(
860 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
861 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
862 return BinaryOp(BINOP_AND, lhs, rhs, broadcast_dimensions);
863 }
864
Or(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)865 ComputationDataHandle ComputationBuilder::Or(
866 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
867 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
868 return BinaryOp(BINOP_OR, lhs, rhs, broadcast_dimensions);
869 }
870
Not(const ComputationDataHandle & operand)871 ComputationDataHandle ComputationBuilder::Not(
872 const ComputationDataHandle& operand) {
873 return UnaryOp(UNOP_NOT, operand);
874 }
875
ShiftLeft(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)876 ComputationDataHandle ComputationBuilder::ShiftLeft(
877 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
878 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
879 return BinaryOp(BINOP_SHIFT_LEFT, lhs, rhs, broadcast_dimensions);
880 }
881
ShiftRightArithmetic(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)882 ComputationDataHandle ComputationBuilder::ShiftRightArithmetic(
883 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
884 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
885 return BinaryOp(BINOP_SHIFT_RIGHT_ARITHMETIC, lhs, rhs, broadcast_dimensions);
886 }
887
ShiftRightLogical(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)888 ComputationDataHandle ComputationBuilder::ShiftRightLogical(
889 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
890 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
891 return BinaryOp(BINOP_SHIFT_RIGHT_LOGICAL, lhs, rhs, broadcast_dimensions);
892 }
893
Abs(const ComputationDataHandle & operand)894 ComputationDataHandle ComputationBuilder::Abs(
895 const ComputationDataHandle& operand) {
896 return UnaryOp(UNOP_ABS, operand);
897 }
898
Atan2(const ComputationDataHandle & y,const ComputationDataHandle & x,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)899 ComputationDataHandle ComputationBuilder::Atan2(
900 const ComputationDataHandle& y, const ComputationDataHandle& x,
901 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
902 return BinaryOp(BINOP_ATAN2, y, x, broadcast_dimensions);
903 }
904
Exp(const ComputationDataHandle & operand)905 ComputationDataHandle ComputationBuilder::Exp(
906 const ComputationDataHandle& operand) {
907 return UnaryOp(UNOP_EXP, operand);
908 }
909
Floor(const ComputationDataHandle & operand)910 ComputationDataHandle ComputationBuilder::Floor(
911 const ComputationDataHandle& operand) {
912 return UnaryOp(UNOP_FLOOR, operand);
913 }
914
Ceil(const ComputationDataHandle & operand)915 ComputationDataHandle ComputationBuilder::Ceil(
916 const ComputationDataHandle& operand) {
917 return UnaryOp(UNOP_CEIL, operand);
918 }
919
Round(const ComputationDataHandle & operand)920 ComputationDataHandle ComputationBuilder::Round(
921 const ComputationDataHandle& operand) {
922 return UnaryOp(UNOP_ROUND_NEAREST_AFZ, operand);
923 }
924
Log(const ComputationDataHandle & operand)925 ComputationDataHandle ComputationBuilder::Log(
926 const ComputationDataHandle& operand) {
927 return UnaryOp(UNOP_LOG, operand);
928 }
929
Sign(const ComputationDataHandle & operand)930 ComputationDataHandle ComputationBuilder::Sign(
931 const ComputationDataHandle& operand) {
932 return UnaryOp(UNOP_SIGN, operand);
933 }
934
Cos(const ComputationDataHandle & operand)935 ComputationDataHandle ComputationBuilder::Cos(
936 const ComputationDataHandle& operand) {
937 return UnaryOp(UNOP_COS, operand);
938 }
939
Sin(const ComputationDataHandle & operand)940 ComputationDataHandle ComputationBuilder::Sin(
941 const ComputationDataHandle& operand) {
942 return UnaryOp(UNOP_SIN, operand);
943 }
944
Tanh(const ComputationDataHandle & operand)945 ComputationDataHandle ComputationBuilder::Tanh(
946 const ComputationDataHandle& operand) {
947 return UnaryOp(UNOP_TANH, operand);
948 }
949
Real(const ComputationDataHandle & operand)950 ComputationDataHandle ComputationBuilder::Real(
951 const ComputationDataHandle& operand) {
952 return UnaryOp(UNOP_REAL, operand);
953 }
954
Imag(const ComputationDataHandle & operand)955 ComputationDataHandle ComputationBuilder::Imag(
956 const ComputationDataHandle& operand) {
957 return UnaryOp(UNOP_IMAG, operand);
958 }
959
IsFinite(const ComputationDataHandle & operand)960 ComputationDataHandle ComputationBuilder::IsFinite(
961 const ComputationDataHandle& operand) {
962 return UnaryOp(UNOP_IS_FINITE, operand);
963 }
964
Transpose(const ComputationDataHandle & operand,tensorflow::gtl::ArraySlice<int64> permutation)965 ComputationDataHandle ComputationBuilder::Transpose(
966 const ComputationDataHandle& operand,
967 tensorflow::gtl::ArraySlice<int64> permutation) {
968 OpRequest op_request;
969 TransposeRequest* request = op_request.mutable_transpose_request();
970 *request->mutable_operand() = operand;
971 for (int64 dimension : permutation) {
972 request->add_dimensions(dimension);
973 }
974 return RunOpAndParseResponse(&op_request);
975 }
976
Rev(const ComputationDataHandle & operand,tensorflow::gtl::ArraySlice<int64> dimensions)977 ComputationDataHandle ComputationBuilder::Rev(
978 const ComputationDataHandle& operand,
979 tensorflow::gtl::ArraySlice<int64> dimensions) {
980 OpRequest op_request;
981 ReverseRequest* request = op_request.mutable_reverse_request();
982 *request->mutable_operand() = operand;
983 for (int64 dimension : dimensions) {
984 request->add_dimensions(dimension);
985 }
986 return RunOpAndParseResponse(&op_request);
987 }
988
Sort(const ComputationDataHandle & operand)989 ComputationDataHandle ComputationBuilder::Sort(
990 const ComputationDataHandle& operand) {
991 return UnaryOp(UNOP_SORT, operand);
992 }
993
SqrtF32(const ComputationDataHandle & operand)994 ComputationDataHandle ComputationBuilder::SqrtF32(
995 const ComputationDataHandle& operand) {
996 return BinaryOp(BINOP_POW, operand, ConstantR0<float>(0.5),
997 /*broadcast_dimensions=*/{});
998 }
999
Pow(const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)1000 ComputationDataHandle ComputationBuilder::Pow(
1001 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
1002 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
1003 return BinaryOp(BINOP_POW, lhs, rhs, broadcast_dimensions);
1004 }
1005
ConvertElementType(const ComputationDataHandle & operand,PrimitiveType new_element_type)1006 ComputationDataHandle ComputationBuilder::ConvertElementType(
1007 const ComputationDataHandle& operand, PrimitiveType new_element_type) {
1008 if (!first_error_.ok() || !PrepareComputation().ok()) {
1009 return ComputationDataHandle();
1010 }
1011
1012 StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
1013 if (!shape_status.ok()) {
1014 return ComputationDataHandle();
1015 }
1016 std::unique_ptr<Shape> original = shape_status.ConsumeValueOrDie();
1017
1018 OpRequest op_request;
1019 ConvertRequest* request = op_request.mutable_convert_request();
1020 *request->mutable_operand() = operand;
1021 request->set_new_element_type(new_element_type);
1022 return RunOpAndParseResponse(&op_request);
1023 }
1024
BitcastConvertType(const ComputationDataHandle & operand,PrimitiveType new_element_type)1025 ComputationDataHandle ComputationBuilder::BitcastConvertType(
1026 const ComputationDataHandle& operand, PrimitiveType new_element_type) {
1027 if (!first_error_.ok() || !PrepareComputation().ok()) {
1028 return ComputationDataHandle();
1029 }
1030
1031 StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
1032 if (!shape_status.ok()) {
1033 return ComputationDataHandle();
1034 }
1035 std::unique_ptr<Shape> original = shape_status.ConsumeValueOrDie();
1036
1037 OpRequest op_request;
1038 ConvertRequest* request = op_request.mutable_bitcast_convert_request();
1039 *request->mutable_operand() = operand;
1040 request->set_new_element_type(new_element_type);
1041 return RunOpAndParseResponse(&op_request);
1042 }
1043
SquareF32(const ComputationDataHandle & operand)1044 ComputationDataHandle ComputationBuilder::SquareF32(
1045 const ComputationDataHandle& operand) {
1046 return BinaryOp(BINOP_POW, operand, ConstantR0<float>(2.0),
1047 /*broadcast_dimensions=*/{});
1048 }
1049
ReciprocalF32(const ComputationDataHandle & operand)1050 ComputationDataHandle ComputationBuilder::ReciprocalF32(
1051 const ComputationDataHandle& operand) {
1052 return BinaryOp(BINOP_POW, operand, ConstantR0<float>(-1.0),
1053 /*broadcast_dimensions=*/{});
1054 }
1055
Neg(const ComputationDataHandle & operand)1056 ComputationDataHandle ComputationBuilder::Neg(
1057 const ComputationDataHandle& operand) {
1058 return UnaryOp(UNOP_NEGATE, operand);
1059 }
1060
Clamp(const ComputationDataHandle & min,const ComputationDataHandle & operand,const ComputationDataHandle & max)1061 ComputationDataHandle ComputationBuilder::Clamp(
1062 const ComputationDataHandle& min, const ComputationDataHandle& operand,
1063 const ComputationDataHandle& max) {
1064 return TernaryOp(TRIOP_CLAMP, min, operand, max);
1065 }
1066
UnaryOp(UnaryOperation unop,const ComputationDataHandle & operand)1067 ComputationDataHandle ComputationBuilder::UnaryOp(
1068 UnaryOperation unop, const ComputationDataHandle& operand) {
1069 OpRequest op_request;
1070 UnaryOpRequest* request = op_request.mutable_unary_op_request();
1071 request->set_unop(unop);
1072 *request->mutable_operand() = operand;
1073 return RunOpAndParseResponse(&op_request);
1074 }
1075
BinaryOp(BinaryOperation binop,const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)1076 ComputationDataHandle ComputationBuilder::BinaryOp(
1077 BinaryOperation binop, const ComputationDataHandle& lhs,
1078 const ComputationDataHandle& rhs,
1079 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
1080 OpRequest op_request;
1081 BinaryOpRequest* request = op_request.mutable_binary_op_request();
1082 request->set_binop(binop);
1083 *request->mutable_lhs() = lhs;
1084 *request->mutable_rhs() = rhs;
1085 for (int64 dimension : broadcast_dimensions) {
1086 request->add_broadcast_dimensions(dimension);
1087 }
1088 return RunOpAndParseResponse(&op_request);
1089 }
1090
RngOp(RandomDistribution distribution,tensorflow::gtl::ArraySlice<ComputationDataHandle> parameters,const Shape & shape)1091 ComputationDataHandle ComputationBuilder::RngOp(
1092 RandomDistribution distribution,
1093 tensorflow::gtl::ArraySlice<ComputationDataHandle> parameters,
1094 const Shape& shape) {
1095 OpRequest op_request;
1096 RngRequest* request = op_request.mutable_rng_request();
1097 request->set_distribution(distribution);
1098 for (const ComputationDataHandle& param : parameters) {
1099 *request->add_parameter() = param;
1100 }
1101 *request->mutable_shape() = shape;
1102 return RunOpAndParseResponse(&op_request);
1103 }
1104
TernaryOp(TernaryOperation triop,const ComputationDataHandle & lhs,const ComputationDataHandle & rhs,const ComputationDataHandle & ehs)1105 ComputationDataHandle ComputationBuilder::TernaryOp(
1106 TernaryOperation triop, const ComputationDataHandle& lhs,
1107 const ComputationDataHandle& rhs, const ComputationDataHandle& ehs) {
1108 OpRequest op_request;
1109 TernaryOpRequest* request = op_request.mutable_ternary_op_request();
1110 request->set_triop(triop);
1111 *request->mutable_lhs() = lhs;
1112 *request->mutable_rhs() = rhs;
1113 *request->mutable_ehs() = ehs;
1114 return RunOpAndParseResponse(&op_request);
1115 }
1116
SetReturnValue(const ComputationDataHandle & operand)1117 Status ComputationBuilder::SetReturnValue(
1118 const ComputationDataHandle& operand) {
1119 TF_RETURN_IF_ERROR(first_error_);
1120
1121 SetReturnValueRequest request;
1122 *request.mutable_computation() = computation_.handle();
1123 *request.mutable_operand() = operand;
1124
1125 SetReturnValueResponse response;
1126
1127 VLOG(2) << "making set-handle-to-execute request";
1128 Status s = client_->stub()->SetReturnValue(&request, &response);
1129 VLOG(2) << "done with request";
1130
1131 if (!s.ok()) {
1132 NoteError(s);
1133 return first_error_;
1134 }
1135
1136 return Status::OK();
1137 }
1138
IsConstant(const ComputationDataHandle & operand,int64 num_parameters)1139 StatusOr<bool> ComputationBuilder::IsConstant(
1140 const ComputationDataHandle& operand, int64 num_parameters) {
1141 TF_RETURN_IF_ERROR(first_error_);
1142
1143 IsConstantRequest request;
1144 *request.mutable_computation() = computation_.handle();
1145 *request.mutable_operand() = operand;
1146 request.set_num_parameters(num_parameters);
1147 IsConstantResponse response;
1148
1149 VLOG(2) << "making IsConstant request";
1150 Status s = client_->stub()->IsConstant(&request, &response);
1151 VLOG(2) << "done with request";
1152
1153 if (!s.ok()) {
1154 return s;
1155 }
1156 return response.is_constant();
1157 }
1158
ComputeConstant(const ComputationDataHandle & operand,const Layout * output_layout,tensorflow::gtl::ArraySlice<Literal> parameters)1159 StatusOr<std::unique_ptr<Literal>> ComputationBuilder::ComputeConstant(
1160 const ComputationDataHandle& operand, const Layout* output_layout,
1161 tensorflow::gtl::ArraySlice<Literal> parameters) {
1162 TF_RETURN_IF_ERROR(first_error_);
1163
1164 ComputeConstantRequest request;
1165 *request.mutable_computation() = computation_.handle();
1166 *request.mutable_operand() = operand;
1167 if (output_layout != nullptr) {
1168 *request.mutable_output_layout() = *output_layout;
1169 }
1170 for (const auto& param : parameters) {
1171 *request.add_parameters() = param.ToProto();
1172 }
1173
1174 ComputeConstantResponse response;
1175
1176 VLOG(2) << "making compute-constant request";
1177 Status s = client_->stub()->ComputeConstant(&request, &response);
1178 VLOG(2) << "done with request";
1179
1180 if (!s.ok()) {
1181 return s;
1182 }
1183
1184 VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}";
1185
1186 if (!response.has_literal()) {
1187 return InternalError(
1188 "no computed literal in the provided response in ComputeConstant "
1189 "request");
1190 }
1191 return Literal::CreateFromProto(response.literal());
1192 }
1193
Map(tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,const Computation & computation,tensorflow::gtl::ArraySlice<int64> dimensions,tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands)1194 ComputationDataHandle ComputationBuilder::Map(
1195 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
1196 const Computation& computation,
1197 tensorflow::gtl::ArraySlice<int64> dimensions,
1198 tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands) {
1199 OpRequest op_request;
1200 MapRequest* request = op_request.mutable_map_request();
1201 for (const ComputationDataHandle& operand : operands) {
1202 *request->add_operands() = operand;
1203 }
1204 *request->mutable_to_apply() = computation.handle();
1205 for (int64 dimension : dimensions) {
1206 request->add_dimensions(dimension);
1207 }
1208 for (const ComputationDataHandle& sop : static_operands) {
1209 *request->add_static_operands() = sop;
1210 }
1211 return RunOpAndParseResponse(&op_request);
1212 }
1213
RngNormal(const ComputationDataHandle & mu,const ComputationDataHandle & sigma,const Shape & shape)1214 ComputationDataHandle ComputationBuilder::RngNormal(
1215 const ComputationDataHandle& mu, const ComputationDataHandle& sigma,
1216 const Shape& shape) {
1217 return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
1218 }
1219
RngUniform(const ComputationDataHandle & a,const ComputationDataHandle & b,const Shape & shape)1220 ComputationDataHandle ComputationBuilder::RngUniform(
1221 const ComputationDataHandle& a, const ComputationDataHandle& b,
1222 const Shape& shape) {
1223 return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
1224 }
1225
While(const Computation & condition,const Computation & body,const ComputationDataHandle & init)1226 ComputationDataHandle ComputationBuilder::While(
1227 const Computation& condition, const Computation& body,
1228 const ComputationDataHandle& init) {
1229 OpRequest op_request;
1230 WhileRequest* request = op_request.mutable_while_request();
1231 *request->mutable_condition() = condition.handle();
1232 *request->mutable_body() = body.handle();
1233 *request->mutable_init() = init;
1234 return RunOpAndParseResponse(&op_request);
1235 }
1236
Gather(const ComputationDataHandle & input,const ComputationDataHandle & gather_indices,const GatherDimensionNumbers & dimension_numbers,tensorflow::gtl::ArraySlice<int64> window_bounds)1237 ComputationDataHandle ComputationBuilder::Gather(
1238 const ComputationDataHandle& input,
1239 const ComputationDataHandle& gather_indices,
1240 const GatherDimensionNumbers& dimension_numbers,
1241 tensorflow::gtl::ArraySlice<int64> window_bounds) {
1242 OpRequest op_request;
1243 GatherRequest* gather_request = op_request.mutable_gather_request();
1244 *gather_request->mutable_input() = input;
1245 *gather_request->mutable_gather_indices() = gather_indices;
1246 *gather_request->mutable_dimension_numbers() = dimension_numbers;
1247 for (int64 window_bound : window_bounds) {
1248 gather_request->add_window_bounds(window_bound);
1249 }
1250 return RunOpAndParseResponse(&op_request);
1251 }
1252
Conditional(const ComputationDataHandle & predicate,const ComputationDataHandle & true_operand,const Computation & true_computation,const ComputationDataHandle & false_operand,const Computation & false_computation)1253 ComputationDataHandle ComputationBuilder::Conditional(
1254 const ComputationDataHandle& predicate,
1255 const ComputationDataHandle& true_operand,
1256 const Computation& true_computation,
1257 const ComputationDataHandle& false_operand,
1258 const Computation& false_computation) {
1259 OpRequest op_request;
1260 ConditionalRequest* request = op_request.mutable_conditional_request();
1261 *request->mutable_predicate() = predicate;
1262 *request->mutable_true_operand() = true_operand;
1263 *request->mutable_true_computation() = true_computation.handle();
1264 *request->mutable_false_operand() = false_operand;
1265 *request->mutable_false_computation() = false_computation.handle();
1266 return RunOpAndParseResponse(&op_request);
1267 }
1268
Reduce(const ComputationDataHandle & operand,const ComputationDataHandle & init_value,const Computation & computation,tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce)1269 ComputationDataHandle ComputationBuilder::Reduce(
1270 const ComputationDataHandle& operand,
1271 const ComputationDataHandle& init_value, const Computation& computation,
1272 tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
1273 OpRequest op_request;
1274 ReduceRequest* request = op_request.mutable_reduce_request();
1275 *request->mutable_operand() = operand;
1276 *request->mutable_init_value() = init_value;
1277 for (int64 dimension : dimensions_to_reduce) {
1278 request->add_dimensions(dimension);
1279 }
1280 *request->mutable_to_apply() = computation.handle();
1281 return RunOpAndParseResponse(&op_request);
1282 }
1283
ReduceAll(const ComputationDataHandle & operand,const ComputationDataHandle & init_value,const Computation & computation)1284 ComputationDataHandle ComputationBuilder::ReduceAll(
1285 const ComputationDataHandle& operand,
1286 const ComputationDataHandle& init_value, const Computation& computation) {
1287 if (!first_error_.ok() || !PrepareComputation().ok()) {
1288 return ComputationDataHandle();
1289 }
1290
1291 StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
1292 if (!shape.ok()) {
1293 return ComputationDataHandle();
1294 }
1295
1296 std::vector<int64> all_dimnos(ShapeUtil::Rank(*shape.ValueOrDie()));
1297 std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
1298 return Reduce(operand, init_value, computation, all_dimnos);
1299 }
1300
ReduceWindow(const ComputationDataHandle & operand,const ComputationDataHandle & init_value,const Computation & computation,tensorflow::gtl::ArraySlice<int64> window_dimensions,tensorflow::gtl::ArraySlice<int64> window_strides,Padding padding)1301 ComputationDataHandle ComputationBuilder::ReduceWindow(
1302 const ComputationDataHandle& operand,
1303 const ComputationDataHandle& init_value, const Computation& computation,
1304 tensorflow::gtl::ArraySlice<int64> window_dimensions,
1305 tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
1306 if (!first_error_.ok()) {
1307 return ComputationDataHandle();
1308 }
1309
1310 StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
1311 if (!shape.ok()) {
1312 return ComputationDataHandle();
1313 }
1314
1315 Status padding_valid =
1316 ValidatePaddingValues(AsInt64Slice(shape.ValueOrDie()->dimensions()),
1317 window_dimensions, window_strides);
1318 if (!padding_valid.ok()) {
1319 first_error_ = padding_valid;
1320 return ComputationDataHandle();
1321 }
1322
1323 std::vector<std::pair<int64, int64>> padding_values =
1324 MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()),
1325 window_dimensions, window_strides, padding);
1326 return ReduceWindowWithGeneralPadding(operand, init_value, computation,
1327 window_dimensions, window_strides,
1328 padding_values);
1329 }
1330
ReduceWindowWithGeneralPadding(const ComputationDataHandle & operand,const ComputationDataHandle & init_value,const Computation & computation,tensorflow::gtl::ArraySlice<int64> window_dimensions,tensorflow::gtl::ArraySlice<int64> window_strides,tensorflow::gtl::ArraySlice<std::pair<int64,int64>> padding)1331 ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding(
1332 const ComputationDataHandle& operand,
1333 const ComputationDataHandle& init_value, const Computation& computation,
1334 tensorflow::gtl::ArraySlice<int64> window_dimensions,
1335 tensorflow::gtl::ArraySlice<int64> window_strides,
1336 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
1337 OpRequest op_request;
1338 ReduceWindowRequest* request = op_request.mutable_reduce_window_request();
1339 *request->mutable_operand() = operand;
1340 *request->mutable_to_apply() = computation.handle();
1341 *request->mutable_init_value() = init_value;
1342
1343 if (!MakeWindow(window_dimensions, window_strides, padding, {}, {},
1344 request->mutable_window())) {
1345 NoteError(InternalError("failed to make window"));
1346 return ComputationDataHandle();
1347 }
1348
1349 return RunOpAndParseResponse(&op_request);
1350 }
1351
BatchNormTraining(const ComputationDataHandle & operand,const ComputationDataHandle & scale,const ComputationDataHandle & offset,float epsilon,int64 feature_index)1352 ComputationDataHandle ComputationBuilder::BatchNormTraining(
1353 const ComputationDataHandle& operand, const ComputationDataHandle& scale,
1354 const ComputationDataHandle& offset, float epsilon, int64 feature_index) {
1355 OpRequest op_request;
1356 BatchNormTrainingRequest* request =
1357 op_request.mutable_batch_norm_training_request();
1358 *request->mutable_operand() = operand;
1359 *request->mutable_scale() = scale;
1360 *request->mutable_offset() = offset;
1361 request->set_epsilon(epsilon);
1362 request->set_feature_index(feature_index);
1363 return RunOpAndParseResponse(&op_request);
1364 }
1365
BatchNormInference(const ComputationDataHandle & operand,const ComputationDataHandle & scale,const ComputationDataHandle & offset,const ComputationDataHandle & mean,const ComputationDataHandle & variance,float epsilon,int64 feature_index)1366 ComputationDataHandle ComputationBuilder::BatchNormInference(
1367 const ComputationDataHandle& operand, const ComputationDataHandle& scale,
1368 const ComputationDataHandle& offset, const ComputationDataHandle& mean,
1369 const ComputationDataHandle& variance, float epsilon, int64 feature_index) {
1370 OpRequest op_request;
1371 BatchNormInferenceRequest* request =
1372 op_request.mutable_batch_norm_inference_request();
1373 *request->mutable_operand() = operand;
1374 *request->mutable_scale() = scale;
1375 *request->mutable_offset() = offset;
1376 *request->mutable_mean() = mean;
1377 *request->mutable_variance() = variance;
1378 request->set_epsilon(epsilon);
1379 request->set_feature_index(feature_index);
1380 return RunOpAndParseResponse(&op_request);
1381 }
1382
BatchNormGrad(const ComputationDataHandle & operand,const ComputationDataHandle & scale,const ComputationDataHandle & mean,const ComputationDataHandle & var,const ComputationDataHandle & grad_output,float epsilon,int64 feature_index)1383 ComputationDataHandle ComputationBuilder::BatchNormGrad(
1384 const ComputationDataHandle& operand, const ComputationDataHandle& scale,
1385 const ComputationDataHandle& mean, const ComputationDataHandle& var,
1386 const ComputationDataHandle& grad_output, float epsilon,
1387 int64 feature_index) {
1388 OpRequest op_request;
1389 BatchNormGradRequest* request = op_request.mutable_batch_norm_grad_request();
1390 *request->mutable_operand() = operand;
1391 *request->mutable_scale() = scale;
1392 *request->mutable_mean() = mean;
1393 *request->mutable_variance() = var;
1394 *request->mutable_grad_output() = grad_output;
1395 request->set_epsilon(epsilon);
1396 request->set_feature_index(feature_index);
1397 return RunOpAndParseResponse(&op_request);
1398 }
1399
CrossReplicaSum(const ComputationDataHandle & operand)1400 ComputationDataHandle ComputationBuilder::CrossReplicaSum(
1401 const ComputationDataHandle& operand) {
1402 OpRequest op_request;
1403 CrossReplicaSumRequest* request =
1404 op_request.mutable_cross_replica_sum_request();
1405 *request->mutable_operand() = operand;
1406 return RunOpAndParseResponse(&op_request);
1407 }
1408
SelectAndScatter(const ComputationDataHandle & operand,const Computation & select,tensorflow::gtl::ArraySlice<int64> window_dimensions,tensorflow::gtl::ArraySlice<int64> window_strides,Padding padding,const ComputationDataHandle & source,const ComputationDataHandle & init_value,const Computation & scatter)1409 ComputationDataHandle ComputationBuilder::SelectAndScatter(
1410 const ComputationDataHandle& operand, const Computation& select,
1411 tensorflow::gtl::ArraySlice<int64> window_dimensions,
1412 tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
1413 const ComputationDataHandle& source,
1414 const ComputationDataHandle& init_value, const Computation& scatter) {
1415 if (!first_error_.ok()) {
1416 return ComputationDataHandle();
1417 }
1418
1419 StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
1420 if (!shape.ok()) {
1421 return ComputationDataHandle();
1422 }
1423 return SelectAndScatterWithGeneralPadding(
1424 operand, select, window_dimensions, window_strides,
1425 MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()),
1426 window_dimensions, window_strides, padding),
1427 source, init_value, scatter);
1428 }
1429
SelectAndScatterWithGeneralPadding(const ComputationDataHandle & operand,const Computation & select,tensorflow::gtl::ArraySlice<int64> window_dimensions,tensorflow::gtl::ArraySlice<int64> window_strides,tensorflow::gtl::ArraySlice<std::pair<int64,int64>> padding,const ComputationDataHandle & source,const ComputationDataHandle & init_value,const Computation & scatter)1430 ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding(
1431 const ComputationDataHandle& operand, const Computation& select,
1432 tensorflow::gtl::ArraySlice<int64> window_dimensions,
1433 tensorflow::gtl::ArraySlice<int64> window_strides,
1434 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
1435 const ComputationDataHandle& source,
1436 const ComputationDataHandle& init_value, const Computation& scatter) {
1437 OpRequest op_request;
1438 SelectAndScatterRequest* request =
1439 op_request.mutable_select_and_scatter_request();
1440 *request->mutable_operand() = operand;
1441 *request->mutable_select() = select.handle();
1442 *request->mutable_source() = source;
1443 *request->mutable_init_value() = init_value;
1444 *request->mutable_scatter() = scatter.handle();
1445
1446 if (!MakeWindow(window_dimensions, window_strides, padding, {}, {},
1447 request->mutable_window())) {
1448 NoteError(InternalError("failed to make window"));
1449 return ComputationDataHandle();
1450 }
1451
1452 return RunOpAndParseResponse(&op_request);
1453 }
1454
ReducePrecision(const ComputationDataHandle & operand,const int exponent_bits,const int mantissa_bits)1455 ComputationDataHandle ComputationBuilder::ReducePrecision(
1456 const ComputationDataHandle& operand, const int exponent_bits,
1457 const int mantissa_bits) {
1458 OpRequest op_request;
1459 ReducePrecisionRequest* request =
1460 op_request.mutable_reduce_precision_request();
1461 *request->mutable_operand() = operand;
1462 request->set_exponent_bits(exponent_bits);
1463 request->set_mantissa_bits(mantissa_bits);
1464 return RunOpAndParseResponse(&op_request);
1465 }
1466
Send(const ComputationDataHandle & operand,const ChannelHandle & handle)1467 void ComputationBuilder::Send(const ComputationDataHandle& operand,
1468 const ChannelHandle& handle) {
1469 OpRequest op_request;
1470 SendRequest* request = op_request.mutable_send_request();
1471 *request->mutable_operand() = operand;
1472 *request->mutable_channel_handle() = handle;
1473 *op_request.mutable_computation() = computation_.handle();
1474 RunOpAndNoteError(&op_request);
1475 }
1476
Recv(const Shape & shape,const ChannelHandle & handle)1477 ComputationDataHandle ComputationBuilder::Recv(const Shape& shape,
1478 const ChannelHandle& handle) {
1479 OpRequest op_request;
1480 RecvRequest* request = op_request.mutable_recv_request();
1481 *request->mutable_shape() = shape;
1482 *request->mutable_channel_handle() = handle;
1483 return RunOpAndParseResponse(&op_request);
1484 }
1485
BuildAndNoteError()1486 Computation ComputationBuilder::BuildAndNoteError() {
1487 DCHECK(parent_builder_ != nullptr);
1488 auto build_status = Build();
1489 if (!build_status.ok()) {
1490 parent_builder_->NoteError(
1491 AddStatus(build_status.status(),
1492 tensorflow::strings::StrCat("error from: ", name_)));
1493 return Computation();
1494 }
1495 return build_status.ConsumeValueOrDie();
1496 }
1497
Build()1498 StatusOr<Computation> ComputationBuilder::Build() {
1499 if (!first_error_.ok()) {
1500 string backtrace;
1501 first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
1502 return AppendStatus(first_error_, backtrace);
1503 }
1504
1505 if (computation_.IsNull()) {
1506 return FailedPrecondition("no computation was built");
1507 }
1508
1509 return {std::move(computation_)};
1510 }
1511
1512 /* static */ ConvolutionDimensionNumbers
CreateDefaultConvDimensionNumbers(int num_spatial_dims)1513 ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
1514 ConvolutionDimensionNumbers dimension_numbers;
1515 dimension_numbers.set_input_batch_dimension(kConvBatchDimension);
1516 dimension_numbers.set_input_feature_dimension(kConvFeatureDimension);
1517 dimension_numbers.set_output_batch_dimension(kConvBatchDimension);
1518 dimension_numbers.set_output_feature_dimension(kConvFeatureDimension);
1519 dimension_numbers.set_kernel_output_feature_dimension(
1520 kConvKernelOutputDimension);
1521 dimension_numbers.set_kernel_input_feature_dimension(
1522 kConvKernelInputDimension);
1523 for (int i = 0; i < num_spatial_dims; ++i) {
1524 dimension_numbers.add_input_spatial_dimensions(i + 2);
1525 dimension_numbers.add_kernel_spatial_dimensions(i + 2);
1526 dimension_numbers.add_output_spatial_dimensions(i + 2);
1527 }
1528 return dimension_numbers;
1529 }
1530
1531 /* static */ StatusOr<ConvolutionDimensionNumbers>
CreateConvDimensionNumbers(int64 input_batch,int64 input_feature,int64 input_first_spatial,int64 input_second_spatial,int64 output_batch,int64 output_feature,int64 output_first_spatial,int64 output_second_spatial,int64 kernel_output_feature,int64 kernel_input_feature,int64 kernel_first_spatial,int64 kernel_second_spatial)1532 ComputationBuilder::CreateConvDimensionNumbers(
1533 int64 input_batch, int64 input_feature, int64 input_first_spatial,
1534 int64 input_second_spatial, int64 output_batch, int64 output_feature,
1535 int64 output_first_spatial, int64 output_second_spatial,
1536 int64 kernel_output_feature, int64 kernel_input_feature,
1537 int64 kernel_first_spatial, int64 kernel_second_spatial) {
1538 if (std::set<int64>({input_batch, input_feature, input_first_spatial,
1539 input_second_spatial})
1540 .size() != 4) {
1541 return FailedPrecondition(
1542 "dimension numbers for the input are not unique: (%lld, %lld, %lld, "
1543 "%lld)",
1544 input_batch, input_feature, input_first_spatial, input_second_spatial);
1545 }
1546 if (std::set<int64>({kernel_output_feature, kernel_input_feature,
1547 kernel_first_spatial, kernel_second_spatial})
1548 .size() != 4) {
1549 return FailedPrecondition(
1550 "dimension numbers for the weight are not unique: (%lld, %lld, %lld, "
1551 "%lld)",
1552 kernel_output_feature, kernel_input_feature, kernel_first_spatial,
1553 kernel_second_spatial);
1554 }
1555 if (std::set<int64>({output_batch, output_feature, output_first_spatial,
1556 output_second_spatial})
1557 .size() != 4) {
1558 return FailedPrecondition(
1559 "dimension numbers for the output are not unique: (%lld, %lld, %lld, "
1560 "%lld)",
1561 output_batch, output_feature, output_first_spatial,
1562 output_second_spatial);
1563 }
1564 ConvolutionDimensionNumbers dimension_numbers;
1565 dimension_numbers.set_input_batch_dimension(input_batch);
1566 dimension_numbers.set_input_feature_dimension(input_feature);
1567 dimension_numbers.add_input_spatial_dimensions(input_first_spatial);
1568 dimension_numbers.add_input_spatial_dimensions(input_second_spatial);
1569 dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature);
1570 dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature);
1571 dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial);
1572 dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial);
1573 dimension_numbers.set_output_batch_dimension(output_batch);
1574 dimension_numbers.set_output_feature_dimension(output_feature);
1575 dimension_numbers.add_output_spatial_dimensions(output_first_spatial);
1576 dimension_numbers.add_output_spatial_dimensions(output_second_spatial);
1577 return dimension_numbers;
1578 }
1579
1580 } // namespace xla
1581