1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/xla_builder.h"
17
18 #include <functional>
19 #include <numeric>
20 #include <queue>
21 #include <string>
22 #include <utility>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/memory/memory.h"
28 #include "absl/strings/match.h"
29 #include "absl/strings/numbers.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/client/sharding_builder.h"
34 #include "tensorflow/compiler/xla/client/xla_computation.h"
35 #include "tensorflow/compiler/xla/comparison_util.h"
36 #include "tensorflow/compiler/xla/execution_options_util.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/permutation_util.h"
39 #include "tensorflow/compiler/xla/primitive_util.h"
40 #include "tensorflow/compiler/xla/service/hlo.pb.h"
41 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
42 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
43 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
44 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
45 #include "tensorflow/compiler/xla/service/shape_inference.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/compiler/xla/window_util.h"
49 #include "tensorflow/compiler/xla/xla_data.pb.h"
50 #include "tensorflow/core/platform/errors.h"
51 #include "tensorflow/core/platform/macros.h"
52 #include "tensorflow/stream_executor/lib/statusor.h"
53
54 namespace xla {
55
56 using absl::StrCat;
57
58 namespace {
59
60 static const char kNameSeparator = '.';
61
62 // Retrieves the base name of an instruction or computation fully qualified
63 // name, using separator as boundary between the initial base name part, and
64 // the numeric identification.
GetBaseName(const string & name,char separator)65 string GetBaseName(const string& name, char separator) {
66 auto pos = name.rfind(separator);
67 CHECK_NE(pos, string::npos) << name;
68 return name.substr(0, pos);
69 }
70
71 // Generates a fully qualified computation/instruction name.
GetFullName(const string & base_name,char separator,int64 id)72 string GetFullName(const string& base_name, char separator, int64 id) {
73 const char separator_str[] = {separator, '\0'};
74 return StrCat(base_name, separator_str, id);
75 }
76
77 // Common function to standardize setting name and IDs on computation and
78 // instruction proto entities.
79 template <typename T>
SetProtoIdAndName(T * entry,const string & base_name,char separator,int64 id)80 void SetProtoIdAndName(T* entry, const string& base_name, char separator,
81 int64 id) {
82 entry->set_id(id);
83 entry->set_name(GetFullName(base_name, separator, id));
84 }
85
ConvertShapeProtoToPred(const ShapeProto & shape_proto)86 ShapeProto ConvertShapeProtoToPred(const ShapeProto& shape_proto) {
87 return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto();
88 }
89
SetInstructionAsConstant(HloInstructionProto * instr,int64 id,const Shape & shape,bool pred)90 void SetInstructionAsConstant(HloInstructionProto* instr, int64 id,
91 const Shape& shape, bool pred) {
92 Literal literal = LiteralUtil::CreateR0(pred);
93 Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie();
94 *instr->mutable_shape() = shape.ToProto();
95 *instr->mutable_literal() = literal_broadcast.ToProto();
96 *instr->mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
97 }
98
99 // Copy `original_reducer` into a new computation proto with `reducer_id` as new
100 // id. If `rewrite_into_pred` is true, the instructions in the reducer are
101 // rewritten into predicate form.
CopyReducer(int64 reducer_id,HloComputationProto * original_reducer,bool rewrite_into_pred,int64 * global_id)102 HloComputationProto CopyReducer(int64 reducer_id,
103 HloComputationProto* original_reducer,
104 bool rewrite_into_pred, int64* global_id) {
105 HloComputationProto reducer;
106 SetProtoIdAndName(&reducer, StrCat("reduce_or"), kNameSeparator, reducer_id);
107 std::vector<int64> operands_id;
108 for (auto& inst : original_reducer->instructions()) {
109 // Copy params.
110 if (StringToHloOpcode(inst.opcode()).ValueOrDie() ==
111 HloOpcode::kParameter) {
112 HloInstructionProto* new_param = reducer.add_instructions();
113 *new_param = inst;
114 new_param->set_id((*global_id)++);
115 *new_param->mutable_name() =
116 GetFullName(inst.name(), '.', new_param->id());
117 if (rewrite_into_pred) {
118 *new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
119 }
120 operands_id.push_back(new_param->id());
121 }
122 if (inst.id() == original_reducer->root_id()) {
123 HloInstructionProto* new_root = reducer.add_instructions();
124 *new_root = inst;
125 new_root->set_id((*global_id)++);
126 *new_root->mutable_name() = GetFullName(inst.name(), '.', new_root->id());
127 if (rewrite_into_pred) {
128 *new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
129 *new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
130 }
131 new_root->clear_operand_ids();
132 for (int64 operand_id : operands_id) {
133 new_root->add_operand_ids(operand_id);
134 }
135 reducer.set_root_id(new_root->id());
136 }
137 }
138 return reducer;
139 }
140
InstrIsSetBound(const HloInstructionProto * instr_proto)141 bool InstrIsSetBound(const HloInstructionProto* instr_proto) {
142 HloOpcode opcode = StringToHloOpcode(instr_proto->opcode()).ValueOrDie();
143 if (opcode == HloOpcode::kCustomCall &&
144 instr_proto->custom_call_target() == "SetBound") {
145 return true;
146 }
147 return false;
148 }
149
150 } // namespace
151
152 namespace internal {
153
BuildFusion(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::string_view fusion_kind,const XlaComputation & fused_computation)154 XlaOp XlaBuilderFriend::BuildFusion(XlaBuilder* builder,
155 absl::Span<const XlaOp> operands,
156 absl::string_view fusion_kind,
157 const XlaComputation& fused_computation) {
158 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
159 HloInstructionProto instr;
160 instr.set_fusion_kind(std::string(fusion_kind));
161 std::vector<const Shape*> operand_shape_ptrs;
162 TF_ASSIGN_OR_RETURN(auto program_shape,
163 fused_computation.GetProgramShape());
164 *instr.mutable_shape() = program_shape.result().ToProto();
165 builder->AddCalledComputation(fused_computation, &instr);
166 return builder->AddInstruction(std::move(instr), HloOpcode::kFusion,
167 operands);
168 });
169 }
170
BuildBitcast(XlaBuilder * builder,XlaOp operand,const Shape & shape)171 XlaOp XlaBuilderFriend::BuildBitcast(XlaBuilder* builder, XlaOp operand,
172 const Shape& shape) {
173 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
174 HloInstructionProto instr;
175 *instr.mutable_shape() = shape.ToProto();
176 return builder->AddInstruction(std::move(instr), HloOpcode::kBitcast,
177 {operand});
178 });
179 }
180
GetInstruction(XlaOp op)181 HloInstructionProto* XlaBuilderFriend::GetInstruction(XlaOp op) {
182 return &op.builder()
183 ->instructions_[op.builder()->handle_to_index_[op.handle_]];
184 }
185
186 } // namespace internal
187
operator -(XlaOp x)188 XlaOp operator-(XlaOp x) { return Neg(x); }
operator +(XlaOp x,XlaOp y)189 XlaOp operator+(XlaOp x, XlaOp y) { return Add(x, y); }
operator -(XlaOp x,XlaOp y)190 XlaOp operator-(XlaOp x, XlaOp y) { return Sub(x, y); }
operator *(XlaOp x,XlaOp y)191 XlaOp operator*(XlaOp x, XlaOp y) { return Mul(x, y); }
operator /(XlaOp x,XlaOp y)192 XlaOp operator/(XlaOp x, XlaOp y) { return Div(x, y); }
operator %(XlaOp x,XlaOp y)193 XlaOp operator%(XlaOp x, XlaOp y) { return Rem(x, y); }
194
operator ~(XlaOp x)195 XlaOp operator~(XlaOp x) { return Not(x); }
operator &(XlaOp x,XlaOp y)196 XlaOp operator&(XlaOp x, XlaOp y) { return And(x, y); }
operator |(XlaOp x,XlaOp y)197 XlaOp operator|(XlaOp x, XlaOp y) { return Or(x, y); }
operator ^(XlaOp x,XlaOp y)198 XlaOp operator^(XlaOp x, XlaOp y) { return Xor(x, y); }
operator <<(XlaOp x,XlaOp y)199 XlaOp operator<<(XlaOp x, XlaOp y) { return ShiftLeft(x, y); }
200
operator >>(XlaOp x,XlaOp y)201 XlaOp operator>>(XlaOp x, XlaOp y) {
202 XlaBuilder* builder = x.builder();
203 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
204 TF_ASSIGN_OR_RETURN(const xla::Shape* shape, builder->GetShapePtr(x));
205 if (!ShapeUtil::ElementIsIntegral(*shape)) {
206 return InvalidArgument(
207 "Argument to >> operator does not have an integral type (%s).",
208 ShapeUtil::HumanString(*shape));
209 }
210 if (ShapeUtil::ElementIsSigned(*shape)) {
211 return ShiftRightArithmetic(x, y);
212 } else {
213 return ShiftRightLogical(x, y);
214 }
215 });
216 }
217
GetShapePtr(XlaOp op) const218 StatusOr<const Shape*> XlaBuilder::GetShapePtr(XlaOp op) const {
219 TF_RETURN_IF_ERROR(first_error_);
220 TF_RETURN_IF_ERROR(CheckOpBuilder(op));
221 auto it = handle_to_index_.find(op.handle());
222 if (it == handle_to_index_.end()) {
223 return InvalidArgument("No XlaOp with handle %d", op.handle());
224 }
225 return instruction_shapes_.at(it->second).get();
226 }
227
GetShape(XlaOp op) const228 StatusOr<Shape> XlaBuilder::GetShape(XlaOp op) const {
229 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(op));
230 return *shape;
231 }
232
GetOperandShapes(absl::Span<const XlaOp> operands) const233 StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
234 absl::Span<const XlaOp> operands) const {
235 std::vector<Shape> operand_shapes;
236 for (XlaOp operand : operands) {
237 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
238 operand_shapes.push_back(*shape);
239 }
240 return operand_shapes;
241 }
242
XlaBuilder(const string & computation_name)243 XlaBuilder::XlaBuilder(const string& computation_name)
244 : name_(computation_name) {}
245
~XlaBuilder()246 XlaBuilder::~XlaBuilder() {}
247
ReportError(const Status & error)248 XlaOp XlaBuilder::ReportError(const Status& error) {
249 CHECK(!error.ok());
250 if (die_immediately_on_error_) {
251 LOG(FATAL) << "error building computation: " << error;
252 }
253
254 if (first_error_.ok()) {
255 first_error_ = error;
256 first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
257 }
258 return XlaOp(this);
259 }
260
ReportErrorOrReturn(const StatusOr<XlaOp> & op)261 XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr<XlaOp>& op) {
262 if (!first_error_.ok()) {
263 return XlaOp(this);
264 }
265 if (!op.ok()) {
266 return ReportError(op.status());
267 }
268 return op.ValueOrDie();
269 }
270
ReportErrorOrReturn(const std::function<StatusOr<XlaOp> ()> & op_creator)271 XlaOp XlaBuilder::ReportErrorOrReturn(
272 const std::function<StatusOr<XlaOp>()>& op_creator) {
273 return ReportErrorOrReturn(op_creator());
274 }
275
GetProgramShape(int64 root_id) const276 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
277 TF_RETURN_IF_ERROR(first_error_);
278 TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
279 LookUpInstructionByHandle(root_id));
280
281 ProgramShape program_shape;
282
283 *program_shape.mutable_result() = Shape(root_proto->shape());
284
285 // Check that the parameter numbers are continuous from 0, and add parameter
286 // shapes and names to the program shape.
287 const int64 param_count = parameter_numbers_.size();
288 for (int64 i = 0; i < param_count; i++) {
289 program_shape.add_parameters();
290 program_shape.add_parameter_names();
291 }
292 for (const HloInstructionProto& instr : instructions_) {
293 // Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So
294 // to verify continuity, we just need to verify that every parameter is in
295 // the right range.
296 if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
297 const int64 index = instr.parameter_number();
298 TF_RET_CHECK(index >= 0 && index < param_count)
299 << "invalid parameter number: " << index;
300 *program_shape.mutable_parameters(index) = Shape(instr.shape());
301 *program_shape.mutable_parameter_names(index) = instr.name();
302 }
303 }
304 return program_shape;
305 }
306
GetProgramShape() const307 StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const {
308 TF_RET_CHECK(!instructions_.empty());
309 return GetProgramShape(instructions_.back().id());
310 }
311
GetProgramShape(XlaOp root) const312 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(XlaOp root) const {
313 if (root.builder_ != this) {
314 return InvalidArgument("Given root operation is not in this computation.");
315 }
316 return GetProgramShape(root.handle());
317 }
318
IsConstantVisitor(const int64 op_handle,absl::flat_hash_set<int64> * visited,bool * is_constant) const319 void XlaBuilder::IsConstantVisitor(const int64 op_handle,
320 absl::flat_hash_set<int64>* visited,
321 bool* is_constant) const {
322 if (visited->contains(op_handle) || !*is_constant) {
323 return;
324 }
325
326 const HloInstructionProto& instr =
327 *(LookUpInstructionByHandle(op_handle).ValueOrDie());
328 const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
329 switch (opcode) {
330 default:
331 for (const int64 operand_id : instr.operand_ids()) {
332 IsConstantVisitor(operand_id, visited, is_constant);
333 }
334 // TODO(b/32495713): We aren't checking the called computations.
335 break;
336
337 case HloOpcode::kGetDimensionSize:
338 // GetDimensionSize is always considered constant in XLA -- If a dynamic
339 // dimension is presented, -1 is returned.
340 break;
341 // Non functional ops.
342 case HloOpcode::kRng:
343 case HloOpcode::kAllReduce:
344 // TODO(b/33009255): Implement constant folding for cross replica sum.
345 case HloOpcode::kInfeed:
346 case HloOpcode::kOutfeed:
347 case HloOpcode::kCall:
348 // TODO(b/32495713): We aren't checking the to_apply computation itself,
349 // so we conservatively say that computations containing the Call op
350 // cannot be constant. We cannot set is_functional=false in other similar
351 // cases since we're already relying on IsConstant to return true.
352 case HloOpcode::kCustomCall:
353 if (instr.custom_call_target() == "SetBound") {
354 // Set bound is considered constant -- the bound is used as the value.
355 break;
356 }
357 TF_FALLTHROUGH_INTENDED;
358 case HloOpcode::kWhile:
359 // TODO(b/32495713): We aren't checking the condition and body
360 // computations themselves.
361 case HloOpcode::kScatter:
362 // TODO(b/32495713): We aren't checking the embedded computation in
363 // Scatter.
364 case HloOpcode::kSend:
365 case HloOpcode::kRecv:
366 case HloOpcode::kParameter:
367 *is_constant = false;
368 break;
369 }
370 if (!*is_constant) {
371 VLOG(1) << "Non-constant: " << instr.name();
372 }
373 visited->insert(op_handle);
374 }
375
SetDynamicBinding(int64 dynamic_size_param_num,ShapeIndex dynamic_size_param_index,int64 target_param_num,ShapeIndex target_param_index,int64 target_dim_num)376 Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num,
377 ShapeIndex dynamic_size_param_index,
378 int64 target_param_num,
379 ShapeIndex target_param_index,
380 int64 target_dim_num) {
381 bool param_exists = false;
382 for (size_t index = 0; index < instructions_.size(); ++index) {
383 HloInstructionProto& instr = instructions_[index];
384 if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
385 instr.parameter_number() == target_param_num) {
386 param_exists = true;
387 Shape param_shape(instr.shape());
388 Shape* param_shape_ptr = ¶m_shape;
389 for (int64 index : target_param_index) {
390 param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index);
391 }
392 param_shape_ptr->set_dynamic_dimension(target_dim_num,
393 /*is_dynamic=*/true);
394 *instr.mutable_shape() = param_shape.ToProto();
395 instruction_shapes_[index] =
396 absl::make_unique<Shape>(std::move(param_shape));
397 }
398 }
399 if (!param_exists) {
400 return InvalidArgument(
401 "Asked to mark parameter %lld as dynamic sized parameter, but the "
402 "doesn't exists",
403 target_param_num);
404 }
405
406 TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind(
407 DynamicParameterBinding::DynamicParameter{dynamic_size_param_num,
408 dynamic_size_param_index},
409 DynamicParameterBinding::DynamicDimension{
410 target_param_num, target_param_index, target_dim_num}));
411 return Status::OK();
412 }
413
SetInstructionFrontendAttribute(const XlaOp op,std::string attribute,std::string value)414 Status XlaBuilder::SetInstructionFrontendAttribute(const XlaOp op,
415 std::string attribute,
416 std::string value) {
417 TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op));
418 auto* frontend_attributes = instr_proto->mutable_frontend_attributes();
419 (*frontend_attributes->mutable_map())[attribute] = std::move(value);
420 return Status::OK();
421 }
422
BuildAndNoteError()423 XlaComputation XlaBuilder::BuildAndNoteError() {
424 DCHECK(parent_builder_ != nullptr);
425 auto build_status = Build();
426 if (!build_status.ok()) {
427 parent_builder_->ReportError(
428 AddStatus(build_status.status(), absl::StrCat("error from: ", name_)));
429 return {};
430 }
431 return build_status.ConsumeValueOrDie();
432 }
433
GetCurrentStatus() const434 Status XlaBuilder::GetCurrentStatus() const {
435 if (!first_error_.ok()) {
436 string backtrace;
437 first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
438 return AppendStatus(first_error_, backtrace);
439 }
440 return Status::OK();
441 }
442
Build(bool remove_dynamic_dimensions)443 StatusOr<XlaComputation> XlaBuilder::Build(bool remove_dynamic_dimensions) {
444 TF_RETURN_IF_ERROR(GetCurrentStatus());
445 return Build(instructions_.back().id(), remove_dynamic_dimensions);
446 }
447
Build(XlaOp root,bool remove_dynamic_dimensions)448 StatusOr<XlaComputation> XlaBuilder::Build(XlaOp root,
449 bool remove_dynamic_dimensions) {
450 if (root.builder_ != this) {
451 return InvalidArgument("Given root operation is not in this computation.");
452 }
453 return Build(root.handle(), remove_dynamic_dimensions);
454 }
455
Build(int64 root_id,bool remove_dynamic_dimensions)456 StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id,
457 bool remove_dynamic_dimensions) {
458 TF_RETURN_IF_ERROR(GetCurrentStatus());
459
460 // TODO(b/121223198): XLA backend cannot handle dynamic dimensions yet, remove
461 // all dynamic dimensions before building xla program until we have support in
462 // the backend.
463 if (remove_dynamic_dimensions) {
464 std::function<void(Shape*)> remove_dynamic_dimension = [&](Shape* shape) {
465 if (shape->tuple_shapes_size() != 0) {
466 for (int64 i = 0; i < shape->tuple_shapes_size(); ++i) {
467 remove_dynamic_dimension(shape->mutable_tuple_shapes(i));
468 }
469 }
470 for (int64 i = 0; i < shape->dimensions_size(); ++i) {
471 shape->set_dynamic_dimension(i, false);
472 }
473 };
474 for (size_t index = 0; index < instructions_.size(); ++index) {
475 remove_dynamic_dimension(instruction_shapes_[index].get());
476 *instructions_[index].mutable_shape() =
477 instruction_shapes_[index]->ToProto();
478 }
479 }
480
481 HloComputationProto entry;
482 SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId());
483 TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id));
484 *entry.mutable_program_shape() = program_shape.ToProto();
485 entry.set_root_id(root_id);
486
487 for (auto& instruction : instructions_) {
488 // Ensures that the instruction names are unique among the whole graph.
489 instruction.set_name(
490 GetFullName(instruction.name(), kNameSeparator, instruction.id()));
491 entry.add_instructions()->Swap(&instruction);
492 }
493
494 XlaComputation computation(entry.id());
495 HloModuleProto* module = computation.mutable_proto();
496 module->set_name(entry.name());
497 module->set_id(entry.id());
498 module->set_entry_computation_name(entry.name());
499 module->set_entry_computation_id(entry.id());
500 *module->mutable_host_program_shape() = entry.program_shape();
501 for (auto& e : embedded_) {
502 module->add_computations()->Swap(&e.second);
503 }
504 module->add_computations()->Swap(&entry);
505 if (!input_output_aliases_.empty()) {
506 TF_RETURN_IF_ERROR(
507 PopulateInputOutputAlias(module, program_shape, input_output_aliases_));
508 }
509 *(module->mutable_dynamic_parameter_binding()) =
510 dynamic_parameter_binding_.ToProto();
511
512 // Clear data held by this builder.
513 this->instructions_.clear();
514 this->instruction_shapes_.clear();
515 this->handle_to_index_.clear();
516 this->embedded_.clear();
517 this->parameter_numbers_.clear();
518
519 return std::move(computation);
520 }
521
PopulateInputOutputAlias(HloModuleProto * module,const ProgramShape & program_shape,const std::vector<InputOutputAlias> & input_output_aliases)522 /* static */ Status XlaBuilder::PopulateInputOutputAlias(
523 HloModuleProto* module, const ProgramShape& program_shape,
524 const std::vector<InputOutputAlias>& input_output_aliases) {
525 HloInputOutputAliasConfig config(program_shape.result());
526 for (auto& alias : input_output_aliases) {
527 // The HloInputOutputAliasConfig does not do parameter validation as it only
528 // carries the result shape. Maybe it should be constructed with a
529 // ProgramShape to allow full validation. We will still get an error when
530 // trying to compile the HLO module, but would be better to have validation
531 // at this stage.
532 if (alias.param_number >= program_shape.parameters_size()) {
533 return InvalidArgument("Invalid parameter number %ld (total %ld)",
534 alias.param_number,
535 program_shape.parameters_size());
536 }
537 const Shape& parameter_shape = program_shape.parameters(alias.param_number);
538 if (!ShapeUtil::IndexIsValid(parameter_shape, alias.param_index)) {
539 return InvalidArgument("Invalid parameter %ld index: %s",
540 alias.param_number,
541 alias.param_index.ToString().c_str());
542 }
543 TF_RETURN_IF_ERROR(config.SetUpAlias(alias.output_index, alias.param_number,
544 alias.param_index, alias.kind));
545 }
546 *module->mutable_input_output_alias() = config.ToProto();
547 return Status::OK();
548 }
549
InDimBroadcast(const Shape & shape,XlaOp operand,absl::Span<const int64> broadcast_dimensions)550 StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
551 const Shape& shape, XlaOp operand,
552 absl::Span<const int64> broadcast_dimensions) {
553 TF_RETURN_IF_ERROR(first_error_);
554
555 HloInstructionProto instr;
556 *instr.mutable_shape() = shape.ToProto();
557 for (int64 dim : broadcast_dimensions) {
558 instr.add_dimensions(dim);
559 }
560
561 return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand});
562 }
563
AddBroadcastSequence(const Shape & output_shape,XlaOp operand)564 StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
565 XlaOp operand) {
566 TF_RETURN_IF_ERROR(first_error_);
567
568 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
569
570 CHECK(ShapeUtil::IsScalar(*operand_shape) ||
571 operand_shape->rank() == output_shape.rank());
572 Shape broadcast_shape =
573 ShapeUtil::ChangeElementType(output_shape, operand_shape->element_type());
574
575 // Do explicit broadcast for scalar.
576 if (ShapeUtil::IsScalar(*operand_shape)) {
577 return InDimBroadcast(broadcast_shape, operand, {});
578 }
579
580 // Do explicit broadcast for degenerate broadcast.
581 std::vector<int64> broadcast_dimensions;
582 std::vector<int64> reshaped_dimensions;
583 for (int i = 0; i < operand_shape->rank(); i++) {
584 if (operand_shape->dimensions(i) == output_shape.dimensions(i)) {
585 broadcast_dimensions.push_back(i);
586 reshaped_dimensions.push_back(operand_shape->dimensions(i));
587 } else {
588 TF_RET_CHECK(operand_shape->dimensions(i) == 1)
589 << "An explicit broadcast sequence requires the broadcasted "
590 "dimensions to be trivial; operand shape: "
591 << *operand_shape << "; output_shape: " << output_shape;
592 }
593 }
594
595 Shape reshaped_shape =
596 ShapeUtil::MakeShape(operand_shape->element_type(), reshaped_dimensions);
597
598 std::vector<std::pair<int64, int64>> unmodified_dims =
599 ShapeUtil::DimensionsUnmodifiedByReshape(*operand_shape, reshaped_shape);
600
601 for (auto& unmodified : unmodified_dims) {
602 if (operand_shape->is_dynamic_dimension(unmodified.first)) {
603 reshaped_shape.set_dynamic_dimension(unmodified.second, true);
604 }
605 }
606
607 // Eliminate the size one dimensions.
608 TF_ASSIGN_OR_RETURN(
609 XlaOp reshaped_operand,
610 ReshapeInternal(reshaped_shape, operand, /*inferred_dimension=*/-1));
611 // Broadcast 'reshape' up to the larger size.
612 return InDimBroadcast(broadcast_shape, reshaped_operand,
613 broadcast_dimensions);
614 }
615
UnaryOp(HloOpcode unop,XlaOp operand)616 XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
617 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
618 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
619 TF_ASSIGN_OR_RETURN(
620 Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape));
621 return AddOpWithShape(unop, shape, {operand});
622 });
623 }
624
BinaryOp(HloOpcode binop,XlaOp lhs,XlaOp rhs,absl::Span<const int64> broadcast_dimensions,absl::optional<ComparisonDirection> direction,absl::optional<Comparison::Type> type)625 XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
626 absl::Span<const int64> broadcast_dimensions,
627 absl::optional<ComparisonDirection> direction,
628 absl::optional<Comparison::Type> type) {
629 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
630 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
631 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
632 TF_ASSIGN_OR_RETURN(
633 Shape shape, ShapeInference::InferBinaryOpShape(
634 binop, *lhs_shape, *rhs_shape, broadcast_dimensions));
635
636 const int64 lhs_rank = lhs_shape->rank();
637 const int64 rhs_rank = rhs_shape->rank();
638
639 XlaOp updated_lhs = lhs;
640 XlaOp updated_rhs = rhs;
641 if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) {
642 const bool should_broadcast_lhs = lhs_rank < rhs_rank;
643 XlaOp from = should_broadcast_lhs ? lhs : rhs;
644 const Shape& from_shape = should_broadcast_lhs ? *lhs_shape : *rhs_shape;
645
646 std::vector<int64> to_size;
647 std::vector<bool> to_size_is_dynamic;
648 for (int i = 0; i < shape.rank(); i++) {
649 to_size.push_back(shape.dimensions(i));
650 to_size_is_dynamic.push_back(shape.is_dynamic_dimension(i));
651 }
652 for (int64 from_dim = 0; from_dim < from_shape.rank(); from_dim++) {
653 int64 to_dim = broadcast_dimensions[from_dim];
654 to_size[to_dim] = from_shape.dimensions(from_dim);
655 to_size_is_dynamic[to_dim] = from_shape.is_dynamic_dimension(from_dim);
656 }
657
658 const Shape& broadcasted_shape = ShapeUtil::MakeShape(
659 from_shape.element_type(), to_size, to_size_is_dynamic);
660 TF_ASSIGN_OR_RETURN(
661 XlaOp broadcasted_operand,
662 InDimBroadcast(broadcasted_shape, from, broadcast_dimensions));
663
664 updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs;
665 updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs;
666 }
667
668 TF_ASSIGN_OR_RETURN(const Shape* updated_lhs_shape,
669 GetShapePtr(updated_lhs));
670 if (!ShapeUtil::SameDimensions(shape, *updated_lhs_shape)) {
671 TF_ASSIGN_OR_RETURN(updated_lhs,
672 AddBroadcastSequence(shape, updated_lhs));
673 }
674 TF_ASSIGN_OR_RETURN(const Shape* updated_rhs_shape,
675 GetShapePtr(updated_rhs));
676 if (!ShapeUtil::SameDimensions(shape, *updated_rhs_shape)) {
677 TF_ASSIGN_OR_RETURN(updated_rhs,
678 AddBroadcastSequence(shape, updated_rhs));
679 }
680
681 if (binop == HloOpcode::kCompare) {
682 if (!direction.has_value()) {
683 return InvalidArgument(
684 "kCompare expects a ComparisonDirection, but none provided.");
685 }
686 if (type == absl::nullopt) {
687 return Compare(shape, updated_lhs, updated_rhs, *direction);
688 } else {
689 return Compare(shape, updated_lhs, updated_rhs, *direction, *type);
690 }
691 }
692
693 if (direction.has_value()) {
694 return InvalidArgument(
695 "A comparison direction is provided for a non-compare opcode: %s.",
696 HloOpcodeString(binop));
697 }
698 return BinaryOpNoBroadcast(binop, shape, updated_lhs, updated_rhs);
699 });
700 }
701
BinaryOpNoBroadcast(HloOpcode binop,const Shape & shape,XlaOp lhs,XlaOp rhs)702 XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
703 XlaOp lhs, XlaOp rhs) {
704 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
705 HloInstructionProto instr;
706 *instr.mutable_shape() = shape.ToProto();
707 return AddInstruction(std::move(instr), binop, {lhs, rhs});
708 });
709 }
710
Compare(const Shape & shape,XlaOp lhs,XlaOp rhs,ComparisonDirection direction)711 StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
712 ComparisonDirection direction) {
713 TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(lhs));
714 return Compare(
715 shape, lhs, rhs, direction,
716 Comparison::DefaultComparisonType(operand_shape.element_type()));
717 }
718
Compare(const Shape & shape,XlaOp lhs,XlaOp rhs,ComparisonDirection direction,Comparison::Type type)719 StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
720 ComparisonDirection direction,
721 Comparison::Type type) {
722 HloInstructionProto instr;
723 instr.set_comparison_direction(ComparisonDirectionToString(direction));
724 instr.set_comparison_type(ComparisonTypeToString(type));
725 *instr.mutable_shape() = shape.ToProto();
726 return AddInstruction(std::move(instr), HloOpcode::kCompare, {lhs, rhs});
727 }
728
TernaryOp(HloOpcode triop,XlaOp lhs,XlaOp rhs,XlaOp ehs)729 XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) {
730 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
731 XlaOp updated_lhs = lhs;
732 XlaOp updated_rhs = rhs;
733 XlaOp updated_ehs = ehs;
734 // The client API supports implicit broadcast for kSelect and kClamp, but
735 // XLA does not support implicit broadcast. Make implicit broadcast explicit
736 // and update the operands.
737 if (triop == HloOpcode::kSelect || triop == HloOpcode::kClamp) {
738 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
739 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
740 TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(ehs));
741
742 absl::optional<Shape> non_scalar_shape;
743 for (const Shape* shape : {lhs_shape, rhs_shape, ehs_shape}) {
744 if (shape->IsArray() && shape->rank() != 0) {
745 if (non_scalar_shape.has_value()) {
746 // TODO(jpienaar): The case where we need to compute the broadcasted
747 // shape by considering multiple of the shapes is not implemented.
748 // Consider reusing getBroadcastedType from mlir/Dialect/Traits.h.
749 TF_RET_CHECK(non_scalar_shape.value().dimensions() ==
750 shape->dimensions())
751 << "Unimplemented implicit broadcast.";
752 } else {
753 non_scalar_shape = *shape;
754 }
755 }
756 }
757 if (non_scalar_shape.has_value()) {
758 if (ShapeUtil::IsScalar(*lhs_shape)) {
759 TF_ASSIGN_OR_RETURN(updated_lhs,
760 AddBroadcastSequence(*non_scalar_shape, lhs));
761 }
762 if (ShapeUtil::IsScalar(*rhs_shape)) {
763 TF_ASSIGN_OR_RETURN(updated_rhs,
764 AddBroadcastSequence(*non_scalar_shape, rhs));
765 }
766 if (ShapeUtil::IsScalar(*ehs_shape)) {
767 TF_ASSIGN_OR_RETURN(updated_ehs,
768 AddBroadcastSequence(*non_scalar_shape, ehs));
769 }
770 }
771 }
772
773 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(updated_lhs));
774 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(updated_rhs));
775 TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(updated_ehs));
776 StatusOr<const Shape> status_or_shape = ShapeInference::InferTernaryOpShape(
777 triop, *lhs_shape, *rhs_shape, *ehs_shape);
778 if (!status_or_shape.status().ok()) {
779 return InvalidArgument(
780 "%s Input scalar shapes may have been changed to non-scalar shapes.",
781 status_or_shape.status().error_message());
782 }
783
784 return AddOpWithShape(triop, status_or_shape.ValueOrDie(),
785 {updated_lhs, updated_rhs, updated_ehs});
786 });
787 }
788
ConstantLiteral(const LiteralSlice & literal)789 XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
790 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
791 if (literal.shape().IsArray() && literal.element_count() > 1 &&
792 literal.IsAllFirst()) {
793 Literal scalar = LiteralUtil::GetFirstScalarLiteral(literal);
794 HloInstructionProto instr;
795 *instr.mutable_shape() = scalar.shape().ToProto();
796 *instr.mutable_literal() = scalar.ToProto();
797 TF_ASSIGN_OR_RETURN(
798 XlaOp scalar_op,
799 AddInstruction(std::move(instr), HloOpcode::kConstant));
800 return Broadcast(scalar_op, literal.shape().dimensions());
801 } else {
802 HloInstructionProto instr;
803 *instr.mutable_shape() = literal.shape().ToProto();
804 *instr.mutable_literal() = literal.ToProto();
805 return AddInstruction(std::move(instr), HloOpcode::kConstant);
806 }
807 });
808 }
809
Iota(const Shape & shape,int64 iota_dimension)810 XlaOp XlaBuilder::Iota(const Shape& shape, int64 iota_dimension) {
811 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
812 HloInstructionProto instr;
813 *instr.mutable_shape() = shape.ToProto();
814 instr.add_dimensions(iota_dimension);
815 return AddInstruction(std::move(instr), HloOpcode::kIota);
816 });
817 }
818
Iota(PrimitiveType type,int64 size)819 XlaOp XlaBuilder::Iota(PrimitiveType type, int64 size) {
820 return Iota(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0);
821 }
822
Call(const XlaComputation & computation,absl::Span<const XlaOp> operands)823 XlaOp XlaBuilder::Call(const XlaComputation& computation,
824 absl::Span<const XlaOp> operands) {
825 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
826 HloInstructionProto instr;
827 std::vector<const Shape*> operand_shape_ptrs;
828 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
829 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
830 [](const Shape& shape) { return &shape; });
831 TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
832 computation.GetProgramShape());
833 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape(
834 operand_shape_ptrs,
835 /*to_apply=*/called_program_shape));
836 *instr.mutable_shape() = shape.ToProto();
837
838 AddCalledComputation(computation, &instr);
839
840 return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
841 });
842 }
843
Parameter(int64 parameter_number,const Shape & shape,const string & name,const std::vector<bool> & replicated_at_leaf_buffers)844 XlaOp XlaBuilder::Parameter(
845 int64 parameter_number, const Shape& shape, const string& name,
846 const std::vector<bool>& replicated_at_leaf_buffers) {
847 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
848 HloInstructionProto instr;
849 if (!parameter_numbers_.insert(parameter_number).second) {
850 return InvalidArgument("parameter %d already registered",
851 parameter_number);
852 }
853 instr.set_parameter_number(parameter_number);
854 instr.set_name(name);
855 *instr.mutable_shape() = shape.ToProto();
856 if (!replicated_at_leaf_buffers.empty()) {
857 auto replication = instr.mutable_parameter_replication();
858 for (bool replicated : replicated_at_leaf_buffers) {
859 replication->add_replicated_at_leaf_buffers(replicated);
860 }
861 }
862 return AddInstruction(std::move(instr), HloOpcode::kParameter);
863 });
864 }
865
Broadcast(XlaOp operand,absl::Span<const int64> broadcast_sizes)866 XlaOp XlaBuilder::Broadcast(XlaOp operand,
867 absl::Span<const int64> broadcast_sizes) {
868 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
869 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
870 TF_ASSIGN_OR_RETURN(
871 const Shape& shape,
872 ShapeInference::InferBroadcastShape(*operand_shape, broadcast_sizes));
873
874 // The client-level broadcast op just appends dimensions on the left (adds
875 // lowest numbered dimensions). The HLO broadcast instruction is more
876 // flexible and can add new dimensions anywhere. The instruction's
877 // dimensions field maps operand dimensions to dimensions in the broadcast
878 // output, so to append dimensions on the left the instruction's dimensions
879 // should just be the n highest dimension numbers of the output shape where
880 // n is the number of input dimensions.
881 const int64 operand_rank = operand_shape->rank();
882 std::vector<int64> dimensions(operand_rank);
883 for (int i = 0; i < operand_rank; ++i) {
884 dimensions[i] = i + shape.rank() - operand_rank;
885 }
886 return InDimBroadcast(shape, operand, dimensions);
887 });
888 }
889
BroadcastInDim(XlaOp operand,const absl::Span<const int64> out_dim_size,const absl::Span<const int64> broadcast_dimensions)890 XlaOp XlaBuilder::BroadcastInDim(
891 XlaOp operand, const absl::Span<const int64> out_dim_size,
892 const absl::Span<const int64> broadcast_dimensions) {
893 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
894 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
895 // Output shape, in the case of degenerate broadcast, the out_dim_size is
896 // not necessarily the same as the dimension sizes of the output shape.
897 TF_ASSIGN_OR_RETURN(auto output_shape,
898 ShapeUtil::MakeValidatedShape(
899 operand_shape->element_type(), out_dim_size));
900 tensorflow::int64 broadcast_rank = broadcast_dimensions.size();
901 if (operand_shape->rank() != broadcast_rank) {
902 return InvalidArgument(
903 "Size of broadcast_dimensions has to match operand's rank; operand "
904 "rank: %lld, size of broadcast_dimensions %u.",
905 operand_shape->rank(), broadcast_dimensions.size());
906 }
907 for (int i = 0; i < broadcast_rank; i++) {
908 const tensorflow::int64 num_dims = out_dim_size.size();
909 if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] > num_dims) {
910 return InvalidArgument("Broadcast dimension %lld is out of bound",
911 broadcast_dimensions[i]);
912 }
913 output_shape.set_dynamic_dimension(
914 broadcast_dimensions[i], operand_shape->is_dynamic_dimension(i));
915 }
916
917 TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape(
918 *operand_shape, output_shape, broadcast_dimensions)
919 .status());
920 std::vector<int64> in_dim_size(out_dim_size.begin(), out_dim_size.end());
921 for (int i = 0; i < broadcast_rank; i++) {
922 in_dim_size[broadcast_dimensions[i]] = operand_shape->dimensions(i);
923 }
924 const auto& in_dim_shape =
925 ShapeUtil::MakeShape(operand_shape->element_type(), in_dim_size);
926 TF_ASSIGN_OR_RETURN(
927 XlaOp in_dim_broadcast,
928 InDimBroadcast(in_dim_shape, operand, broadcast_dimensions));
929
930 // If broadcast is not degenerate, return broadcasted result.
931 if (ShapeUtil::Equal(in_dim_shape, output_shape)) {
932 return in_dim_broadcast;
933 }
934
935 // Otherwise handle degenerate broadcast case.
936 return AddBroadcastSequence(output_shape, in_dim_broadcast);
937 });
938 }
939
ReshapeInternal(const Shape & shape,XlaOp operand,int64 inferred_dimension)940 StatusOr<XlaOp> XlaBuilder::ReshapeInternal(const Shape& shape, XlaOp operand,
941 int64 inferred_dimension) {
942 TF_RETURN_IF_ERROR(first_error_);
943
944 HloInstructionProto instr;
945 *instr.mutable_shape() = shape.ToProto();
946 if (inferred_dimension != -1) {
947 instr.add_dimensions(inferred_dimension);
948 }
949 return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand});
950 }
951
Slice(XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)952 XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span<const int64> start_indices,
953 absl::Span<const int64> limit_indices,
954 absl::Span<const int64> strides) {
955 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
956 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
957 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSliceShape(
958 *operand_shape, start_indices,
959 limit_indices, strides));
960 return SliceInternal(shape, operand, start_indices, limit_indices, strides);
961 });
962 }
963
SliceInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)964 StatusOr<XlaOp> XlaBuilder::SliceInternal(const Shape& shape, XlaOp operand,
965 absl::Span<const int64> start_indices,
966 absl::Span<const int64> limit_indices,
967 absl::Span<const int64> strides) {
968 HloInstructionProto instr;
969 *instr.mutable_shape() = shape.ToProto();
970 for (int i = 0, end = start_indices.size(); i < end; i++) {
971 auto* slice_config = instr.add_slice_dimensions();
972 slice_config->set_start(start_indices[i]);
973 slice_config->set_limit(limit_indices[i]);
974 slice_config->set_stride(strides[i]);
975 }
976 return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
977 }
978
SliceInDim(XlaOp operand,int64 start_index,int64 limit_index,int64 stride,int64 dimno)979 XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64 start_index,
980 int64 limit_index, int64 stride, int64 dimno) {
981 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
982 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
983 std::vector<int64> starts(shape->rank(), 0);
984 std::vector<int64> limits(shape->dimensions().begin(),
985 shape->dimensions().end());
986 std::vector<int64> strides(shape->rank(), 1);
987 starts[dimno] = start_index;
988 limits[dimno] = limit_index;
989 strides[dimno] = stride;
990 return Slice(operand, starts, limits, strides);
991 });
992 }
993
DynamicSlice(XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)994 XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
995 absl::Span<const XlaOp> start_indices,
996 absl::Span<const int64> slice_sizes) {
997 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
998 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
999 std::vector<const Shape*> start_indices_shape_ptrs;
1000 TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
1001 GetOperandShapes(start_indices));
1002 absl::c_transform(start_indices_shapes,
1003 std::back_inserter(start_indices_shape_ptrs),
1004 [](const Shape& shape) { return &shape; });
1005 TF_ASSIGN_OR_RETURN(Shape shape,
1006 ShapeInference::InferDynamicSliceShape(
1007 *operand_shape, start_indices_shapes, slice_sizes));
1008 return DynamicSliceInternal(shape, operand, start_indices, slice_sizes);
1009 });
1010 }
1011
DynamicSliceInternal(const Shape & shape,XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)1012 StatusOr<XlaOp> XlaBuilder::DynamicSliceInternal(
1013 const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
1014 absl::Span<const int64> slice_sizes) {
1015 HloInstructionProto instr;
1016 *instr.mutable_shape() = shape.ToProto();
1017
1018 for (int64 size : slice_sizes) {
1019 instr.add_dynamic_slice_sizes(size);
1020 }
1021
1022 std::vector<XlaOp> operands = {operand};
1023 operands.insert(operands.end(), start_indices.begin(), start_indices.end());
1024 return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands);
1025 }
1026
DynamicUpdateSlice(XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)1027 XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
1028 absl::Span<const XlaOp> start_indices) {
1029 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1030 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1031 TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
1032 std::vector<const Shape*> start_indices_shape_ptrs;
1033 TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
1034 GetOperandShapes(start_indices));
1035 absl::c_transform(start_indices_shapes,
1036 std::back_inserter(start_indices_shape_ptrs),
1037 [](const Shape& shape) { return &shape; });
1038 TF_ASSIGN_OR_RETURN(
1039 Shape shape, ShapeInference::InferDynamicUpdateSliceShape(
1040 *operand_shape, *update_shape, start_indices_shapes));
1041 return DynamicUpdateSliceInternal(shape, operand, update, start_indices);
1042 });
1043 }
1044
DynamicUpdateSliceInternal(const Shape & shape,XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)1045 StatusOr<XlaOp> XlaBuilder::DynamicUpdateSliceInternal(
1046 const Shape& shape, XlaOp operand, XlaOp update,
1047 absl::Span<const XlaOp> start_indices) {
1048 HloInstructionProto instr;
1049 *instr.mutable_shape() = shape.ToProto();
1050
1051 std::vector<XlaOp> operands = {operand, update};
1052 operands.insert(operands.end(), start_indices.begin(), start_indices.end());
1053 return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
1054 operands);
1055 }
1056
ConcatInDim(absl::Span<const XlaOp> operands,int64 dimension)1057 XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
1058 int64 dimension) {
1059 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1060 std::vector<const Shape*> operand_shape_ptrs;
1061 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
1062 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1063 [](const Shape& shape) { return &shape; });
1064 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape(
1065 operand_shape_ptrs, dimension));
1066 return ConcatInDimInternal(shape, operands, dimension);
1067 });
1068 }
1069
ConcatInDimInternal(const Shape & shape,absl::Span<const XlaOp> operands,int64 dimension)1070 StatusOr<XlaOp> XlaBuilder::ConcatInDimInternal(
1071 const Shape& shape, absl::Span<const XlaOp> operands, int64 dimension) {
1072 HloInstructionProto instr;
1073 *instr.mutable_shape() = shape.ToProto();
1074
1075 instr.add_dimensions(dimension);
1076
1077 return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
1078 }
1079
Pad(XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)1080 XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value,
1081 const PaddingConfig& padding_config) {
1082 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1083 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1084 TF_ASSIGN_OR_RETURN(const Shape* padding_value_shape,
1085 GetShapePtr(padding_value));
1086 TF_ASSIGN_OR_RETURN(
1087 Shape shape, ShapeInference::InferPadShape(
1088 *operand_shape, *padding_value_shape, padding_config));
1089 return PadInternal(shape, operand, padding_value, padding_config);
1090 });
1091 }
1092
PadInDim(XlaOp operand,XlaOp padding_value,int64 dimno,int64 pad_lo,int64 pad_hi)1093 XlaOp XlaBuilder::PadInDim(XlaOp operand, XlaOp padding_value, int64 dimno,
1094 int64 pad_lo, int64 pad_hi) {
1095 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1096 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
1097 PaddingConfig padding_config = MakeNoPaddingConfig(shape->rank());
1098 auto* dims = padding_config.mutable_dimensions(dimno);
1099 dims->set_edge_padding_low(pad_lo);
1100 dims->set_edge_padding_high(pad_hi);
1101 return Pad(operand, padding_value, padding_config);
1102 });
1103 }
1104
PadInternal(const Shape & shape,XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)1105 StatusOr<XlaOp> XlaBuilder::PadInternal(const Shape& shape, XlaOp operand,
1106 XlaOp padding_value,
1107 const PaddingConfig& padding_config) {
1108 HloInstructionProto instr;
1109 *instr.mutable_shape() = shape.ToProto();
1110 *instr.mutable_padding_config() = padding_config;
1111 return AddInstruction(std::move(instr), HloOpcode::kPad,
1112 {operand, padding_value});
1113 }
1114
Reshape(XlaOp operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes,int64 inferred_dimension)1115 XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> dimensions,
1116 absl::Span<const int64> new_sizes,
1117 int64 inferred_dimension) {
1118 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1119 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1120 TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferReshapeShape(
1121 *operand_shape, dimensions,
1122 new_sizes, inferred_dimension));
1123 XlaOp transposed = IsIdentityPermutation(dimensions)
1124 ? operand
1125 : Transpose(operand, dimensions);
1126 return ReshapeInternal(shape, transposed, inferred_dimension);
1127 });
1128 }
1129
Reshape(XlaOp operand,absl::Span<const int64> new_sizes,int64 inferred_dimension)1130 XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
1131 int64 inferred_dimension) {
1132 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1133 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
1134 std::vector<int64> dimensions(shape->dimensions_size());
1135 std::iota(dimensions.begin(), dimensions.end(), 0);
1136 return Reshape(operand, dimensions, new_sizes, inferred_dimension);
1137 });
1138 }
1139
Reshape(const Shape & shape,XlaOp operand,int64 inferred_dimension)1140 XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand,
1141 int64 inferred_dimension) {
1142 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1143 return ReshapeInternal(shape, operand, inferred_dimension);
1144 });
1145 }
1146
DynamicReshape(XlaOp operand,absl::Span<const XlaOp> dim_sizes,absl::Span<const int64> new_size_bounds,const std::vector<bool> & dims_are_dynamic)1147 XlaOp XlaBuilder::DynamicReshape(XlaOp operand,
1148 absl::Span<const XlaOp> dim_sizes,
1149 absl::Span<const int64> new_size_bounds,
1150 const std::vector<bool>& dims_are_dynamic) {
1151 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1152 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1153 std::vector<const Shape*> dim_size_shape_ptrs;
1154 TF_ASSIGN_OR_RETURN(const auto& dim_size_shapes,
1155 GetOperandShapes(dim_sizes));
1156
1157 absl::c_transform(dim_size_shapes, std::back_inserter(dim_size_shape_ptrs),
1158 [](const Shape& shape) { return &shape; });
1159 TF_ASSIGN_OR_RETURN(const Shape shape,
1160 ShapeInference::InferDynamicReshapeShape(
1161 *operand_shape, dim_size_shape_ptrs,
1162 new_size_bounds, dims_are_dynamic));
1163 TF_RETURN_IF_ERROR(first_error_);
1164 std::vector<XlaOp> operands;
1165 operands.reserve(1 + dim_sizes.size());
1166 operands.push_back(operand);
1167 for (const XlaOp& dim_size : dim_sizes) {
1168 operands.push_back(dim_size);
1169 }
1170 HloInstructionProto instr;
1171 *instr.mutable_shape() = shape.ToProto();
1172 return AddInstruction(std::move(instr), HloOpcode::kDynamicReshape,
1173 operands);
1174 });
1175 }
1176
Collapse(XlaOp operand,absl::Span<const int64> dimensions)1177 XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span<const int64> dimensions) {
1178 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1179 if (dimensions.size() <= 1) {
1180 // Not collapsing anything, trivially we can return the operand versus
1181 // enqueueing a trivial reshape.
1182 return operand;
1183 }
1184
1185 // Out-of-order collapse is not supported.
1186 // Checks that the collapsed dimensions are in order and consecutive.
1187 for (absl::Span<const int64>::size_type i = 1; i < dimensions.size(); ++i) {
1188 if (dimensions[i] - 1 != dimensions[i - 1]) {
1189 return InvalidArgument(
1190 "Collapsed dimensions are not in consecutive order.");
1191 }
1192 }
1193
1194 // Create a new sizes vector from the old shape, replacing the collapsed
1195 // dimensions by the product of their sizes.
1196 TF_ASSIGN_OR_RETURN(const Shape* original_shape, GetShapePtr(operand));
1197
1198 VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape);
1199 VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ",");
1200
1201 std::vector<int64> new_sizes;
1202 for (int i = 0; i < original_shape->rank(); ++i) {
1203 if (i <= dimensions.front() || i > dimensions.back()) {
1204 new_sizes.push_back(original_shape->dimensions(i));
1205 } else {
1206 new_sizes.back() *= original_shape->dimensions(i);
1207 }
1208 }
1209
1210 VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]";
1211
1212 return Reshape(operand, new_sizes);
1213 });
1214 }
1215
Trace(const string & tag,XlaOp operand)1216 void XlaBuilder::Trace(const string& tag, XlaOp operand) {
1217 ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1218 HloInstructionProto instr;
1219 *instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
1220 *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto();
1221 return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
1222 });
1223 }
1224
Select(XlaOp pred,XlaOp on_true,XlaOp on_false)1225 XlaOp XlaBuilder::Select(XlaOp pred, XlaOp on_true, XlaOp on_false) {
1226 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1227 TF_ASSIGN_OR_RETURN(const Shape* true_shape, GetShapePtr(on_true));
1228 TF_ASSIGN_OR_RETURN(const Shape* false_shape, GetShapePtr(on_false));
1229 TF_RET_CHECK(true_shape->IsTuple() == false_shape->IsTuple());
1230 HloOpcode opcode =
1231 true_shape->IsTuple() ? HloOpcode::kTupleSelect : HloOpcode::kSelect;
1232 return TernaryOp(opcode, pred, on_true, on_false);
1233 });
1234 }
1235
Tuple(absl::Span<const XlaOp> elements)1236 XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
1237 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1238 std::vector<const Shape*> operand_shape_ptrs;
1239 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
1240 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1241 [](const Shape& shape) { return &shape; });
1242 TF_ASSIGN_OR_RETURN(const Shape shape,
1243 ShapeInference::InferVariadicOpShape(
1244 HloOpcode::kTuple, operand_shape_ptrs));
1245 return TupleInternal(shape, elements);
1246 });
1247 }
1248
TupleInternal(const Shape & shape,absl::Span<const XlaOp> elements)1249 StatusOr<XlaOp> XlaBuilder::TupleInternal(const Shape& shape,
1250 absl::Span<const XlaOp> elements) {
1251 HloInstructionProto instr;
1252 *instr.mutable_shape() = shape.ToProto();
1253 return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
1254 }
1255
GetTupleElement(XlaOp tuple_data,int64 index)1256 XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64 index) {
1257 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1258 TF_ASSIGN_OR_RETURN(const Shape* tuple_shape, GetShapePtr(tuple_data));
1259 if (!tuple_shape->IsTuple()) {
1260 return InvalidArgument(
1261 "Operand to GetTupleElement() is not a tuple; got %s",
1262 ShapeUtil::HumanString(*tuple_shape));
1263 }
1264 if (index < 0 || index >= ShapeUtil::TupleElementCount(*tuple_shape)) {
1265 return InvalidArgument(
1266 "GetTupleElement() index (%d) out of range for tuple shape %s", index,
1267 ShapeUtil::HumanString(*tuple_shape));
1268 }
1269 return GetTupleElementInternal(
1270 ShapeUtil::GetTupleElementShape(*tuple_shape, index), tuple_data,
1271 index);
1272 });
1273 }
1274
GetTupleElementInternal(const Shape & shape,XlaOp tuple_data,int64 index)1275 StatusOr<XlaOp> XlaBuilder::GetTupleElementInternal(const Shape& shape,
1276 XlaOp tuple_data,
1277 int64 index) {
1278 HloInstructionProto instr;
1279 *instr.mutable_shape() = shape.ToProto();
1280 instr.set_tuple_index(index);
1281 return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
1282 {tuple_data});
1283 }
1284
Dot(XlaOp lhs,XlaOp rhs,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1285 XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
1286 const PrecisionConfig* precision_config,
1287 absl::optional<PrimitiveType> preferred_element_type) {
1288 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1289 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1290
1291 DotDimensionNumbers dimension_numbers;
1292 dimension_numbers.add_lhs_contracting_dimensions(
1293 lhs_shape->dimensions_size() == 1 ? 0 : 1);
1294 dimension_numbers.add_rhs_contracting_dimensions(0);
1295 return DotGeneral(lhs, rhs, dimension_numbers, precision_config);
1296 });
1297 }
1298
DotGeneral(XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1299 XlaOp XlaBuilder::DotGeneral(
1300 XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
1301 const PrecisionConfig* precision_config,
1302 absl::optional<PrimitiveType> preferred_element_type) {
1303 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1304 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1305 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1306 TF_ASSIGN_OR_RETURN(
1307 Shape shape,
1308 ShapeInference::InferDotOpShape(
1309 *lhs_shape, *rhs_shape, dimension_numbers, preferred_element_type));
1310 return DotGeneralInternal(shape, lhs, rhs, dimension_numbers,
1311 precision_config);
1312 });
1313 }
1314
DotGeneralInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config)1315 StatusOr<XlaOp> XlaBuilder::DotGeneralInternal(
1316 const Shape& shape, XlaOp lhs, XlaOp rhs,
1317 const DotDimensionNumbers& dimension_numbers,
1318 const PrecisionConfig* precision_config) {
1319 HloInstructionProto instr;
1320 *instr.mutable_shape() = shape.ToProto();
1321 *instr.mutable_dot_dimension_numbers() = dimension_numbers;
1322 if (precision_config != nullptr) {
1323 *instr.mutable_precision_config() = *precision_config;
1324 }
1325 return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
1326 }
1327
VerifyConvolution(const Shape & lhs_shape,const Shape & rhs_shape,const ConvolutionDimensionNumbers & dimension_numbers) const1328 Status XlaBuilder::VerifyConvolution(
1329 const Shape& lhs_shape, const Shape& rhs_shape,
1330 const ConvolutionDimensionNumbers& dimension_numbers) const {
1331 if (lhs_shape.rank() != rhs_shape.rank()) {
1332 return InvalidArgument(
1333 "Convolution arguments must have same number of "
1334 "dimensions. Got: %s and %s",
1335 ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1336 }
1337 int num_dims = lhs_shape.rank();
1338 if (num_dims < 2) {
1339 return InvalidArgument(
1340 "Convolution expects argument arrays with >= 3 dimensions. "
1341 "Got: %s and %s",
1342 ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1343 }
1344 int num_spatial_dims = num_dims - 2;
1345
1346 const auto check_spatial_dimensions =
1347 [&](const char* const field_name,
1348 const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
1349 numbers) {
1350 if (numbers.size() != num_spatial_dims) {
1351 return InvalidArgument("Expected %d elements for %s, but got %d.",
1352 num_spatial_dims, field_name, numbers.size());
1353 }
1354 for (int i = 0; i < numbers.size(); ++i) {
1355 if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
1356 return InvalidArgument("Convolution %s[%d] is out of bounds: %d",
1357 field_name, i, numbers.Get(i));
1358 }
1359 }
1360 return Status::OK();
1361 };
1362 TF_RETURN_IF_ERROR(
1363 check_spatial_dimensions("input_spatial_dimensions",
1364 dimension_numbers.input_spatial_dimensions()));
1365 TF_RETURN_IF_ERROR(
1366 check_spatial_dimensions("kernel_spatial_dimensions",
1367 dimension_numbers.kernel_spatial_dimensions()));
1368 return check_spatial_dimensions(
1369 "output_spatial_dimensions",
1370 dimension_numbers.output_spatial_dimensions());
1371 }
1372
Conv(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1373 XlaOp XlaBuilder::Conv(XlaOp lhs, XlaOp rhs,
1374 absl::Span<const int64> window_strides, Padding padding,
1375 int64 feature_group_count, int64 batch_group_count,
1376 const PrecisionConfig* precision_config,
1377 absl::optional<PrimitiveType> preferred_element_type) {
1378 return ConvWithGeneralDimensions(
1379 lhs, rhs, window_strides, padding,
1380 CreateDefaultConvDimensionNumbers(window_strides.size()),
1381 feature_group_count, batch_group_count, precision_config,
1382 preferred_element_type);
1383 }
1384
ConvWithGeneralPadding(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1385 XlaOp XlaBuilder::ConvWithGeneralPadding(
1386 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1387 absl::Span<const std::pair<int64, int64>> padding,
1388 int64 feature_group_count, int64 batch_group_count,
1389 const PrecisionConfig* precision_config,
1390 absl::optional<PrimitiveType> preferred_element_type) {
1391 return ConvGeneral(lhs, rhs, window_strides, padding,
1392 CreateDefaultConvDimensionNumbers(window_strides.size()),
1393 feature_group_count, batch_group_count, precision_config,
1394 preferred_element_type);
1395 }
1396
ConvWithGeneralDimensions(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1397 XlaOp XlaBuilder::ConvWithGeneralDimensions(
1398 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1399 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1400 int64 feature_group_count, int64 batch_group_count,
1401 const PrecisionConfig* precision_config,
1402 absl::optional<PrimitiveType> preferred_element_type) {
1403 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1404 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1405 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1406
1407 TF_RETURN_IF_ERROR(
1408 VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers));
1409
1410 std::vector<int64> base_area_dimensions(
1411 dimension_numbers.input_spatial_dimensions_size());
1412 for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size();
1413 ++i) {
1414 base_area_dimensions[i] =
1415 lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i));
1416 }
1417
1418 std::vector<int64> window_dimensions(
1419 dimension_numbers.kernel_spatial_dimensions_size());
1420 for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
1421 ++i) {
1422 window_dimensions[i] =
1423 rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1424 }
1425
1426 return ConvGeneral(lhs, rhs, window_strides,
1427 MakePadding(base_area_dimensions, window_dimensions,
1428 window_strides, padding),
1429 dimension_numbers, feature_group_count,
1430 batch_group_count, precision_config,
1431 preferred_element_type);
1432 });
1433 }
1434
ConvGeneral(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1435 XlaOp XlaBuilder::ConvGeneral(
1436 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1437 absl::Span<const std::pair<int64, int64>> padding,
1438 const ConvolutionDimensionNumbers& dimension_numbers,
1439 int64 feature_group_count, int64 batch_group_count,
1440 const PrecisionConfig* precision_config,
1441 absl::optional<PrimitiveType> preferred_element_type) {
1442 return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
1443 dimension_numbers, feature_group_count,
1444 batch_group_count, precision_config,
1445 preferred_element_type);
1446 }
1447
ConvGeneralDilated(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1448 XlaOp XlaBuilder::ConvGeneralDilated(
1449 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1450 absl::Span<const std::pair<int64, int64>> padding,
1451 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1452 const ConvolutionDimensionNumbers& dimension_numbers,
1453 int64 feature_group_count, int64 batch_group_count,
1454 const PrecisionConfig* precision_config,
1455 absl::optional<PrimitiveType> preferred_element_type) {
1456 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1457 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1458 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1459 TF_RETURN_IF_ERROR(
1460 VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers));
1461
1462 std::vector<int64> window_dimensions(
1463 dimension_numbers.kernel_spatial_dimensions_size());
1464 for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
1465 ++i) {
1466 window_dimensions[i] =
1467 rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1468 }
1469
1470 TF_ASSIGN_OR_RETURN(Window window,
1471 ShapeInference::InferWindowFromDimensions(
1472 window_dimensions, window_strides, padding,
1473 lhs_dilation, rhs_dilation));
1474 TF_ASSIGN_OR_RETURN(
1475 Shape shape,
1476 ShapeInference::InferConvolveShape(
1477 *lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
1478 window, dimension_numbers, preferred_element_type));
1479 return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides,
1480 padding, lhs_dilation, rhs_dilation,
1481 dimension_numbers, feature_group_count,
1482 batch_group_count, precision_config);
1483 });
1484 }
1485
DynamicConvInstruction(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)1486 StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
1487 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1488 absl::Span<const std::pair<int64, int64>> padding,
1489 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1490 const ConvolutionDimensionNumbers& dimension_numbers,
1491 int64 feature_group_count, int64 batch_group_count,
1492 const PrecisionConfig* precision_config, PaddingType padding_type,
1493 absl::optional<PrimitiveType> preferred_element_type) {
1494 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1495 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1496 std::vector<int64> window_dimensions(
1497 dimension_numbers.kernel_spatial_dimensions_size());
1498 for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) {
1499 window_dimensions[i] =
1500 rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1501 }
1502
1503 TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions(
1504 window_dimensions, window_strides,
1505 padding, lhs_dilation, rhs_dilation));
1506 TF_ASSIGN_OR_RETURN(
1507 Shape shape,
1508 ShapeInference::InferConvolveShape(
1509 *lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
1510 window, dimension_numbers, preferred_element_type));
1511
1512 HloInstructionProto instr;
1513 *instr.mutable_shape() = shape.ToProto();
1514
1515 *instr.mutable_window() = window;
1516 *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
1517 instr.set_feature_group_count(feature_group_count);
1518 instr.set_batch_group_count(batch_group_count);
1519 instr.set_padding_type(padding_type);
1520
1521 if (precision_config != nullptr) {
1522 *instr.mutable_precision_config() = *precision_config;
1523 }
1524 return std::move(instr);
1525 }
1526
DynamicConvInputGrad(XlaOp input_sizes,XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)1527 XlaOp XlaBuilder::DynamicConvInputGrad(
1528 XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
1529 absl::Span<const int64> window_strides,
1530 absl::Span<const std::pair<int64, int64>> padding,
1531 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1532 const ConvolutionDimensionNumbers& dimension_numbers,
1533 int64 feature_group_count, int64 batch_group_count,
1534 const PrecisionConfig* precision_config, PaddingType padding_type,
1535 absl::optional<PrimitiveType> preferred_element_type) {
1536 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1537 TF_ASSIGN_OR_RETURN(
1538 HloInstructionProto instr,
1539 DynamicConvInstruction(
1540 lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
1541 dimension_numbers, feature_group_count, batch_group_count,
1542 precision_config, padding_type, preferred_element_type));
1543
1544 instr.set_custom_call_target("DynamicConvolutionInputGrad");
1545
1546 return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
1547 {input_sizes, lhs, rhs});
1548 });
1549 }
1550
DynamicConvKernelGrad(XlaOp activations,XlaOp gradients,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)1551 XlaOp XlaBuilder::DynamicConvKernelGrad(
1552 XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
1553 absl::Span<const std::pair<int64, int64>> padding,
1554 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1555 const ConvolutionDimensionNumbers& dimension_numbers,
1556 int64 feature_group_count, int64 batch_group_count,
1557 const PrecisionConfig* precision_config, PaddingType padding_type,
1558 absl::optional<PrimitiveType> preferred_element_type) {
1559 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1560 TF_ASSIGN_OR_RETURN(
1561 HloInstructionProto instr,
1562 DynamicConvInstruction(activations, gradients, window_strides, padding,
1563 lhs_dilation, rhs_dilation, dimension_numbers,
1564 feature_group_count, batch_group_count,
1565 precision_config, padding_type,
1566 preferred_element_type));
1567
1568 instr.set_custom_call_target("DynamicConvolutionKernelGrad");
1569 // The gradient of kernel has kernel shape and shouldn't have any dynamic
1570 // sizes.
1571 instr.mutable_shape()->clear_is_dynamic_dimension();
1572 return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
1573 {activations, gradients});
1574 });
1575 }
1576
DynamicConvForward(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)1577 XlaOp XlaBuilder::DynamicConvForward(
1578 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1579 absl::Span<const std::pair<int64, int64>> padding,
1580 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1581 const ConvolutionDimensionNumbers& dimension_numbers,
1582 int64 feature_group_count, int64 batch_group_count,
1583 const PrecisionConfig* precision_config, PaddingType padding_type,
1584 absl::optional<PrimitiveType> preferred_element_type) {
1585 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1586 TF_ASSIGN_OR_RETURN(
1587 HloInstructionProto instr,
1588 DynamicConvInstruction(
1589 lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
1590 dimension_numbers, feature_group_count, batch_group_count,
1591 precision_config, padding_type, preferred_element_type));
1592 instr.set_custom_call_target("DynamicConvolutionForward");
1593
1594 return AddInstruction(std::move(instr), HloOpcode::kCustomCall, {lhs, rhs});
1595 });
1596 }
1597
ConvGeneralDilatedInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const Window & window,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1598 StatusOr<XlaOp> XlaBuilder::ConvGeneralDilatedInternal(
1599 const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
1600 absl::Span<const int64> window_strides,
1601 absl::Span<const std::pair<int64, int64>> padding,
1602 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1603 const ConvolutionDimensionNumbers& dimension_numbers,
1604 int64 feature_group_count, int64 batch_group_count,
1605 const PrecisionConfig* precision_config) {
1606 HloInstructionProto instr;
1607 *instr.mutable_shape() = shape.ToProto();
1608
1609 *instr.mutable_window() = window;
1610 *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
1611 instr.set_feature_group_count(feature_group_count);
1612 instr.set_batch_group_count(batch_group_count);
1613
1614 if (precision_config != nullptr) {
1615 *instr.mutable_precision_config() = *precision_config;
1616 }
1617
1618 return AddInstruction(std::move(instr), HloOpcode::kConvolution, {lhs, rhs});
1619 }
1620
Fft(XlaOp operand,const FftType fft_type,const absl::Span<const int64> fft_length)1621 XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type,
1622 const absl::Span<const int64> fft_length) {
1623 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1624 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1625 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape(
1626 *operand_shape, fft_type, fft_length));
1627 return FftInternal(shape, operand, fft_type, fft_length);
1628 });
1629 }
1630
FftInternal(const Shape & shape,XlaOp operand,const FftType fft_type,const absl::Span<const int64> fft_length)1631 StatusOr<XlaOp> XlaBuilder::FftInternal(
1632 const Shape& shape, XlaOp operand, const FftType fft_type,
1633 const absl::Span<const int64> fft_length) {
1634 HloInstructionProto instr;
1635 *instr.mutable_shape() = shape.ToProto();
1636 instr.set_fft_type(fft_type);
1637 for (int64 i : fft_length) {
1638 instr.add_fft_length(i);
1639 }
1640
1641 return AddInstruction(std::move(instr), HloOpcode::kFft, {operand});
1642 }
1643
TriangularSolveInternal(const Shape & shape,XlaOp a,XlaOp b,TriangularSolveOptions options)1644 StatusOr<XlaOp> XlaBuilder::TriangularSolveInternal(
1645 const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) {
1646 HloInstructionProto instr;
1647 *instr.mutable_triangular_solve_options() = std::move(options);
1648 *instr.mutable_shape() = shape.ToProto();
1649
1650 return AddInstruction(std::move(instr), HloOpcode::kTriangularSolve, {a, b});
1651 }
1652
CholeskyInternal(const Shape & shape,XlaOp a,bool lower)1653 StatusOr<XlaOp> XlaBuilder::CholeskyInternal(const Shape& shape, XlaOp a,
1654 bool lower) {
1655 HloInstructionProto instr;
1656 xla::CholeskyOptions& options = *instr.mutable_cholesky_options();
1657 options.set_lower(lower);
1658 *instr.mutable_shape() = shape.ToProto();
1659
1660 return AddInstruction(std::move(instr), HloOpcode::kCholesky, {a});
1661 }
1662
Infeed(const Shape & shape,const string & config)1663 XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
1664 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1665 HloInstructionProto instr;
1666 if (!LayoutUtil::HasLayout(shape)) {
1667 return InvalidArgument("Given shape to Infeed must have a layout");
1668 }
1669 const Shape infeed_instruction_shape =
1670 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1671 *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1672 instr.set_infeed_config(config);
1673
1674 if (shape.IsArray() && sharding() &&
1675 sharding()->type() == OpSharding::OTHER) {
1676 // TODO(b/110793772): Support tiled array-shaped infeeds.
1677 return InvalidArgument(
1678 "Tiled sharding is not yet supported for array-shaped infeeds");
1679 }
1680
1681 if (sharding() && sharding()->type() == OpSharding::REPLICATED) {
1682 return InvalidArgument(
1683 "Replicated sharding is not yet supported for infeeds");
1684 }
1685
1686 // Infeed takes a single token operand. Generate the token to pass to the
1687 // infeed.
1688 XlaOp token;
1689 auto make_token = [&]() {
1690 HloInstructionProto token_instr;
1691 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1692 return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {});
1693 };
1694 if (sharding()) {
1695 // Arbitrarily assign token to device 0.
1696 OpSharding sharding = sharding_builder::AssignDevice(0);
1697 XlaScopedShardingAssignment scoped_sharding(this, sharding);
1698 TF_ASSIGN_OR_RETURN(token, make_token());
1699 } else {
1700 TF_ASSIGN_OR_RETURN(token, make_token());
1701 }
1702
1703 // The sharding is set by the client according to the data tuple shape.
1704 // However, the shape of the infeed instruction is a tuple containing the
1705 // data and a token. For tuple sharding type, the sharding must be changed
1706 // to accommodate the token.
1707 XlaOp infeed;
1708 if (sharding() && sharding()->type() == OpSharding::TUPLE) {
1709 // TODO(b/80000000): Remove this when clients have been updated to handle
1710 // tokens.
1711 OpSharding infeed_instruction_sharding = *sharding();
1712 // Arbitrarily assign the token to device 0.
1713 *infeed_instruction_sharding.add_tuple_shardings() =
1714 sharding_builder::AssignDevice(0);
1715 XlaScopedShardingAssignment scoped_sharding(this,
1716 infeed_instruction_sharding);
1717 TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1718 HloOpcode::kInfeed, {token}));
1719 } else {
1720 TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1721 HloOpcode::kInfeed, {token}));
1722 }
1723
1724 // The infeed instruction produces a tuple of the infed data and a token
1725 // type. Return XLA op containing the data.
1726 // TODO(b/80000000): Remove this when clients have been updated to handle
1727 // tokens.
1728 HloInstructionProto infeed_data;
1729 *infeed_data.mutable_shape() = shape.ToProto();
1730 infeed_data.set_tuple_index(0);
1731 return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement,
1732 {infeed});
1733 });
1734 }
1735
InfeedWithToken(XlaOp token,const Shape & shape,const string & config)1736 XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape,
1737 const string& config) {
1738 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1739 if (!LayoutUtil::HasLayout(shape)) {
1740 return InvalidArgument("Given shape to Infeed must have a layout");
1741 }
1742 const Shape infeed_instruction_shape =
1743 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1744
1745 if (shape.IsArray() && sharding() &&
1746 sharding()->type() == OpSharding::OTHER) {
1747 // TODO(b/110793772): Support tiled array-shaped infeeds.
1748 return InvalidArgument(
1749 "Tiled sharding is not yet supported for array-shaped infeeds");
1750 }
1751
1752 if (sharding() && sharding()->type() == OpSharding::REPLICATED) {
1753 return InvalidArgument(
1754 "Replicated sharding is not yet supported for infeeds");
1755 }
1756 return InfeedWithTokenInternal(infeed_instruction_shape, token, config);
1757 });
1758 }
1759
InfeedWithTokenInternal(const Shape & infeed_instruction_shape,XlaOp token,const string & config)1760 StatusOr<XlaOp> XlaBuilder::InfeedWithTokenInternal(
1761 const Shape& infeed_instruction_shape, XlaOp token, const string& config) {
1762 HloInstructionProto instr;
1763 *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1764 instr.set_infeed_config(config);
1765 return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
1766 }
1767
Outfeed(XlaOp operand,const Shape & shape_with_layout,const string & outfeed_config)1768 void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout,
1769 const string& outfeed_config) {
1770 ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1771 HloInstructionProto instr;
1772
1773 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1774
1775 // Check and set outfeed shape.
1776 if (!LayoutUtil::HasLayout(shape_with_layout)) {
1777 return InvalidArgument("Given shape to Outfeed must have a layout");
1778 }
1779 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1780 if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
1781 return InvalidArgument(
1782 "Outfeed shape %s must be compatible with operand shape %s",
1783 ShapeUtil::HumanStringWithLayout(shape_with_layout),
1784 ShapeUtil::HumanStringWithLayout(*operand_shape));
1785 }
1786 *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1787
1788 instr.set_outfeed_config(outfeed_config);
1789
1790 // Outfeed takes a token as its second operand. Generate the token to pass
1791 // to the outfeed.
1792 HloInstructionProto token_instr;
1793 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1794 TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
1795 HloOpcode::kAfterAll, {}));
1796
1797 TF_RETURN_IF_ERROR(
1798 AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand, token})
1799 .status());
1800
1801 // The outfeed instruction produces a token. However, existing users expect
1802 // a nil shape (empty tuple). This should only be relevant if the outfeed is
1803 // the root of a computation.
1804 // TODO(b/80000000): Remove this when clients have been updated to handle
1805 // tokens.
1806 HloInstructionProto tuple_instr;
1807 *tuple_instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
1808
1809 // The dummy tuple should have no sharding.
1810 {
1811 XlaScopedShardingAssignment scoped_sharding(this, OpSharding());
1812 TF_ASSIGN_OR_RETURN(
1813 XlaOp empty_tuple,
1814 AddInstruction(std::move(tuple_instr), HloOpcode::kTuple, {}));
1815 return empty_tuple;
1816 }
1817 });
1818 }
1819
OutfeedWithToken(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)1820 XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token,
1821 const Shape& shape_with_layout,
1822 const string& outfeed_config) {
1823 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1824 // Check and set outfeed shape.
1825 if (!LayoutUtil::HasLayout(shape_with_layout)) {
1826 return InvalidArgument("Given shape to Outfeed must have a layout");
1827 }
1828 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1829 if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
1830 return InvalidArgument(
1831 "Outfeed shape %s must be compatible with operand shape %s",
1832 ShapeUtil::HumanStringWithLayout(shape_with_layout),
1833 ShapeUtil::HumanStringWithLayout(*operand_shape));
1834 }
1835 return OutfeedWithTokenInternal(operand, token, shape_with_layout,
1836 outfeed_config);
1837 });
1838 }
1839
OutfeedWithTokenInternal(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)1840 StatusOr<XlaOp> XlaBuilder::OutfeedWithTokenInternal(
1841 XlaOp operand, XlaOp token, const Shape& shape_with_layout,
1842 const string& outfeed_config) {
1843 HloInstructionProto instr;
1844 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1845 *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1846 instr.set_outfeed_config(outfeed_config);
1847 return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
1848 {operand, token});
1849 }
1850
CreateToken()1851 XlaOp XlaBuilder::CreateToken() {
1852 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1853 HloInstructionProto instr;
1854 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1855 return AddInstruction(std::move(instr), HloOpcode::kAfterAll);
1856 });
1857 }
1858
AfterAll(absl::Span<const XlaOp> tokens)1859 XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
1860 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1861 if (tokens.empty()) {
1862 return InvalidArgument("AfterAll requires at least one operand");
1863 }
1864 for (int i = 0, end = tokens.size(); i < end; ++i) {
1865 XlaOp operand = tokens[i];
1866 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1867 if (!operand_shape->IsToken()) {
1868 return InvalidArgument(
1869 "All operands to AfterAll must be tokens; operand %d has shape %s",
1870 i, ShapeUtil::HumanString(*operand_shape));
1871 }
1872 }
1873 HloInstructionProto instr;
1874 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1875 return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens);
1876 });
1877 }
1878
CustomCall(const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal)1879 XlaOp XlaBuilder::CustomCall(
1880 const string& call_target_name, absl::Span<const XlaOp> operands,
1881 const Shape& shape, const string& opaque,
1882 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
1883 bool has_side_effect,
1884 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1885 output_operand_aliasing,
1886 const Literal* literal) {
1887 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1888 if (absl::StartsWith(call_target_name, "$")) {
1889 return InvalidArgument(
1890 "Invalid custom_call_target \"%s\": Call targets that start with '$' "
1891 "are reserved for internal use.",
1892 call_target_name);
1893 }
1894 if (operand_shapes_with_layout.has_value()) {
1895 if (!LayoutUtil::HasLayout(shape)) {
1896 return InvalidArgument(
1897 "Result shape must have layout for custom call with constrained "
1898 "layout.");
1899 }
1900 if (operands.size() != operand_shapes_with_layout->size()) {
1901 return InvalidArgument(
1902 "Must specify a shape with layout for each operand for custom call "
1903 "with constrained layout; given %d shapes, expected %d",
1904 operand_shapes_with_layout->size(), operands.size());
1905 }
1906 int64 operand_num = 0;
1907 for (const Shape& operand_shape : *operand_shapes_with_layout) {
1908 if (!LayoutUtil::HasLayout(operand_shape)) {
1909 return InvalidArgument(
1910 "No layout specified for operand %d for custom call with "
1911 "constrained layout.",
1912 operand_num);
1913 }
1914 ++operand_num;
1915 }
1916 }
1917 return CustomCallInternal(call_target_name, operands, shape, opaque,
1918 operand_shapes_with_layout, has_side_effect,
1919 output_operand_aliasing, literal);
1920 });
1921 }
1922
CustomCallInternal(const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal)1923 StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
1924 const string& call_target_name, absl::Span<const XlaOp> operands,
1925 const Shape& shape, const string& opaque,
1926 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
1927 bool has_side_effect,
1928 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1929 output_operand_aliasing,
1930 const Literal* literal) {
1931 HloInstructionProto instr;
1932 *instr.mutable_shape() = shape.ToProto();
1933 instr.set_custom_call_target(call_target_name);
1934 instr.set_backend_config(opaque);
1935 if (operand_shapes_with_layout.has_value()) {
1936 instr.set_constrain_layout(true);
1937 for (const Shape& operand_shape : *operand_shapes_with_layout) {
1938 *instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
1939 }
1940 }
1941 if (literal != nullptr) {
1942 *instr.mutable_literal() = literal->ToProto();
1943 }
1944 instr.set_custom_call_has_side_effect(has_side_effect);
1945 for (const auto& pair : output_operand_aliasing) {
1946 auto aliasing = instr.add_custom_call_output_operand_aliasing();
1947 aliasing->set_operand_index(pair.second.first);
1948 for (int64 index : pair.second.second) {
1949 aliasing->add_operand_shape_index(index);
1950 }
1951 for (int64 index : pair.first) {
1952 aliasing->add_output_shape_index(index);
1953 }
1954 }
1955 return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
1956 }
1957
CustomCall(const string & call_target_name,absl::Span<const XlaOp> operands,const XlaComputation & computation,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal)1958 XlaOp XlaBuilder::CustomCall(
1959 const string& call_target_name, absl::Span<const XlaOp> operands,
1960 const XlaComputation& computation, const Shape& shape, const string& opaque,
1961 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
1962 bool has_side_effect,
1963 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1964 output_operand_aliasing,
1965 const Literal* literal) {
1966 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1967 HloInstructionProto instr;
1968 if (absl::StartsWith(call_target_name, "$")) {
1969 return InvalidArgument(
1970 "Invalid custom_call_target \"%s\": Call targets that start with '$' "
1971 "are reserved for internal use.",
1972 call_target_name);
1973 }
1974 *instr.mutable_shape() = shape.ToProto();
1975 instr.set_custom_call_target(call_target_name);
1976 instr.set_backend_config(opaque);
1977 if (literal != nullptr) {
1978 *instr.mutable_literal() = literal->ToProto();
1979 }
1980 if (operand_shapes_with_layout.has_value()) {
1981 if (!LayoutUtil::HasLayout(shape)) {
1982 return InvalidArgument(
1983 "Result shape must have layout for custom call with constrained "
1984 "layout.");
1985 }
1986 if (operands.size() != operand_shapes_with_layout->size()) {
1987 return InvalidArgument(
1988 "Must specify a shape with layout for each operand for custom call "
1989 "with constrained layout; given %d shapes, expected %d",
1990 operand_shapes_with_layout->size(), operands.size());
1991 }
1992 instr.set_constrain_layout(true);
1993 int64 operand_num = 0;
1994 for (const Shape& operand_shape : *operand_shapes_with_layout) {
1995 if (!LayoutUtil::HasLayout(operand_shape)) {
1996 return InvalidArgument(
1997 "No layout specified for operand %d for custom call with "
1998 "constrained layout.",
1999 operand_num);
2000 }
2001 *instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
2002 ++operand_num;
2003 }
2004 }
2005 AddCalledComputation(computation, &instr);
2006 for (const auto& pair : output_operand_aliasing) {
2007 auto aliasing = instr.add_custom_call_output_operand_aliasing();
2008 aliasing->set_operand_index(pair.second.first);
2009 for (int64 index : pair.second.second) {
2010 aliasing->add_operand_shape_index(index);
2011 }
2012 for (int64 index : pair.first) {
2013 aliasing->add_output_shape_index(index);
2014 }
2015 }
2016 return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
2017 });
2018 }
2019
Transpose(XlaOp operand,absl::Span<const int64> permutation)2020 XlaOp XlaBuilder::Transpose(XlaOp operand,
2021 absl::Span<const int64> permutation) {
2022 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2023 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2024 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape(
2025 *operand_shape, permutation));
2026 return TransposeInternal(shape, operand, permutation);
2027 });
2028 }
2029
TransposeInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> permutation)2030 StatusOr<XlaOp> XlaBuilder::TransposeInternal(
2031 const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
2032 HloInstructionProto instr;
2033 *instr.mutable_shape() = shape.ToProto();
2034 for (int64 dim : permutation) {
2035 instr.add_dimensions(dim);
2036 }
2037 return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
2038 }
2039
Rev(XlaOp operand,absl::Span<const int64> dimensions)2040 XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span<const int64> dimensions) {
2041 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2042 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2043 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape(
2044 *operand_shape, dimensions));
2045 return RevInternal(shape, operand, dimensions);
2046 });
2047 }
2048
RevInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> dimensions)2049 StatusOr<XlaOp> XlaBuilder::RevInternal(const Shape& shape, XlaOp operand,
2050 absl::Span<const int64> dimensions) {
2051 HloInstructionProto instr;
2052 *instr.mutable_shape() = shape.ToProto();
2053 for (int64 dim : dimensions) {
2054 instr.add_dimensions(dim);
2055 }
2056 return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand});
2057 }
2058
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64 dimension,bool is_stable)2059 XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
2060 const XlaComputation& comparator, int64 dimension,
2061 bool is_stable) {
2062 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2063 std::vector<const Shape*> operand_shape_ptrs;
2064 TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes,
2065 GetOperandShapes(operands));
2066 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
2067 [](const Shape& shape) { return &shape; });
2068 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape(
2069 HloOpcode::kSort, operand_shape_ptrs));
2070 return SortInternal(shape, operands, comparator, dimension, is_stable);
2071 });
2072 }
2073
SortInternal(const Shape & shape,absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64 dimension,bool is_stable)2074 StatusOr<XlaOp> XlaBuilder::SortInternal(const Shape& shape,
2075 absl::Span<const XlaOp> operands,
2076 const XlaComputation& comparator,
2077 int64 dimension, bool is_stable) {
2078 HloInstructionProto instr;
2079 *instr.mutable_shape() = shape.ToProto();
2080 instr.set_is_stable(is_stable);
2081 if (dimension == -1) {
2082 TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0]));
2083 dimension = keys_shape->rank() - 1;
2084 }
2085 instr.add_dimensions(dimension);
2086 AddCalledComputation(comparator, &instr);
2087 return AddInstruction(std::move(instr), HloOpcode::kSort, operands);
2088 }
2089
ConvertElementType(XlaOp operand,PrimitiveType new_element_type)2090 XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
2091 PrimitiveType new_element_type) {
2092 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2093 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2094 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
2095 *operand_shape, new_element_type));
2096 return AddOpWithShape(HloOpcode::kConvert, shape, {operand});
2097 });
2098 }
2099
BitcastConvertType(XlaOp operand,PrimitiveType new_element_type)2100 XlaOp XlaBuilder::BitcastConvertType(XlaOp operand,
2101 PrimitiveType new_element_type) {
2102 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2103 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2104 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
2105 *operand_shape, new_element_type));
2106 return BitcastConvertTypeInternal(shape, operand);
2107 });
2108 }
2109
BitcastConvertTypeInternal(const Shape & shape,XlaOp operand)2110 StatusOr<XlaOp> XlaBuilder::BitcastConvertTypeInternal(const Shape& shape,
2111 XlaOp operand) {
2112 HloInstructionProto instr;
2113 *instr.mutable_shape() = shape.ToProto();
2114 return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
2115 {operand});
2116 }
2117
Clamp(XlaOp min,XlaOp operand,XlaOp max)2118 XlaOp XlaBuilder::Clamp(XlaOp min, XlaOp operand, XlaOp max) {
2119 return TernaryOp(HloOpcode::kClamp, min, operand, max);
2120 }
2121
Map(absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64> dimensions,absl::Span<const XlaOp> static_operands)2122 XlaOp XlaBuilder::Map(absl::Span<const XlaOp> operands,
2123 const XlaComputation& computation,
2124 absl::Span<const int64> dimensions,
2125 absl::Span<const XlaOp> static_operands) {
2126 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2127 if (!static_operands.empty()) {
2128 return Unimplemented("static_operands is not supported in Map");
2129 }
2130
2131 HloInstructionProto instr;
2132 std::vector<const Shape*> operand_shape_ptrs;
2133 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
2134 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
2135 [](const Shape& shape) { return &shape; });
2136 TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
2137 computation.GetProgramShape());
2138 TF_ASSIGN_OR_RETURN(
2139 Shape shape, ShapeInference::InferMapShape(
2140 operand_shape_ptrs, called_program_shape, dimensions));
2141 *instr.mutable_shape() = shape.ToProto();
2142
2143 Shape output_shape(instr.shape());
2144 const int64 output_rank = output_shape.rank();
2145 AddCalledComputation(computation, &instr);
2146 std::vector<XlaOp> new_operands(operands.begin(), operands.end());
2147 for (XlaOp& new_operand : new_operands) {
2148 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(new_operand));
2149 const int64 rank = shape->rank();
2150 if (rank != output_rank) {
2151 TF_ASSIGN_OR_RETURN(new_operand,
2152 InDimBroadcast(output_shape, new_operand, {}));
2153 TF_ASSIGN_OR_RETURN(shape, GetShapePtr(new_operand));
2154 }
2155 if (!ShapeUtil::SameDimensions(output_shape, *shape)) {
2156 TF_ASSIGN_OR_RETURN(new_operand,
2157 AddBroadcastSequence(output_shape, new_operand));
2158 }
2159 }
2160
2161 return AddInstruction(std::move(instr), HloOpcode::kMap, new_operands);
2162 });
2163 }
2164
RngOp(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)2165 XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
2166 absl::Span<const XlaOp> parameters,
2167 const Shape& shape) {
2168 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2169 // Check the number of parameters per RNG distribution.
2170 switch (distribution) {
2171 case RandomDistribution::RNG_NORMAL:
2172 case RandomDistribution::RNG_UNIFORM:
2173 if (parameters.size() != 2) {
2174 return InvalidArgument(
2175 "RNG distribution (%s) expects 2 parameters, but got %ld",
2176 RandomDistribution_Name(distribution), parameters.size());
2177 }
2178 break;
2179 default:
2180 LOG(FATAL) << "unhandled distribution " << distribution;
2181 }
2182
2183 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
2184 return RngOpInternal(distribution, parameters, shape);
2185 });
2186 }
2187
RngOpInternal(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)2188 StatusOr<XlaOp> XlaBuilder::RngOpInternal(RandomDistribution distribution,
2189 absl::Span<const XlaOp> parameters,
2190 const Shape& shape) {
2191 HloInstructionProto instr;
2192 *instr.mutable_shape() = shape.ToProto();
2193 instr.set_distribution(distribution);
2194
2195 return AddInstruction(std::move(instr), HloOpcode::kRng, parameters);
2196 }
2197
RngNormal(XlaOp mu,XlaOp sigma,const Shape & shape)2198 XlaOp XlaBuilder::RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape) {
2199 return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
2200 }
2201
RngUniform(XlaOp a,XlaOp b,const Shape & shape)2202 XlaOp XlaBuilder::RngUniform(XlaOp a, XlaOp b, const Shape& shape) {
2203 return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
2204 }
2205
RngBitGenerator(RandomAlgorithm algorithm,XlaOp initial_state,const Shape & shape)2206 XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm,
2207 XlaOp initial_state, const Shape& shape) {
2208 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2209 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
2210 TF_ASSIGN_OR_RETURN(Shape state_shape, GetShape(initial_state));
2211 Shape output_shape = shape;
2212 switch (output_shape.element_type()) {
2213 case PrimitiveType::F32:
2214 case PrimitiveType::S32:
2215 case PrimitiveType::U32:
2216 output_shape.set_element_type(PrimitiveType::U32);
2217 break;
2218 case PrimitiveType::F64:
2219 case PrimitiveType::S64:
2220 case PrimitiveType::U64:
2221 output_shape.set_element_type(PrimitiveType::U64);
2222 break;
2223 default:
2224 return InvalidArgument("Unsupported shape for RngBitGenerator: %s",
2225 PrimitiveType_Name(output_shape.element_type()));
2226 }
2227 return RngBitGeneratorInternal(
2228 ShapeUtil::MakeTupleShape({state_shape, output_shape}), algorithm,
2229 initial_state);
2230 });
2231 }
2232
RngBitGeneratorInternal(const Shape & full_result_shape,RandomAlgorithm algorithm,XlaOp initial_state)2233 StatusOr<XlaOp> XlaBuilder::RngBitGeneratorInternal(
2234 const Shape& full_result_shape, RandomAlgorithm algorithm,
2235 XlaOp initial_state) {
2236 HloInstructionProto instr;
2237 *instr.mutable_shape() = full_result_shape.ToProto();
2238 instr.set_rng_algorithm(algorithm);
2239 return AddInstruction(std::move(instr), HloOpcode::kRngBitGenerator,
2240 {initial_state});
2241 }
2242
While(const XlaComputation & condition,const XlaComputation & body,XlaOp init)2243 XlaOp XlaBuilder::While(const XlaComputation& condition,
2244 const XlaComputation& body, XlaOp init) {
2245 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2246 // Infer shape.
2247 TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape());
2248 TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
2249 condition.GetProgramShape());
2250 TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init));
2251 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape(
2252 condition_program_shape,
2253 body_program_shape, *init_shape));
2254 return WhileInternal(shape, condition, body, init);
2255 });
2256 }
2257
WhileInternal(const Shape & shape,const XlaComputation & condition,const XlaComputation & body,XlaOp init)2258 StatusOr<XlaOp> XlaBuilder::WhileInternal(const Shape& shape,
2259 const XlaComputation& condition,
2260 const XlaComputation& body,
2261 XlaOp init) {
2262 HloInstructionProto instr;
2263 *instr.mutable_shape() = shape.ToProto();
2264 // Body comes before condition computation in the vector.
2265 AddCalledComputation(body, &instr);
2266 AddCalledComputation(condition, &instr);
2267 return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
2268 }
2269
Gather(XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)2270 XlaOp XlaBuilder::Gather(XlaOp input, XlaOp start_indices,
2271 const GatherDimensionNumbers& dimension_numbers,
2272 absl::Span<const int64> slice_sizes,
2273 bool indices_are_sorted) {
2274 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2275 TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
2276 TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape,
2277 GetShapePtr(start_indices));
2278 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGatherShape(
2279 *input_shape, *start_indices_shape,
2280 dimension_numbers, slice_sizes));
2281 return GatherInternal(shape, input, start_indices, dimension_numbers,
2282 slice_sizes, indices_are_sorted);
2283 });
2284 }
2285
GatherInternal(const Shape & shape,XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)2286 StatusOr<XlaOp> XlaBuilder::GatherInternal(
2287 const Shape& shape, XlaOp input, XlaOp start_indices,
2288 const GatherDimensionNumbers& dimension_numbers,
2289 absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
2290 HloInstructionProto instr;
2291 instr.set_indices_are_sorted(indices_are_sorted);
2292 *instr.mutable_shape() = shape.ToProto();
2293 *instr.mutable_gather_dimension_numbers() = dimension_numbers;
2294 for (int64 bound : slice_sizes) {
2295 instr.add_gather_slice_sizes(bound);
2296 }
2297
2298 return AddInstruction(std::move(instr), HloOpcode::kGather,
2299 {input, start_indices});
2300 }
2301
Scatter(XlaOp input,XlaOp scatter_indices,XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)2302 XlaOp XlaBuilder::Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
2303 const XlaComputation& update_computation,
2304 const ScatterDimensionNumbers& dimension_numbers,
2305 bool indices_are_sorted, bool unique_indices) {
2306 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2307 TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
2308 TF_ASSIGN_OR_RETURN(const Shape* scatter_indices_shape,
2309 GetShapePtr(scatter_indices));
2310 TF_ASSIGN_OR_RETURN(const Shape* updates_shape, GetShapePtr(updates));
2311 TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2312 update_computation.GetProgramShape());
2313 TF_ASSIGN_OR_RETURN(
2314 Shape shape, ShapeInference::InferScatterShape(
2315 *input_shape, *scatter_indices_shape, *updates_shape,
2316 to_apply_shape, dimension_numbers));
2317 return ScatterInternal(shape, input, scatter_indices, updates,
2318 update_computation, dimension_numbers,
2319 indices_are_sorted, unique_indices);
2320 });
2321 }
2322
ScatterInternal(const Shape & shape,XlaOp input,XlaOp scatter_indices,XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)2323 StatusOr<XlaOp> XlaBuilder::ScatterInternal(
2324 const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
2325 const XlaComputation& update_computation,
2326 const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
2327 bool unique_indices) {
2328 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2329 HloInstructionProto instr;
2330 instr.set_indices_are_sorted(indices_are_sorted);
2331 instr.set_unique_indices(unique_indices);
2332 *instr.mutable_shape() = shape.ToProto();
2333 *instr.mutable_scatter_dimension_numbers() = dimension_numbers;
2334
2335 AddCalledComputation(update_computation, &instr);
2336 return AddInstruction(std::move(instr), HloOpcode::kScatter,
2337 {input, scatter_indices, updates});
2338 });
2339 }
2340
Conditional(XlaOp predicate,XlaOp true_operand,const XlaComputation & true_computation,XlaOp false_operand,const XlaComputation & false_computation)2341 XlaOp XlaBuilder::Conditional(XlaOp predicate, XlaOp true_operand,
2342 const XlaComputation& true_computation,
2343 XlaOp false_operand,
2344 const XlaComputation& false_computation) {
2345 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2346 TF_ASSIGN_OR_RETURN(const xla::Shape* shape, GetShapePtr(predicate));
2347
2348 if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != PRED) {
2349 return InvalidArgument(
2350 "Argument to predicated-Conditional is not a scalar of PRED type "
2351 "(%s).",
2352 ShapeUtil::HumanString(*shape));
2353 }
2354 // The index of true_computation must be 0 and that of false computation
2355 // must be 1.
2356 return ConditionalImpl(predicate, {&true_computation, &false_computation},
2357 {true_operand, false_operand});
2358 });
2359 }
2360
Conditional(XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)2361 XlaOp XlaBuilder::Conditional(
2362 XlaOp branch_index,
2363 absl::Span<const XlaComputation* const> branch_computations,
2364 absl::Span<const XlaOp> branch_operands) {
2365 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2366 TF_ASSIGN_OR_RETURN(const xla::Shape* shape, GetShapePtr(branch_index));
2367
2368 if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != S32) {
2369 return InvalidArgument(
2370 "Argument to indexed-Conditional is not a scalar of S32 type (%s).",
2371 ShapeUtil::HumanString(*shape));
2372 }
2373 return ConditionalImpl(branch_index, branch_computations, branch_operands);
2374 });
2375 }
2376
ConditionalImpl(XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)2377 XlaOp XlaBuilder::ConditionalImpl(
2378 XlaOp branch_index,
2379 absl::Span<const XlaComputation* const> branch_computations,
2380 absl::Span<const XlaOp> branch_operands) {
2381 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2382 HloInstructionProto instr;
2383
2384 TF_ASSIGN_OR_RETURN(const Shape* branch_index_shape,
2385 GetShapePtr(branch_index));
2386 std::vector<Shape> branch_operand_shapes(branch_operands.size());
2387 std::vector<ProgramShape> branch_computation_shapes(
2388 branch_computations.size());
2389 for (int j = 0, end = branch_operands.size(); j < end; ++j) {
2390 TF_ASSIGN_OR_RETURN(branch_operand_shapes[j],
2391 GetShape(branch_operands[j]));
2392 TF_ASSIGN_OR_RETURN(branch_computation_shapes[j],
2393 branch_computations[j]->GetProgramShape());
2394 }
2395 TF_ASSIGN_OR_RETURN(const Shape shape,
2396 ShapeInference::InferConditionalShape(
2397 *branch_index_shape, branch_computation_shapes,
2398 branch_operand_shapes));
2399 *instr.mutable_shape() = shape.ToProto();
2400
2401 for (const XlaComputation* branch_computation : branch_computations) {
2402 AddCalledComputation(*branch_computation, &instr);
2403 }
2404
2405 std::vector<XlaOp> operands(1, branch_index);
2406 for (const XlaOp branch_operand : branch_operands) {
2407 operands.emplace_back(branch_operand);
2408 }
2409 return AddInstruction(std::move(instr), HloOpcode::kConditional,
2410 absl::MakeSpan(operands));
2411 });
2412 }
2413
CheckOpBuilder(XlaOp op) const2414 Status XlaBuilder::CheckOpBuilder(XlaOp op) const {
2415 if (this != op.builder()) {
2416 return InvalidArgument(
2417 "XlaOp with handle %d is built by builder '%s', but is trying to use "
2418 "it in builder '%s'",
2419 op.handle(), op.builder()->name(), name());
2420 }
2421 return Status::OK();
2422 }
2423
Reduce(XlaOp operand,XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)2424 XlaOp XlaBuilder::Reduce(XlaOp operand, XlaOp init_value,
2425 const XlaComputation& computation,
2426 absl::Span<const int64> dimensions_to_reduce) {
2427 return Reduce(absl::Span<const XlaOp>({operand}),
2428 absl::Span<const XlaOp>({init_value}), computation,
2429 dimensions_to_reduce);
2430 }
2431
Reduce(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)2432 XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
2433 absl::Span<const XlaOp> init_values,
2434 const XlaComputation& computation,
2435 absl::Span<const int64> dimensions_to_reduce) {
2436 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2437 TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
2438 computation.GetProgramShape());
2439
2440 std::vector<XlaOp> all_operands;
2441 all_operands.insert(all_operands.end(), operands.begin(), operands.end());
2442 all_operands.insert(all_operands.end(), init_values.begin(),
2443 init_values.end());
2444
2445 std::vector<const Shape*> operand_shape_ptrs;
2446 TF_ASSIGN_OR_RETURN(const auto& operand_shapes,
2447 GetOperandShapes(all_operands));
2448 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
2449 [](const Shape& shape) { return &shape; });
2450
2451 TF_ASSIGN_OR_RETURN(
2452 Shape shape,
2453 ShapeInference::InferReduceShape(
2454 operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
2455 return ReduceInternal(shape, all_operands, computation,
2456 dimensions_to_reduce);
2457 });
2458 }
2459
ReduceInternal(const Shape & shape,absl::Span<const XlaOp> all_operands,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)2460 StatusOr<XlaOp> XlaBuilder::ReduceInternal(
2461 const Shape& shape, absl::Span<const XlaOp> all_operands,
2462 const XlaComputation& computation,
2463 absl::Span<const int64> dimensions_to_reduce) {
2464 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2465 HloInstructionProto instr;
2466 *instr.mutable_shape() = shape.ToProto();
2467
2468 for (int64 dim : dimensions_to_reduce) {
2469 instr.add_dimensions(dim);
2470 }
2471
2472 AddCalledComputation(computation, &instr);
2473 return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands);
2474 });
2475 }
2476
ReduceAll(XlaOp operand,XlaOp init_value,const XlaComputation & computation)2477 XlaOp XlaBuilder::ReduceAll(XlaOp operand, XlaOp init_value,
2478 const XlaComputation& computation) {
2479 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2480 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2481 std::vector<int64> all_dimnos(operand_shape->rank());
2482 std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
2483 return Reduce(operand, init_value, computation, all_dimnos);
2484 });
2485 }
2486
ReduceWindow(XlaOp operand,XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)2487 XlaOp XlaBuilder::ReduceWindow(XlaOp operand, XlaOp init_value,
2488 const XlaComputation& computation,
2489 absl::Span<const int64> window_dimensions,
2490 absl::Span<const int64> window_strides,
2491 Padding padding) {
2492 return ReduceWindow(absl::MakeSpan(&operand, 1),
2493 absl::MakeSpan(&init_value, 1), computation,
2494 window_dimensions, window_strides, padding);
2495 }
2496
ReduceWindow(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)2497 XlaOp XlaBuilder::ReduceWindow(absl::Span<const XlaOp> operands,
2498 absl::Span<const XlaOp> init_values,
2499 const XlaComputation& computation,
2500 absl::Span<const int64> window_dimensions,
2501 absl::Span<const int64> window_strides,
2502 Padding padding) {
2503 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2504 const Shape* operand_shape = nullptr;
2505 for (const auto& operand : operands) {
2506 TF_ASSIGN_OR_RETURN(operand_shape, GetShapePtr(operand));
2507 TF_RETURN_IF_ERROR(
2508 ValidatePaddingValues(AsInt64Slice(operand_shape->dimensions()),
2509 window_dimensions, window_strides));
2510 }
2511 CHECK(operand_shape != nullptr);
2512 std::vector<std::pair<int64, int64>> padding_values =
2513 MakePadding(AsInt64Slice(operand_shape->dimensions()),
2514 window_dimensions, window_strides, padding);
2515 TF_ASSIGN_OR_RETURN(auto window,
2516 ShapeInference::InferWindowFromDimensions(
2517 window_dimensions, window_strides, padding_values,
2518 /*lhs_dilation=*/{},
2519 /*rhs_dilation=*/{}));
2520 PaddingType padding_type = PADDING_INVALID;
2521 for (int64 i = 0; i < operand_shape->rank(); ++i) {
2522 if (operand_shape->is_dynamic_dimension(i) &&
2523 !window_util::IsTrivialWindowDimension(window.dimensions(i)) &&
2524 padding == Padding::kSame) {
2525 // SAME padding can create dynamic padding sizes. The padding size
2526 // need to be rewritten by dynamic padder using HloInstructions. We
2527 // create a CustomCall to handle this.
2528 padding_type = PADDING_SAME;
2529 }
2530 }
2531 if (padding_type == PADDING_SAME) {
2532 TF_ASSIGN_OR_RETURN(
2533 HloInstructionProto instr,
2534 ReduceWindowInternal(operands, init_values, computation,
2535 window_dimensions, window_strides, {}, {},
2536 padding_values));
2537 instr.set_custom_call_target("DynamicReduceWindowSamePadding");
2538 std::vector<XlaOp> args;
2539 args.insert(args.end(), operands.begin(), operands.end());
2540 args.insert(args.end(), init_values.begin(), init_values.end());
2541 return AddInstruction(std::move(instr), HloOpcode::kCustomCall, args);
2542 }
2543 return ReduceWindowWithGeneralPadding(
2544 operands, init_values, computation, window_dimensions, window_strides,
2545 /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values);
2546 });
2547 }
2548
ReduceWindowWithGeneralPadding(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)2549 XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
2550 absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
2551 const XlaComputation& computation,
2552 absl::Span<const int64> window_dimensions,
2553 absl::Span<const int64> window_strides,
2554 absl::Span<const int64> base_dilations,
2555 absl::Span<const int64> window_dilations,
2556 absl::Span<const std::pair<int64, int64>> padding) {
2557 std::vector<const Shape*> operand_shapes, init_shapes;
2558 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2559 if (operands.size() == 1) {
2560 const auto& operand = operands[0];
2561 const auto& init_value = init_values[0];
2562 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2563 operand_shapes.push_back(operand_shape);
2564 TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
2565 init_shapes.push_back(init_shape);
2566
2567 TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2568 computation.GetProgramShape());
2569 TF_ASSIGN_OR_RETURN(auto window,
2570 ShapeInference::InferWindowFromDimensions(
2571 window_dimensions, window_strides, padding,
2572 /*lhs_dilation=*/base_dilations,
2573 /*rhs_dilation=*/window_dilations));
2574 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape(
2575 absl::MakeSpan(operand_shapes),
2576 absl::MakeSpan(init_shapes), window,
2577 to_apply_shape));
2578 return ReduceWindowInternal(shape, operands[0], init_values[0],
2579 computation, window);
2580 }
2581
2582 TF_ASSIGN_OR_RETURN(
2583 HloInstructionProto instr,
2584 ReduceWindowInternal(operands, init_values, computation,
2585 window_dimensions, window_strides, base_dilations,
2586 window_dilations, padding));
2587 std::vector<XlaOp> args;
2588 args.insert(args.end(), operands.begin(), operands.end());
2589 args.insert(args.end(), init_values.begin(), init_values.end());
2590 return AddInstruction(std::move(instr), HloOpcode::kReduceWindow, args);
2591 });
2592 }
2593
ReduceWindowInternal(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)2594 StatusOr<HloInstructionProto> XlaBuilder::ReduceWindowInternal(
2595 absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
2596 const XlaComputation& computation,
2597 absl::Span<const int64> window_dimensions,
2598 absl::Span<const int64> window_strides,
2599 absl::Span<const int64> base_dilations,
2600 absl::Span<const int64> window_dilations,
2601 absl::Span<const std::pair<int64, int64>> padding) {
2602 std::vector<const Shape*> operand_shapes, init_shapes;
2603 for (int i = 0; i < operands.size(); ++i) {
2604 const auto& operand = operands[i];
2605 const auto& init_value = init_values[i];
2606 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2607 operand_shapes.push_back(operand_shape);
2608 TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
2609 init_shapes.push_back(init_shape);
2610 }
2611 TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2612 computation.GetProgramShape());
2613 TF_ASSIGN_OR_RETURN(auto window,
2614 ShapeInference::InferWindowFromDimensions(
2615 window_dimensions, window_strides, padding,
2616 /*lhs_dilation=*/base_dilations,
2617 /*rhs_dilation=*/window_dilations));
2618 TF_ASSIGN_OR_RETURN(Shape shape,
2619 ShapeInference::InferReduceWindowShape(
2620 absl::MakeSpan(operand_shapes),
2621 absl::MakeSpan(init_shapes), window, to_apply_shape));
2622 HloInstructionProto instr;
2623 *instr.mutable_shape() = shape.ToProto();
2624 *instr.mutable_window() = std::move(window);
2625 AddCalledComputation(computation, &instr);
2626 return instr;
2627 }
2628
ReduceWindowInternal(const Shape & shape,XlaOp operand,XlaOp init_value,const XlaComputation & computation,Window window)2629 StatusOr<XlaOp> XlaBuilder::ReduceWindowInternal(
2630 const Shape& shape, XlaOp operand, XlaOp init_value,
2631 const XlaComputation& computation, Window window) {
2632 HloInstructionProto instr;
2633 *instr.mutable_shape() = shape.ToProto();
2634 *instr.mutable_window() = std::move(window);
2635
2636 AddCalledComputation(computation, &instr);
2637 return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
2638 {operand, init_value});
2639 }
2640
BatchNormTraining(XlaOp operand,XlaOp scale,XlaOp offset,float epsilon,int64 feature_index)2641 XlaOp XlaBuilder::BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
2642 float epsilon, int64 feature_index) {
2643 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2644 HloInstructionProto instr;
2645
2646 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2647 TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2648 TF_ASSIGN_OR_RETURN(const Shape* offset_shape, GetShapePtr(offset));
2649 TF_ASSIGN_OR_RETURN(
2650 Shape shape,
2651 ShapeInference::InferBatchNormTrainingShape(
2652 *operand_shape, *scale_shape, *offset_shape, feature_index));
2653 *instr.mutable_shape() = shape.ToProto();
2654
2655 instr.set_epsilon(epsilon);
2656 instr.set_feature_index(feature_index);
2657
2658 return AddInstruction(std::move(instr), HloOpcode::kBatchNormTraining,
2659 {operand, scale, offset});
2660 });
2661 }
2662
BatchNormInference(XlaOp operand,XlaOp scale,XlaOp offset,XlaOp mean,XlaOp variance,float epsilon,int64 feature_index)2663 XlaOp XlaBuilder::BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset,
2664 XlaOp mean, XlaOp variance, float epsilon,
2665 int64 feature_index) {
2666 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2667 HloInstructionProto instr;
2668
2669 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2670 TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2671 TF_ASSIGN_OR_RETURN(const Shape* offset_shape, GetShapePtr(offset));
2672 TF_ASSIGN_OR_RETURN(const Shape* mean_shape, GetShapePtr(mean));
2673 TF_ASSIGN_OR_RETURN(const Shape* variance_shape, GetShapePtr(variance));
2674 TF_ASSIGN_OR_RETURN(Shape shape,
2675 ShapeInference::InferBatchNormInferenceShape(
2676 *operand_shape, *scale_shape, *offset_shape,
2677 *mean_shape, *variance_shape, feature_index));
2678 *instr.mutable_shape() = shape.ToProto();
2679
2680 instr.set_epsilon(epsilon);
2681 instr.set_feature_index(feature_index);
2682
2683 return AddInstruction(std::move(instr), HloOpcode::kBatchNormInference,
2684 {operand, scale, offset, mean, variance});
2685 });
2686 }
2687
BatchNormGrad(XlaOp operand,XlaOp scale,XlaOp batch_mean,XlaOp batch_var,XlaOp grad_output,float epsilon,int64 feature_index)2688 XlaOp XlaBuilder::BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
2689 XlaOp batch_var, XlaOp grad_output,
2690 float epsilon, int64 feature_index) {
2691 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2692 HloInstructionProto instr;
2693
2694 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2695 TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2696 TF_ASSIGN_OR_RETURN(const Shape* batch_mean_shape, GetShapePtr(batch_mean));
2697 TF_ASSIGN_OR_RETURN(const Shape* batch_var_shape, GetShapePtr(batch_var));
2698 TF_ASSIGN_OR_RETURN(const Shape* grad_output_shape,
2699 GetShapePtr(grad_output));
2700 TF_ASSIGN_OR_RETURN(
2701 Shape shape, ShapeInference::InferBatchNormGradShape(
2702 *operand_shape, *scale_shape, *batch_mean_shape,
2703 *batch_var_shape, *grad_output_shape, feature_index));
2704 *instr.mutable_shape() = shape.ToProto();
2705
2706 instr.set_epsilon(epsilon);
2707 instr.set_feature_index(feature_index);
2708
2709 return AddInstruction(std::move(instr), HloOpcode::kBatchNormGrad,
2710 {operand, scale, batch_mean, batch_var, grad_output});
2711 });
2712 }
2713
AllGather(XlaOp operand,int64 all_gather_dimension,int64 shard_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Layout> & layout,const absl::optional<bool> use_global_device_ids)2714 XlaOp XlaBuilder::AllGather(XlaOp operand, int64 all_gather_dimension,
2715 int64 shard_count,
2716 absl::Span<const ReplicaGroup> replica_groups,
2717 const absl::optional<ChannelHandle>& channel_id,
2718 const absl::optional<Layout>& layout,
2719 const absl::optional<bool> use_global_device_ids) {
2720 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2721 HloInstructionProto instr;
2722 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2723
2724 TF_ASSIGN_OR_RETURN(Shape inferred_shape,
2725 ShapeInference::InferAllGatherShape(
2726 *operand_shape, all_gather_dimension, shard_count));
2727 if (layout) {
2728 *inferred_shape.mutable_layout() = *layout;
2729 instr.set_constrain_layout(true);
2730 }
2731 *instr.mutable_shape() = inferred_shape.ToProto();
2732
2733 instr.add_dimensions(all_gather_dimension);
2734 for (const ReplicaGroup& group : replica_groups) {
2735 *instr.add_replica_groups() = group;
2736 }
2737 if (channel_id.has_value()) {
2738 instr.set_channel_id(channel_id->handle());
2739 }
2740 if (use_global_device_ids.has_value()) {
2741 instr.set_use_global_device_ids(use_global_device_ids.value());
2742 }
2743
2744 TF_ASSIGN_OR_RETURN(
2745 auto all_gather,
2746 AddInstruction(std::move(instr), HloOpcode::kAllGather, {operand}));
2747 return all_gather;
2748 });
2749 }
2750
CrossReplicaSum(XlaOp operand,absl::Span<const ReplicaGroup> replica_groups)2751 XlaOp XlaBuilder::CrossReplicaSum(
2752 XlaOp operand, absl::Span<const ReplicaGroup> replica_groups) {
2753 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2754 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
2755 const Shape* element_shape;
2756 if (shape->IsTuple()) {
2757 if (shape->tuple_shapes_size() == 0) {
2758 return Unimplemented(
2759 "0 element tuple CrossReplicaSum is not supported");
2760 }
2761 element_shape = &shape->tuple_shapes(0);
2762 } else {
2763 element_shape = shape;
2764 }
2765 const Shape scalar_shape =
2766 ShapeUtil::MakeShape(element_shape->element_type(), {});
2767 auto b = CreateSubBuilder("sum");
2768 auto x = b->Parameter(/*parameter_number=*/0, scalar_shape, "x");
2769 auto y = b->Parameter(/*parameter_number=*/1, scalar_shape, "y");
2770 if (scalar_shape.element_type() == PRED) {
2771 Or(x, y);
2772 } else {
2773 Add(x, y);
2774 }
2775 TF_ASSIGN_OR_RETURN(auto computation, b->Build());
2776 return AllReduce(operand, computation, replica_groups,
2777 /*channel_id=*/absl::nullopt);
2778 });
2779 }
2780
AllReduce(XlaOp operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Shape> & shape_with_layout)2781 XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
2782 absl::Span<const ReplicaGroup> replica_groups,
2783 const absl::optional<ChannelHandle>& channel_id,
2784 const absl::optional<Shape>& shape_with_layout) {
2785 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2786 HloInstructionProto instr;
2787 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2788 std::vector<const Shape*> operand_shapes;
2789 std::vector<XlaOp> operands;
2790 if (operand_shape->IsTuple()) {
2791 if (operand_shape->tuple_shapes_size() == 0) {
2792 return Unimplemented("0 element tuple AllReduce is not supported");
2793 }
2794 for (int64 i = 0; i < operand_shape->tuple_shapes_size(); ++i) {
2795 if (operand_shape->tuple_shapes(i).element_type() !=
2796 operand_shape->tuple_shapes(0).element_type()) {
2797 return Unimplemented(
2798 "All the shapes of a tuple input of AllReduce must have the same "
2799 "element type");
2800 }
2801 operand_shapes.push_back(&operand_shape->tuple_shapes(i));
2802 operands.push_back(GetTupleElement(operand, i));
2803 }
2804 } else {
2805 operand_shapes.push_back(operand_shape);
2806 operands.push_back(operand);
2807 }
2808
2809 TF_ASSIGN_OR_RETURN(Shape inferred_shape,
2810 ShapeInference::InferAllReduceShape(operand_shapes));
2811 if (shape_with_layout) {
2812 if (!LayoutUtil::HasLayout(*shape_with_layout)) {
2813 return InvalidArgument("shape_with_layout must have the layout set: %s",
2814 shape_with_layout->ToString());
2815 }
2816 if (!ShapeUtil::Compatible(*shape_with_layout, *operand_shape)) {
2817 return InvalidArgument(
2818 "Provided shape_with_layout must be compatible with the "
2819 "operand shape: %s vs %s",
2820 shape_with_layout->ToString(), operand_shape->ToString());
2821 }
2822 instr.set_constrain_layout(true);
2823 if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
2824 // For a single-element tuple, take the tuple element shape.
2825 TF_RET_CHECK(shape_with_layout->tuple_shapes_size() == 1);
2826 *instr.mutable_shape() = shape_with_layout->tuple_shapes(0).ToProto();
2827 } else {
2828 *instr.mutable_shape() = shape_with_layout->ToProto();
2829 }
2830 } else {
2831 *instr.mutable_shape() = inferred_shape.ToProto();
2832 }
2833
2834 for (const ReplicaGroup& group : replica_groups) {
2835 *instr.add_replica_groups() = group;
2836 }
2837
2838 if (channel_id.has_value()) {
2839 instr.set_channel_id(channel_id->handle());
2840 }
2841
2842 AddCalledComputation(computation, &instr);
2843
2844 TF_ASSIGN_OR_RETURN(
2845 auto all_reduce,
2846 AddInstruction(std::move(instr), HloOpcode::kAllReduce, operands));
2847 if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
2848 // For a single-element tuple, wrap the result into a tuple.
2849 TF_RET_CHECK(operand_shapes.size() == 1);
2850 TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], inferred_shape));
2851 return Tuple({all_reduce});
2852 }
2853 return all_reduce;
2854 });
2855 }
2856
AllToAll(XlaOp operand,int64 split_dimension,int64 concat_dimension,int64 split_count,const std::vector<ReplicaGroup> & replica_groups,const absl::optional<Layout> & layout)2857 XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension,
2858 int64 concat_dimension, int64 split_count,
2859 const std::vector<ReplicaGroup>& replica_groups,
2860 const absl::optional<Layout>& layout) {
2861 // Array all_to_all may need to violate layout constraint to be legal so use
2862 // the tuple version.
2863 if (layout.has_value()) {
2864 return AllToAllTuple(operand, split_dimension, concat_dimension,
2865 split_count, replica_groups, layout);
2866 }
2867 return AllToAllArray(operand, split_dimension, concat_dimension, split_count,
2868 replica_groups);
2869 }
2870
AllToAllArray(XlaOp operand,int64 split_dimension,int64 concat_dimension,int64 split_count,const std::vector<ReplicaGroup> & replica_groups)2871 XlaOp XlaBuilder::AllToAllArray(
2872 XlaOp operand, int64 split_dimension, int64 concat_dimension,
2873 int64 split_count, const std::vector<ReplicaGroup>& replica_groups) {
2874 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2875 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2876 TF_ASSIGN_OR_RETURN(
2877 const Shape all_to_all_shape,
2878 ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
2879 concat_dimension, split_count));
2880 HloInstructionProto instr;
2881 *instr.mutable_shape() = operand_shape->ToProto();
2882 if (replica_groups.empty()) {
2883 auto* group = instr.add_replica_groups();
2884 for (int64 i = 0; i < split_count; ++i) {
2885 group->add_replica_ids(i);
2886 }
2887 } else {
2888 for (const ReplicaGroup& group : replica_groups) {
2889 *instr.add_replica_groups() = group;
2890 }
2891 }
2892 instr.add_dimensions(split_dimension);
2893 TF_ASSIGN_OR_RETURN(
2894 XlaOp all_to_all,
2895 AddInstruction(std::move(instr), HloOpcode::kAllToAll, {operand}));
2896 if (split_dimension == concat_dimension) {
2897 return all_to_all;
2898 }
2899 DimensionVector sizes;
2900 for (int64 i = 0; i < operand_shape->rank(); ++i) {
2901 if (i != split_dimension) {
2902 sizes.push_back(operand_shape->dimensions(i));
2903 continue;
2904 }
2905 sizes.push_back(split_count);
2906 sizes.push_back(operand_shape->dimensions(i) / split_count);
2907 }
2908 all_to_all = Reshape(all_to_all, sizes);
2909
2910 std::vector<int64> permutation;
2911 for (int64 i = 0; i < operand_shape->rank(); ++i) {
2912 int64 dim_after_reshape = i >= split_dimension ? i + 1 : i;
2913 if (i == concat_dimension) {
2914 permutation.push_back(split_dimension);
2915 }
2916 permutation.push_back(dim_after_reshape);
2917 }
2918 all_to_all = Transpose(all_to_all, permutation);
2919 return Reshape(all_to_all_shape, all_to_all);
2920 });
2921 }
2922
AllToAllTuple(XlaOp operand,int64 split_dimension,int64 concat_dimension,int64 split_count,const std::vector<ReplicaGroup> & replica_groups,const absl::optional<Layout> & layout)2923 XlaOp XlaBuilder::AllToAllTuple(XlaOp operand, int64 split_dimension,
2924 int64 concat_dimension, int64 split_count,
2925 const std::vector<ReplicaGroup>& replica_groups,
2926 const absl::optional<Layout>& layout) {
2927 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2928 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2929
2930 // The HloInstruction for Alltoall currently only handles the data
2931 // communication: it accepts N already split parts and scatters them to N
2932 // cores, and each core gathers the N received parts into a tuple as the
2933 // output. So here we explicitly split the operand before the hlo alltoall,
2934 // and concat the tuple elements.
2935 //
2936 // First, run shape inference to make sure the shapes are valid.
2937 TF_RETURN_IF_ERROR(
2938 ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
2939 concat_dimension, split_count)
2940 .status());
2941
2942 // Split into N parts.
2943 std::vector<XlaOp> slices;
2944 slices.reserve(split_count);
2945 const int64 block_size =
2946 operand_shape->dimensions(split_dimension) / split_count;
2947 for (int i = 0; i < split_count; i++) {
2948 slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size,
2949 /*limit_index=*/(i + 1) * block_size,
2950 /*stride=*/1, /*dimno=*/split_dimension));
2951 }
2952
2953 // Handle data communication.
2954 HloInstructionProto instr;
2955 TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices));
2956 std::vector<const Shape*> slice_shape_ptrs;
2957 absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs),
2958 [](const Shape& shape) { return &shape; });
2959 TF_ASSIGN_OR_RETURN(
2960 Shape shape, ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
2961
2962 if (layout) {
2963 TF_RET_CHECK(shape.IsTuple() && !ShapeUtil::IsNestedTuple(shape));
2964 for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2965 const int64 layout_minor_to_major_size =
2966 layout->minor_to_major().size();
2967 if (layout_minor_to_major_size != shape.tuple_shapes(i).rank()) {
2968 return InvalidArgument(
2969 "Provided layout must be compatible with the operand shape: %s "
2970 "vs %s",
2971 layout->ToString(), operand_shape->ToString());
2972 }
2973 *(shape.mutable_tuple_shapes(i)->mutable_layout()) = *layout;
2974 }
2975 instr.set_constrain_layout(true);
2976 }
2977 *instr.mutable_shape() = shape.ToProto();
2978
2979 for (const ReplicaGroup& group : replica_groups) {
2980 *instr.add_replica_groups() = group;
2981 }
2982 TF_ASSIGN_OR_RETURN(
2983 XlaOp alltoall,
2984 AddInstruction(std::move(instr), HloOpcode::kAllToAll, slices));
2985
2986 // Concat the N received parts.
2987 std::vector<XlaOp> received;
2988 received.reserve(split_count);
2989 for (int i = 0; i < split_count; i++) {
2990 received.push_back(this->GetTupleElement(alltoall, i));
2991 }
2992 return this->ConcatInDim(received, concat_dimension);
2993 });
2994 }
2995
CollectivePermute(XlaOp operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)2996 XlaOp XlaBuilder::CollectivePermute(
2997 XlaOp operand,
2998 const std::vector<std::pair<int64, int64>>& source_target_pairs) {
2999 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3000 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3001 HloInstructionProto instr;
3002 TF_ASSIGN_OR_RETURN(
3003 Shape shape,
3004 ShapeInference::InferCollectivePermuteShape(*operand_shape));
3005 *instr.mutable_shape() = shape.ToProto();
3006
3007 for (const auto& pair : source_target_pairs) {
3008 auto* proto_pair = instr.add_source_target_pairs();
3009 proto_pair->set_source(pair.first);
3010 proto_pair->set_target(pair.second);
3011 }
3012
3013 return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute,
3014 {operand});
3015 });
3016 }
3017
ReplicaId()3018 XlaOp XlaBuilder::ReplicaId() {
3019 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3020 HloInstructionProto instr;
3021 *instr.mutable_shape() = ShapeUtil::MakeShape(U32, {}).ToProto();
3022 return AddInstruction(std::move(instr), HloOpcode::kReplicaId, {});
3023 });
3024 }
3025
SelectAndScatter(XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)3026 XlaOp XlaBuilder::SelectAndScatter(XlaOp operand, const XlaComputation& select,
3027 absl::Span<const int64> window_dimensions,
3028 absl::Span<const int64> window_strides,
3029 Padding padding, XlaOp source,
3030 XlaOp init_value,
3031 const XlaComputation& scatter) {
3032 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3033 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3034
3035 std::vector<std::pair<int64, int64>> padding_values =
3036 MakePadding(AsInt64Slice(operand_shape->dimensions()),
3037 window_dimensions, window_strides, padding);
3038
3039 TF_ASSIGN_OR_RETURN(auto window,
3040 ShapeInference::InferWindowFromDimensions(
3041 window_dimensions, window_strides, padding_values,
3042 /*lhs_dilation=*/{},
3043 /*rhs_dilation=*/{}));
3044 PaddingType padding_type = PADDING_INVALID;
3045 for (int64 i = 0; i < operand_shape->rank(); ++i) {
3046 if (operand_shape->is_dynamic_dimension(i) &&
3047 !window_util::IsTrivialWindowDimension(window.dimensions(i)) &&
3048 padding == Padding::kSame) {
3049 // SAME padding can create dynamic padding sizes. The padding size
3050 // need to be rewritten by dynamic padder using HloInstructions. We
3051 // create a CustomCall to handle this.
3052 padding_type = PADDING_SAME;
3053 }
3054 }
3055 if (padding_type == PADDING_SAME) {
3056 TF_ASSIGN_OR_RETURN(
3057 HloInstructionProto instr,
3058 SelectAndScatterInternal(operand, select, window_dimensions,
3059 window_strides, padding_values, source,
3060 init_value, scatter));
3061 instr.set_custom_call_target("DynamicSelectAndScatterSamePadding");
3062 return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
3063 {operand, source, init_value});
3064 }
3065 return SelectAndScatterWithGeneralPadding(
3066 operand, select, window_dimensions, window_strides, padding_values,
3067 source, init_value, scatter);
3068 });
3069 }
3070
SelectAndScatterInternal(XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)3071 StatusOr<HloInstructionProto> XlaBuilder::SelectAndScatterInternal(
3072 XlaOp operand, const XlaComputation& select,
3073 absl::Span<const int64> window_dimensions,
3074 absl::Span<const int64> window_strides,
3075 absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
3076 XlaOp init_value, const XlaComputation& scatter) {
3077 HloInstructionProto instr;
3078
3079 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3080 TF_ASSIGN_OR_RETURN(const Shape* source_shape, GetShapePtr(source));
3081 TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
3082 TF_ASSIGN_OR_RETURN(const ProgramShape& select_shape,
3083 select.GetProgramShape());
3084 TF_ASSIGN_OR_RETURN(const ProgramShape& scatter_shape,
3085 scatter.GetProgramShape());
3086 TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
3087 ShapeInference::InferWindowFromDimensions(
3088 window_dimensions, window_strides, padding,
3089 /*lhs_dilation=*/{}, /*rhs_dilation=*/{}));
3090 TF_ASSIGN_OR_RETURN(Shape shape,
3091 ShapeInference::InferSelectAndScatterShape(
3092 *operand_shape, select_shape, instr.window(),
3093 *source_shape, *init_shape, scatter_shape));
3094 *instr.mutable_shape() = shape.ToProto();
3095
3096 AddCalledComputation(select, &instr);
3097 AddCalledComputation(scatter, &instr);
3098 return instr;
3099 }
3100
SelectAndScatterWithGeneralPadding(XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)3101 XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
3102 XlaOp operand, const XlaComputation& select,
3103 absl::Span<const int64> window_dimensions,
3104 absl::Span<const int64> window_strides,
3105 absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
3106 XlaOp init_value, const XlaComputation& scatter) {
3107 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3108 TF_ASSIGN_OR_RETURN(HloInstructionProto instr,
3109 SelectAndScatterInternal(
3110 operand, select, window_dimensions, window_strides,
3111 padding, source, init_value, scatter));
3112
3113 return AddInstruction(std::move(instr), HloOpcode::kSelectAndScatter,
3114 {operand, source, init_value});
3115 });
3116 }
3117
ReducePrecision(XlaOp operand,const int exponent_bits,const int mantissa_bits)3118 XlaOp XlaBuilder::ReducePrecision(XlaOp operand, const int exponent_bits,
3119 const int mantissa_bits) {
3120 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3121 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3122 TF_ASSIGN_OR_RETURN(Shape shape,
3123 ShapeInference::InferReducePrecisionShape(
3124 *operand_shape, exponent_bits, mantissa_bits));
3125 return ReducePrecisionInternal(shape, operand, exponent_bits,
3126 mantissa_bits);
3127 });
3128 }
3129
ReducePrecisionInternal(const Shape & shape,XlaOp operand,const int exponent_bits,const int mantissa_bits)3130 StatusOr<XlaOp> XlaBuilder::ReducePrecisionInternal(const Shape& shape,
3131 XlaOp operand,
3132 const int exponent_bits,
3133 const int mantissa_bits) {
3134 HloInstructionProto instr;
3135 *instr.mutable_shape() = shape.ToProto();
3136 instr.set_exponent_bits(exponent_bits);
3137 instr.set_mantissa_bits(mantissa_bits);
3138 return AddInstruction(std::move(instr), HloOpcode::kReducePrecision,
3139 {operand});
3140 }
3141
Send(XlaOp operand,const ChannelHandle & handle)3142 void XlaBuilder::Send(XlaOp operand, const ChannelHandle& handle) {
3143 ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3144 // Send HLO takes two operands: a data operand and a token. Generate the
3145 // token to pass into the send.
3146 // TODO(b/80000000): Remove this when clients have been updated to handle
3147 // tokens.
3148 HloInstructionProto token_instr;
3149 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3150 TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
3151 HloOpcode::kAfterAll, {}));
3152
3153 return SendWithToken(operand, token, handle);
3154 });
3155 }
3156
SendWithToken(XlaOp operand,XlaOp token,const ChannelHandle & handle)3157 XlaOp XlaBuilder::SendWithToken(XlaOp operand, XlaOp token,
3158 const ChannelHandle& handle) {
3159 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3160 if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
3161 return InvalidArgument("Send must use a device-to-device channel");
3162 }
3163
3164 // Send instruction produces a tuple of {aliased operand, U32 context,
3165 // token}.
3166 HloInstructionProto send_instr;
3167 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
3168 *send_instr.mutable_shape() =
3169 ShapeUtil::MakeTupleShape({*shape, ShapeUtil::MakeShape(U32, {}),
3170 ShapeUtil::MakeTokenShape()})
3171 .ToProto();
3172 send_instr.set_channel_id(handle.handle());
3173 TF_ASSIGN_OR_RETURN(XlaOp send,
3174 AddInstruction(std::move(send_instr), HloOpcode::kSend,
3175 {operand, token}));
3176
3177 HloInstructionProto send_done_instr;
3178 *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3179 send_done_instr.set_channel_id(handle.handle());
3180 return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
3181 {send});
3182 });
3183 }
3184
Recv(const Shape & shape,const ChannelHandle & handle)3185 XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
3186 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3187 // Recv HLO takes a single token operand. Generate the token to pass into
3188 // the Recv and RecvDone instructions.
3189 // TODO(b/80000000): Remove this when clients have been updated to handle
3190 // tokens.
3191 HloInstructionProto token_instr;
3192 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3193 TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
3194 HloOpcode::kAfterAll, {}));
3195
3196 XlaOp recv = RecvWithToken(token, shape, handle);
3197
3198 // The RecvDone instruction produces a tuple of the data and a token
3199 // type. Return XLA op containing the data.
3200 // TODO(b/80000000): Remove this when clients have been updated to handle
3201 // tokens.
3202 HloInstructionProto recv_data;
3203 *recv_data.mutable_shape() = shape.ToProto();
3204 recv_data.set_tuple_index(0);
3205 return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement,
3206 {recv});
3207 });
3208 }
3209
RecvWithToken(XlaOp token,const Shape & shape,const ChannelHandle & handle)3210 XlaOp XlaBuilder::RecvWithToken(XlaOp token, const Shape& shape,
3211 const ChannelHandle& handle) {
3212 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3213 if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
3214 return InvalidArgument("Recv must use a device-to-device channel");
3215 }
3216
3217 // Recv instruction produces a tuple of {receive buffer, U32 context,
3218 // token}.
3219 HloInstructionProto recv_instr;
3220 *recv_instr.mutable_shape() =
3221 ShapeUtil::MakeTupleShape(
3222 {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
3223 .ToProto();
3224 recv_instr.set_channel_id(handle.handle());
3225 TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
3226 HloOpcode::kRecv, {token}));
3227
3228 HloInstructionProto recv_done_instr;
3229 *recv_done_instr.mutable_shape() =
3230 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
3231 .ToProto();
3232 recv_done_instr.set_channel_id(handle.handle());
3233 return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
3234 {recv});
3235 });
3236 }
3237
SendToHost(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const ChannelHandle & handle)3238 XlaOp XlaBuilder::SendToHost(XlaOp operand, XlaOp token,
3239 const Shape& shape_with_layout,
3240 const ChannelHandle& handle) {
3241 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3242 if (!LayoutUtil::HasLayout(shape_with_layout)) {
3243 return InvalidArgument("Shape passed to SendToHost must have a layout");
3244 }
3245 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3246 if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
3247 return InvalidArgument(
3248 "SendToHost shape %s must be compatible with operand shape %s",
3249 ShapeUtil::HumanStringWithLayout(shape_with_layout),
3250 ShapeUtil::HumanStringWithLayout(*operand_shape));
3251 }
3252 // TODO(b/111544877): Support tuple shapes.
3253 if (!operand_shape->IsArray()) {
3254 return InvalidArgument("SendToHost only supports array shapes, shape: %s",
3255 ShapeUtil::HumanString(*operand_shape));
3256 }
3257
3258 if (handle.type() != ChannelHandle::DEVICE_TO_HOST) {
3259 return InvalidArgument("SendToHost must use a device-to-host channel");
3260 }
3261
3262 // Send instruction produces a tuple of {aliased operand, U32 context,
3263 // token}.
3264 HloInstructionProto send_instr;
3265 *send_instr.mutable_shape() =
3266 ShapeUtil::MakeTupleShape({shape_with_layout,
3267 ShapeUtil::MakeShape(U32, {}),
3268 ShapeUtil::MakeTokenShape()})
3269 .ToProto();
3270 send_instr.set_channel_id(handle.handle());
3271 send_instr.set_is_host_transfer(true);
3272 TF_ASSIGN_OR_RETURN(XlaOp send,
3273 AddInstruction(std::move(send_instr), HloOpcode::kSend,
3274 {operand, token}));
3275
3276 HloInstructionProto send_done_instr;
3277 *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3278 send_done_instr.set_channel_id(handle.handle());
3279 send_done_instr.set_is_host_transfer(true);
3280 return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
3281 {send});
3282 });
3283 }
3284
RecvFromHost(XlaOp token,const Shape & shape,const ChannelHandle & handle)3285 XlaOp XlaBuilder::RecvFromHost(XlaOp token, const Shape& shape,
3286 const ChannelHandle& handle) {
3287 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3288 if (!LayoutUtil::HasLayout(shape)) {
3289 return InvalidArgument("Shape passed to RecvFromHost must have a layout");
3290 }
3291
3292 // TODO(b/111544877): Support tuple shapes.
3293 if (!shape.IsArray()) {
3294 return InvalidArgument(
3295 "RecvFromHost only supports array shapes, shape: %s",
3296 ShapeUtil::HumanString(shape));
3297 }
3298
3299 if (handle.type() != ChannelHandle::HOST_TO_DEVICE) {
3300 return InvalidArgument("RecvFromHost must use a host-to-device channel");
3301 }
3302
3303 // Recv instruction produces a tuple of {receive buffer, U32 context,
3304 // token}.
3305 HloInstructionProto recv_instr;
3306 *recv_instr.mutable_shape() =
3307 ShapeUtil::MakeTupleShape(
3308 {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
3309 .ToProto();
3310 recv_instr.set_channel_id(handle.handle());
3311 recv_instr.set_is_host_transfer(true);
3312 TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
3313 HloOpcode::kRecv, {token}));
3314
3315 HloInstructionProto recv_done_instr;
3316 *recv_done_instr.mutable_shape() =
3317 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
3318 .ToProto();
3319 recv_done_instr.set_channel_id(handle.handle());
3320 recv_done_instr.set_is_host_transfer(true);
3321 return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
3322 {recv});
3323 });
3324 }
3325
GetDimensionSize(XlaOp operand,int64 dimension)3326 XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64 dimension) {
3327 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3328 HloInstructionProto instr;
3329 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3330 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape(
3331 *operand_shape, dimension));
3332 // Calling GetDimensionSize on a static dimension returns a constant
3333 // instruction.
3334 if (!operand_shape->is_dynamic_dimension(dimension)) {
3335 return ConstantR0<int32>(this, operand_shape->dimensions(dimension));
3336 }
3337 *instr.mutable_shape() = shape.ToProto();
3338 instr.add_dimensions(dimension);
3339 return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize,
3340 {operand});
3341 });
3342 }
3343
RemoveDynamicDimension(XlaOp operand,int64 dimension)3344 XlaOp XlaBuilder::RemoveDynamicDimension(XlaOp operand, int64 dimension) {
3345 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3346 HloInstructionProto instr;
3347 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3348
3349 Shape shape = *operand_shape;
3350 shape.set_dynamic_dimension(dimension, false);
3351 // Setting an op's dynamic dimension to its static size removes the dynamic
3352 // dimension.
3353 XlaOp static_size =
3354 ConstantR0<int32>(this, operand_shape->dimensions(dimension));
3355
3356 *instr.mutable_shape() = shape.ToProto();
3357 instr.add_dimensions(dimension);
3358 return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize,
3359 {operand, static_size});
3360 });
3361 }
3362
SetDimensionSize(XlaOp operand,XlaOp val,int64 dimension)3363 XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) {
3364 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3365 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3366 TF_ASSIGN_OR_RETURN(const Shape* val_shape, GetShapePtr(val));
3367
3368 TF_ASSIGN_OR_RETURN(Shape shape,
3369 ShapeInference::InferSetDimensionSizeShape(
3370 *operand_shape, *val_shape, dimension));
3371 return SetDimensionSizeInternal(shape, operand, val, dimension);
3372 });
3373 }
3374
SetDimensionSizeInternal(const Shape & shape,XlaOp operand,XlaOp val,int64 dimension)3375 StatusOr<XlaOp> XlaBuilder::SetDimensionSizeInternal(const Shape& shape,
3376 XlaOp operand, XlaOp val,
3377 int64 dimension) {
3378 // Setting an op's dynamic dimension to the static size is a noop.
3379 TF_ASSIGN_OR_RETURN(const HloInstructionProto* val_proto,
3380 LookUpInstruction(val));
3381 if (StringToHloOpcode(val_proto->opcode()).ValueOrDie() ==
3382 HloOpcode::kConstant) {
3383 TF_ASSIGN_OR_RETURN(auto literal,
3384 Literal::CreateFromProto(val_proto->literal(), true));
3385 if (literal.Get<int32>({}) == shape.dimensions(dimension)) {
3386 return operand;
3387 }
3388 }
3389
3390 HloInstructionProto instr;
3391 *instr.mutable_shape() = shape.ToProto();
3392 instr.add_dimensions(dimension);
3393 return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize,
3394 {operand, val});
3395 }
3396
IsConstant(XlaOp operand) const3397 StatusOr<bool> XlaBuilder::IsConstant(XlaOp operand) const {
3398 TF_RETURN_IF_ERROR(first_error_);
3399
3400 // Verify that the handle is valid.
3401 TF_RETURN_IF_ERROR(LookUpInstruction(operand).status());
3402
3403 bool is_constant = true;
3404 absl::flat_hash_set<int64> visited;
3405 IsConstantVisitor(operand.handle(), &visited, &is_constant);
3406 return is_constant;
3407 }
3408
BuildDynamicInferenceGraph(XlaOp root_op)3409 StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
3410 TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
3411 LookUpInstruction(root_op));
3412
3413 HloComputationProto entry;
3414 SetProtoIdAndName(&entry, StrCat(name_, "_dynamic_inference"), kNameSeparator,
3415 GetNextId());
3416 ProgramShapeProto* program_shape = entry.mutable_program_shape();
3417 *program_shape->mutable_result() =
3418 ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto();
3419
3420 std::vector<HloComputationProto> called_computations;
3421 auto operand_is_constant = [&](const HloInstructionProto* instr_proto,
3422 int64 operand_index) -> StatusOr<bool> {
3423 int64 operand_id = instr_proto->operand_ids(operand_index);
3424 bool is_constant = true;
3425 absl::flat_hash_set<int64> visited;
3426 IsConstantVisitor(operand_id, &visited, &is_constant);
3427 return is_constant;
3428 };
3429 // Process instruction and copy it into the new graph. The new node in the new
3430 // graph with have id set to `id`.
3431 auto process_instruction = [&](const HloInstructionProto* instr_proto,
3432 bool need_rewrite, int64 id,
3433 absl::Span<int64 const> operand_ids,
3434 int64* global_id) {
3435 // Rewrite the instruction with following rules:
3436 // - Unary ops: Convert into bitcast (identity) with type Pred.
3437 // - Binary ops: Convert into binary or.
3438 // - Select: Convert into binary or with its two data operands.
3439 // - Concat / Tuple/ GTE / Bitcast: Copy.
3440 // - Param: Convert to constant True.
3441 // - GetDimensionSize: Convert to constant True if dimension is dynamic,
3442 // contant False if dimension is static.
3443 // - Reduce: Convert to reduce or.
3444 // - Constant: Convert to constant False.
3445 // - Reshape, slice, transpose, pad:
3446 // Convert into predicate type with same opcode.
3447 // - Other ops: Not supported.
3448 // Create the instruction for the new handle.
3449 TF_ASSIGN_OR_RETURN(HloOpcode opcode,
3450 StringToHloOpcode(instr_proto->opcode()));
3451 auto* new_instr = entry.add_instructions();
3452 *new_instr = *instr_proto;
3453 new_instr->set_id(id);
3454 new_instr->mutable_operand_ids()->Clear();
3455 for (auto operand_id : operand_ids) {
3456 new_instr->mutable_operand_ids()->Add(operand_id);
3457 }
3458
3459 if (!need_rewrite) {
3460 *new_instr->mutable_name() =
3461 GetFullName(instr_proto->opcode(), kNameSeparator, id);
3462 if (opcode == HloOpcode::kReduce) {
3463 // Copy the reducer to the new module, with a new id that's same as the
3464 // reduce op.
3465 HloComputationProto* reducer =
3466 &embedded_[new_instr->called_computation_ids(0)];
3467 int64 reducer_id = (*global_id)++;
3468 new_instr->clear_called_computation_ids();
3469 new_instr->add_called_computation_ids(reducer_id);
3470 called_computations.push_back(CopyReducer(
3471 reducer_id, reducer, /*rewrite_into_pred=*/false, global_id));
3472 }
3473 return Status::OK();
3474 }
3475 *new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape());
3476 Shape new_shape(new_instr->shape());
3477 switch (opcode) {
3478 case HloOpcode::kAbs:
3479 case HloOpcode::kRoundNearestAfz:
3480 case HloOpcode::kBitcast:
3481 case HloOpcode::kCeil:
3482 case HloOpcode::kCollectivePermuteDone:
3483 case HloOpcode::kCos:
3484 case HloOpcode::kClz:
3485 case HloOpcode::kExp:
3486 case HloOpcode::kExpm1:
3487 case HloOpcode::kFloor:
3488 case HloOpcode::kImag:
3489 case HloOpcode::kIsFinite:
3490 case HloOpcode::kLog:
3491 case HloOpcode::kLog1p:
3492 case HloOpcode::kNot:
3493 case HloOpcode::kNegate:
3494 case HloOpcode::kPopulationCount:
3495 case HloOpcode::kReal:
3496 case HloOpcode::kRsqrt:
3497 case HloOpcode::kLogistic:
3498 case HloOpcode::kSign:
3499 case HloOpcode::kSin:
3500 case HloOpcode::kConvert:
3501 case HloOpcode::kSqrt:
3502 case HloOpcode::kCbrt:
3503 case HloOpcode::kTanh:
3504 CHECK_EQ(instr_proto->operand_ids_size(), 1);
3505 *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kBitcast);
3506 break;
3507 case HloOpcode::kAdd:
3508 case HloOpcode::kAtan2:
3509 case HloOpcode::kDivide:
3510 case HloOpcode::kComplex:
3511 case HloOpcode::kMaximum:
3512 case HloOpcode::kMinimum:
3513 case HloOpcode::kMultiply:
3514 case HloOpcode::kPower:
3515 case HloOpcode::kRemainder:
3516 case HloOpcode::kSubtract:
3517 case HloOpcode::kCompare:
3518 case HloOpcode::kAnd:
3519 case HloOpcode::kOr:
3520 case HloOpcode::kXor:
3521 case HloOpcode::kShiftLeft:
3522 case HloOpcode::kShiftRightArithmetic:
3523 case HloOpcode::kShiftRightLogical:
3524 CHECK_EQ(instr_proto->operand_ids_size(), 2);
3525 *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
3526 break;
3527 case HloOpcode::kSelect: {
3528 TF_ASSIGN_OR_RETURN(bool constant_predicate,
3529 operand_is_constant(instr_proto, 0));
3530 if (!constant_predicate) {
3531 // If the predicate operand is not constant, conservatively assume the
3532 // entire result values are dynamic.
3533 SetInstructionAsConstant(new_instr, id, new_shape, true);
3534 }
3535 break;
3536 }
3537 case HloOpcode::kGather: {
3538 TF_ASSIGN_OR_RETURN(bool constant_indices,
3539 operand_is_constant(instr_proto, 1));
3540 if (!constant_indices) {
3541 // If the indices operand is not constant, conservatively assume the
3542 // entire result values are dynamic.
3543 SetInstructionAsConstant(new_instr, id, new_shape, true);
3544 }
3545 break;
3546 }
3547 case HloOpcode::kReduce: {
3548 auto* reducer = &embedded_[new_instr->called_computation_ids(0)];
3549 int64 reducer_id = (*global_id)++;
3550 new_instr->clear_called_computation_ids();
3551 new_instr->add_called_computation_ids(reducer_id);
3552 called_computations.push_back(CopyReducer(
3553 reducer_id, reducer, /*rewrite_into_pred=*/true, global_id));
3554 break;
3555 }
3556 case HloOpcode::kTuple:
3557 case HloOpcode::kTranspose:
3558 case HloOpcode::kSlice:
3559 case HloOpcode::kBroadcast:
3560 case HloOpcode::kConcatenate:
3561 case HloOpcode::kReshape:
3562 case HloOpcode::kPad:
3563 break;
3564 case HloOpcode::kGetTupleElement: {
3565 // Rewrite parameter followed by gte into constants to avoid
3566 // rematerializing the tuple parameter (could be very large).
3567 int64 operand_handle = instr_proto->operand_ids(0);
3568 TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
3569 LookUpInstructionByHandle(operand_handle));
3570 TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
3571 StringToHloOpcode(operand_proto->opcode()));
3572 if (operand_opcode == HloOpcode::kParameter) {
3573 SetInstructionAsConstant(new_instr, id, new_shape, true);
3574 }
3575 break;
3576 }
3577 case HloOpcode::kGetDimensionSize: {
3578 int64 dimension = instr_proto->dimensions(0);
3579 int64 operand_handle = instr_proto->operand_ids(0);
3580 TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
3581 LookUpInstructionByHandle(operand_handle));
3582
3583 SetInstructionAsConstant(
3584 new_instr, id, new_shape,
3585 operand_proto->shape().is_dynamic_dimension(dimension));
3586 break;
3587 }
3588 case HloOpcode::kConstant:
3589 case HloOpcode::kIota:
3590 SetInstructionAsConstant(new_instr, id, new_shape, false);
3591 break;
3592 case HloOpcode::kCustomCall:
3593 if (instr_proto->custom_call_target() == "SetBound") {
3594 SetInstructionAsConstant(new_instr, id, new_shape, true);
3595 break;
3596 } else {
3597 return InvalidArgument(
3598 "Dynamic inferencing on custom call %s is not supported",
3599 instr_proto->DebugString());
3600 }
3601 case HloOpcode::kParameter:
3602 SetInstructionAsConstant(new_instr, id, new_shape, true);
3603 break;
3604 default:
3605 return InvalidArgument("Dynamic inferencing %s is not supported",
3606 instr_proto->DebugString());
3607 }
3608 *new_instr->mutable_name() =
3609 GetFullName(instr_proto->opcode(), kNameSeparator, id);
3610 return Status::OK();
3611 };
3612
3613 struct WorkItem {
3614 explicit WorkItem(int64 handle, bool need_rewrite)
3615 : handle(handle), need_rewrite(need_rewrite), visited(false) {}
3616 int64 handle;
3617 // If need_rewrite is true, the instruction will be copied and rewrite into
3618 // a pred instruction indicating if each value is dynamic. If need_rewrite
3619 // is false, simply copy the instruction to the output graph.
3620 // E.g.,
3621 // For select(P, A, B), we need to rewrite A and B into predicates, but
3622 // don't need to rewrite P.
3623 bool need_rewrite;
3624 // Used in dfs to remember the ids of processed operands of this item.
3625 std::vector<int64> processed_operands;
3626 // Whether this node been visited before or not.
3627 bool visited;
3628 };
3629 // Only copy each pair of {handle, need_rewrite} once. Value is the id in the
3630 // new graph.
3631 absl::flat_hash_map<std::pair<int64, bool>, int64> seen;
3632 // Monotonically increasing id to assign to new instructions.
3633 int64 global_id = 0;
3634 // The result id of the last rewritten item -- return value of last stack
3635 // item.
3636 int64 stacktop_id = -1;
3637 std::vector<WorkItem> worklist;
3638 worklist.push_back(WorkItem(root->id(), true));
3639 while (!worklist.empty()) {
3640 WorkItem& item = worklist.back();
3641 auto item_key = std::make_pair(item.handle, item.need_rewrite);
3642 auto iter = seen.find(item_key);
3643 // Already processed this item. Return previous results.
3644 if (iter != seen.end()) {
3645 stacktop_id = iter->second;
3646 worklist.pop_back();
3647 continue;
3648 }
3649
3650 int64 next_operand = item.processed_operands.size();
3651 TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
3652 LookUpInstructionByHandle(item.handle));
3653 VLOG(3) << "Processing" << instr_proto->name();
3654 if (!item.visited) {
3655 item.visited = true;
3656 } else {
3657 // Record previous processed operand.
3658 item.processed_operands.push_back(stacktop_id);
3659 next_operand++;
3660 }
3661 TF_ASSIGN_OR_RETURN(HloOpcode opcode,
3662 StringToHloOpcode(instr_proto->opcode()));
3663 bool should_visit_operand = true;
3664 if (opcode == HloOpcode::kGetDimensionSize) {
3665 // We rewrite gte instructions into constants based on its operand shape
3666 // so no need to visit their operands.
3667 should_visit_operand = false;
3668 }
3669
3670 if (opcode == HloOpcode::kGetTupleElement) {
3671 int64 operand_handle = instr_proto->operand_ids(0);
3672 TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
3673 LookUpInstructionByHandle(operand_handle));
3674 TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
3675 StringToHloOpcode(operand_proto->opcode()));
3676 if (operand_opcode == HloOpcode::kParameter) {
3677 // Don't rematerialize the whole parameter if it's followed by a GTE.
3678 should_visit_operand = false;
3679 }
3680 }
3681
3682 if (opcode == HloOpcode::kSelect) {
3683 TF_ASSIGN_OR_RETURN(bool constant_predicate,
3684 operand_is_constant(instr_proto, 0));
3685 // If the predicate operand (first operand) is non-constant, we don't
3686 // visit operands and we say the all result values are dynamic.
3687 should_visit_operand = constant_predicate;
3688 }
3689 if (opcode == HloOpcode::kGather) {
3690 TF_ASSIGN_OR_RETURN(bool constant_indices,
3691 operand_is_constant(instr_proto, 1));
3692 // If the indices operand (second operand) is non-constant, we don't
3693 // visit operands and we say the all result values are dynamic.
3694 should_visit_operand = constant_indices;
3695 }
3696 if (opcode == HloOpcode::kGetDimensionSize && next_operand == 0) {
3697 // Always rewrite get dimension size into constant.
3698 item.need_rewrite = true;
3699 }
3700 if (next_operand >= instr_proto->operand_ids_size() ||
3701 !should_visit_operand || InstrIsSetBound(instr_proto)) {
3702 // No more operands to process, process self.
3703 int64 new_id = global_id++;
3704 VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name();
3705 TF_RETURN_IF_ERROR(process_instruction(instr_proto, item.need_rewrite,
3706 new_id, item.processed_operands,
3707 &global_id));
3708 stacktop_id = new_id;
3709 seen[item_key] = stacktop_id;
3710 worklist.pop_back();
3711 } else {
3712 // Visit and process operand. If an operand doesn't need rewrite
3713 // (predicate of kSelect, or indices of kGather), we also don't rewrite
3714 // its ancestors.
3715 WorkItem next_item(instr_proto->operand_ids(next_operand),
3716 item.need_rewrite);
3717 if (opcode == HloOpcode::kSelect && next_operand == 0) {
3718 next_item.need_rewrite = false;
3719 }
3720 if (opcode == HloOpcode::kGather && next_operand == 1) {
3721 next_item.need_rewrite = false;
3722 }
3723 // Push next operand into worklist.
3724 worklist.push_back(next_item);
3725 }
3726 }
3727 TF_RET_CHECK(stacktop_id != -1);
3728 entry.set_root_id(stacktop_id);
3729 absl::c_sort(*entry.mutable_instructions(),
3730 [](const HloInstructionProto& p1,
3731 const HloInstructionProto& p2) { return p1.id() < p2.id(); });
3732 XlaComputation computation(entry.id());
3733 HloModuleProto* module = computation.mutable_proto();
3734 module->set_name(entry.name());
3735 module->set_id(entry.id());
3736 module->set_entry_computation_name(entry.name());
3737 module->set_entry_computation_id(entry.id());
3738 *module->mutable_host_program_shape() = *program_shape;
3739 for (auto& called_comp : called_computations) {
3740 *module->add_computations() = called_comp;
3741 }
3742 *module->add_computations() = std::move(entry);
3743 // Make sure all ids appear in the computation with ascending order.
3744 absl::c_sort(*module->mutable_computations(),
3745 [](const HloComputationProto& c1,
3746 const HloComputationProto& c2) { return c1.id() < c2.id(); });
3747 XLA_VLOG_LINES(3, module->DebugString());
3748 return std::move(computation);
3749 }
3750
BuildConstantSubGraph(XlaOp root_op,bool dynamic_dimension_is_minus_one)3751 StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
3752 XlaOp root_op, bool dynamic_dimension_is_minus_one) {
3753 TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
3754 if (!is_constant) {
3755 auto op_status = LookUpInstruction(root_op);
3756 string op_string =
3757 op_status.ok() ? op_status.ValueOrDie()->name() : "<unknown operation>";
3758 return InvalidArgument(
3759 "Operand to BuildConstantSubGraph depends on a parameter.\n\n"
3760 " op requested for constant subgraph: %s\n\n"
3761 "This is an internal error that typically happens when the XLA user "
3762 "(e.g. TensorFlow) is attempting to determine a value that must be a "
3763 "compile-time constant (e.g. an array dimension) but it is not capable "
3764 "of being evaluated at XLA compile time.\n\n"
3765 "Please file a usability bug with the framework being used (e.g. "
3766 "TensorFlow).",
3767 op_string);
3768 }
3769
3770 TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
3771 LookUpInstruction(root_op));
3772
3773 HloComputationProto entry;
3774 SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator,
3775 GetNextId());
3776 entry.set_root_id(root->id());
3777 ProgramShapeProto* program_shape = entry.mutable_program_shape();
3778 *program_shape->mutable_result() = root->shape();
3779
3780 // We use std::set to keep the instruction ids in ascending order (which is
3781 // also a valid dependency order). The related ops will be added to the
3782 // subgraph in the same order.
3783 std::set<int64> related_ops;
3784 absl::flat_hash_set<int64> related_calls; // Related computations.
3785 std::queue<int64> worklist;
3786 worklist.push(root->id());
3787 related_ops.insert(root->id());
3788 while (!worklist.empty()) {
3789 int64 handle = worklist.front();
3790 worklist.pop();
3791 TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
3792 LookUpInstructionByHandle(handle));
3793
3794 if (instr_proto->opcode() ==
3795 HloOpcodeString(HloOpcode::kGetDimensionSize) ||
3796 InstrIsSetBound(instr_proto)) {
3797 int32 constant_value = -1;
3798 HloInstructionProto const_instr;
3799
3800 if (instr_proto->opcode() ==
3801 HloOpcodeString(HloOpcode::kGetDimensionSize)) {
3802 // At this point, BuildConstantSubGraph should never encounter a
3803 // GetDimensionSize with a dynamic dimension. IsConstant check would
3804 // have failed at the beginning of this function.
3805 //
3806 // Replace GetDimensionSize with a Constant representing the static
3807 // bound of the shape.
3808 int64 dimension = instr_proto->dimensions(0);
3809 int64 operand_handle = instr_proto->operand_ids(0);
3810 TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
3811 LookUpInstructionByHandle(operand_handle));
3812
3813 if (!(operand_proto->shape().is_dynamic_dimension(dimension) &&
3814 dynamic_dimension_is_minus_one)) {
3815 constant_value =
3816 static_cast<int32>(operand_proto->shape().dimensions(dimension));
3817 }
3818 Literal literal = LiteralUtil::CreateR0(constant_value);
3819 *const_instr.mutable_literal() = literal.ToProto();
3820 *const_instr.mutable_shape() = literal.shape().ToProto();
3821 } else {
3822 *const_instr.mutable_literal() = instr_proto->literal();
3823 *const_instr.mutable_shape() = instr_proto->shape();
3824 }
3825 *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
3826 const_instr.set_id(handle);
3827 *const_instr.mutable_name() =
3828 GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id());
3829 *entry.add_instructions() =
3830 const_instr; // Add to the result constant graph.
3831 } else {
3832 for (int64 id : instr_proto->operand_ids()) {
3833 if (related_ops.insert(id).second) {
3834 worklist.push(id);
3835 }
3836 }
3837 for (int64 called_id : instr_proto->called_computation_ids()) {
3838 related_calls.insert(called_id);
3839 }
3840 }
3841 }
3842
3843 // Add related ops to the computation.
3844 for (int64 id : related_ops) {
3845 TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
3846 LookUpInstructionByHandle(id));
3847
3848 if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize)) {
3849 continue;
3850 }
3851 if (InstrIsSetBound(instr_src)) {
3852 continue;
3853 }
3854 auto* instr = entry.add_instructions();
3855
3856 *instr = *instr_src;
3857 // Ensures that the instruction names are unique among the graph.
3858 const string& new_name =
3859 StrCat(instr->name(), ".", entry.id(), ".", instr->id());
3860 instr->set_name(new_name);
3861 }
3862
3863 XlaComputation computation(entry.id());
3864 HloModuleProto* module = computation.mutable_proto();
3865 module->set_name(entry.name());
3866 module->set_id(entry.id());
3867 module->set_entry_computation_name(entry.name());
3868 module->set_entry_computation_id(entry.id());
3869 *module->mutable_host_program_shape() = *program_shape;
3870 for (auto& e : embedded_) {
3871 if (related_calls.find(e.second.id()) != related_calls.end()) {
3872 *module->add_computations() = e.second;
3873 }
3874 }
3875 *module->add_computations() = std::move(entry);
3876 return std::move(computation);
3877 }
3878
CreateSubBuilder(const string & computation_name)3879 std::unique_ptr<XlaBuilder> XlaBuilder::CreateSubBuilder(
3880 const string& computation_name) {
3881 auto sub_builder = absl::make_unique<XlaBuilder>(computation_name);
3882 sub_builder->parent_builder_ = this;
3883 sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_;
3884 return sub_builder;
3885 }
3886
3887 /* static */ ConvolutionDimensionNumbers
CreateDefaultConvDimensionNumbers(int num_spatial_dims)3888 XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
3889 ConvolutionDimensionNumbers dimension_numbers;
3890 dimension_numbers.set_input_batch_dimension(kConvBatchDimension);
3891 dimension_numbers.set_input_feature_dimension(kConvFeatureDimension);
3892 dimension_numbers.set_output_batch_dimension(kConvBatchDimension);
3893 dimension_numbers.set_output_feature_dimension(kConvFeatureDimension);
3894 dimension_numbers.set_kernel_output_feature_dimension(
3895 kConvKernelOutputDimension);
3896 dimension_numbers.set_kernel_input_feature_dimension(
3897 kConvKernelInputDimension);
3898 for (int i = 0; i < num_spatial_dims; ++i) {
3899 dimension_numbers.add_input_spatial_dimensions(i + 2);
3900 dimension_numbers.add_kernel_spatial_dimensions(i + 2);
3901 dimension_numbers.add_output_spatial_dimensions(i + 2);
3902 }
3903 return dimension_numbers;
3904 }
3905
Validate(const ConvolutionDimensionNumbers & dnum)3906 /* static */ Status XlaBuilder::Validate(
3907 const ConvolutionDimensionNumbers& dnum) {
3908 if (dnum.input_spatial_dimensions_size() < 2) {
3909 return FailedPrecondition("input spacial dimension < 2: %d",
3910 dnum.input_spatial_dimensions_size());
3911 }
3912 if (dnum.kernel_spatial_dimensions_size() < 2) {
3913 return FailedPrecondition("kernel spacial dimension < 2: %d",
3914 dnum.kernel_spatial_dimensions_size());
3915 }
3916 if (dnum.output_spatial_dimensions_size() < 2) {
3917 return FailedPrecondition("output spacial dimension < 2: %d",
3918 dnum.output_spatial_dimensions_size());
3919 }
3920
3921 if (std::set<int64>(
3922 {dnum.input_batch_dimension(), dnum.input_feature_dimension(),
3923 dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)})
3924 .size() != 4) {
3925 return FailedPrecondition(
3926 "dimension numbers for the input are not unique: (%d, %d, %d, "
3927 "%d)",
3928 dnum.input_batch_dimension(), dnum.input_feature_dimension(),
3929 dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1));
3930 }
3931 if (std::set<int64>({dnum.kernel_output_feature_dimension(),
3932 dnum.kernel_input_feature_dimension(),
3933 dnum.kernel_spatial_dimensions(0),
3934 dnum.kernel_spatial_dimensions(1)})
3935 .size() != 4) {
3936 return FailedPrecondition(
3937 "dimension numbers for the weight are not unique: (%d, %d, %d, "
3938 "%d)",
3939 dnum.kernel_output_feature_dimension(),
3940 dnum.kernel_input_feature_dimension(),
3941 dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1));
3942 }
3943 if (std::set<int64>({dnum.output_batch_dimension(),
3944 dnum.output_feature_dimension(),
3945 dnum.output_spatial_dimensions(0),
3946 dnum.output_spatial_dimensions(1)})
3947 .size() != 4) {
3948 return FailedPrecondition(
3949 "dimension numbers for the output are not unique: (%d, %d, %d, "
3950 "%d)",
3951 dnum.output_batch_dimension(), dnum.output_feature_dimension(),
3952 dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1));
3953 }
3954 return Status::OK();
3955 }
3956
AddInstruction(HloInstructionProto && instr,HloOpcode opcode,absl::Span<const XlaOp> operands)3957 StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
3958 HloOpcode opcode,
3959 absl::Span<const XlaOp> operands) {
3960 TF_RETURN_IF_ERROR(first_error_);
3961
3962 const int64 handle = GetNextId();
3963 instr.set_id(handle);
3964 instr.set_opcode(HloOpcodeString(opcode));
3965 if (instr.name().empty()) {
3966 instr.set_name(instr.opcode());
3967 }
3968 for (const auto& operand : operands) {
3969 if (operand.builder_ == nullptr) {
3970 return InvalidArgument("invalid XlaOp with handle %d", operand.handle());
3971 }
3972 if (operand.builder_ != this) {
3973 return InvalidArgument("Do not add XlaOp from builder %s to builder %s",
3974 operand.builder_->name(), this->name());
3975 }
3976 instr.add_operand_ids(operand.handle());
3977 }
3978
3979 if (one_shot_metadata_.has_value()) {
3980 *instr.mutable_metadata() = one_shot_metadata_.value();
3981 one_shot_metadata_.reset();
3982 } else {
3983 *instr.mutable_metadata() = metadata_;
3984 }
3985 if (sharding_) {
3986 *instr.mutable_sharding() = *sharding_;
3987 }
3988 *instr.mutable_frontend_attributes() = frontend_attributes_;
3989
3990 handle_to_index_[handle] = instructions_.size();
3991 instructions_.push_back(std::move(instr));
3992 instruction_shapes_.push_back(
3993 absl::make_unique<Shape>(instructions_.back().shape()));
3994
3995 XlaOp op(handle, this);
3996 return op;
3997 }
3998
AddOpWithShape(HloOpcode opcode,const Shape & shape,absl::Span<const XlaOp> operands)3999 StatusOr<XlaOp> XlaBuilder::AddOpWithShape(HloOpcode opcode, const Shape& shape,
4000 absl::Span<const XlaOp> operands) {
4001 HloInstructionProto instr;
4002 *instr.mutable_shape() = shape.ToProto();
4003 return AddInstruction(std::move(instr), opcode, operands);
4004 }
4005
AddCalledComputation(const XlaComputation & computation,HloInstructionProto * instr)4006 void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
4007 HloInstructionProto* instr) {
4008 absl::flat_hash_map<int64, int64> remapped_ids;
4009 std::vector<HloComputationProto> imported_computations;
4010 imported_computations.reserve(computation.proto().computations_size());
4011 // Before we import the computations by remapping IDs, and capturing the
4012 // old->new mappings in remapped_ids.
4013 for (const HloComputationProto& e : computation.proto().computations()) {
4014 HloComputationProto new_computation(e);
4015 int64 computation_id = GetNextId();
4016 remapped_ids[new_computation.id()] = computation_id;
4017 SetProtoIdAndName(&new_computation,
4018 GetBaseName(new_computation.name(), kNameSeparator),
4019 kNameSeparator, computation_id);
4020 for (auto& instruction : *new_computation.mutable_instructions()) {
4021 int64 instruction_id = GetNextId();
4022 remapped_ids[instruction.id()] = instruction_id;
4023 SetProtoIdAndName(&instruction,
4024 GetBaseName(instruction.name(), kNameSeparator),
4025 kNameSeparator, instruction_id);
4026 }
4027 new_computation.set_root_id(remapped_ids.at(new_computation.root_id()));
4028
4029 imported_computations.push_back(std::move(new_computation));
4030 }
4031 // Once we have imported all the computations, and captured all the ID
4032 // mappings, we go back and fixup the IDs in the imported computations.
4033 instr->add_called_computation_ids(
4034 remapped_ids.at(computation.proto().entry_computation_id()));
4035 for (auto& imported_computation : imported_computations) {
4036 for (auto& instruction : *imported_computation.mutable_instructions()) {
4037 for (auto& operand_id : *instruction.mutable_operand_ids()) {
4038 operand_id = remapped_ids.at(operand_id);
4039 }
4040 for (auto& control_predecessor_id :
4041 *instruction.mutable_control_predecessor_ids()) {
4042 control_predecessor_id = remapped_ids.at(control_predecessor_id);
4043 }
4044 for (auto& called_computation_id :
4045 *instruction.mutable_called_computation_ids()) {
4046 called_computation_id = remapped_ids.at(called_computation_id);
4047 }
4048 }
4049
4050 int64 computation_id = imported_computation.id();
4051 embedded_.insert({computation_id, std::move(imported_computation)});
4052 }
4053 }
4054
LookUpInstruction(const XlaOp op) const4055 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
4056 const XlaOp op) const {
4057 TF_RETURN_IF_ERROR(first_error_);
4058 return LookUpInstructionInternal<const HloInstructionProto*>(op);
4059 }
4060
LookUpInstructionByHandle(int64 handle) const4061 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
4062 int64 handle) const {
4063 return LookUpInstructionByHandleInternal<const HloInstructionProto*>(handle);
4064 }
4065
LookUpMutableInstruction(const XlaOp op)4066 StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstruction(
4067 const XlaOp op) {
4068 TF_RETURN_IF_ERROR(first_error_);
4069 return LookUpInstructionInternal<HloInstructionProto*>(op);
4070 }
4071
LookUpMutableInstructionByHandle(int64 handle)4072 StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstructionByHandle(
4073 int64 handle) {
4074 return LookUpInstructionByHandleInternal<HloInstructionProto*>(handle);
4075 }
4076
4077 // Enqueues a "retrieve parameter value" instruction for a parameter that was
4078 // passed to the computation.
Parameter(XlaBuilder * builder,int64 parameter_number,const Shape & shape,const string & name)4079 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
4080 const string& name) {
4081 std::vector<bool> empty_bools;
4082 return Parameter(builder, parameter_number, shape, name, empty_bools);
4083 }
4084
Parameter(XlaBuilder * builder,int64 parameter_number,const Shape & shape,const string & name,const std::vector<bool> & replicated_at_leaf_buffers)4085 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
4086 const string& name,
4087 const std::vector<bool>& replicated_at_leaf_buffers) {
4088 return builder->Parameter(parameter_number, shape, name,
4089 replicated_at_leaf_buffers);
4090 }
4091
4092 // Enqueues a constant with the value of the given literal onto the
4093 // computation.
ConstantLiteral(XlaBuilder * builder,const LiteralSlice & literal)4094 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) {
4095 return builder->ConstantLiteral(literal);
4096 }
4097
Broadcast(const XlaOp operand,absl::Span<const int64> broadcast_sizes)4098 XlaOp Broadcast(const XlaOp operand, absl::Span<const int64> broadcast_sizes) {
4099 return operand.builder()->Broadcast(operand, broadcast_sizes);
4100 }
4101
BroadcastInDim(const XlaOp operand,const absl::Span<const int64> out_dim_size,const absl::Span<const int64> broadcast_dimensions)4102 XlaOp BroadcastInDim(const XlaOp operand,
4103 const absl::Span<const int64> out_dim_size,
4104 const absl::Span<const int64> broadcast_dimensions) {
4105 return operand.builder()->BroadcastInDim(operand, out_dim_size,
4106 broadcast_dimensions);
4107 }
4108
Copy(const XlaOp operand)4109 XlaOp Copy(const XlaOp operand) {
4110 return operand.builder()->UnaryOp(HloOpcode::kCopy, operand);
4111 }
4112
Pad(const XlaOp operand,const XlaOp padding_value,const PaddingConfig & padding_config)4113 XlaOp Pad(const XlaOp operand, const XlaOp padding_value,
4114 const PaddingConfig& padding_config) {
4115 return operand.builder()->Pad(operand, padding_value, padding_config);
4116 }
4117
PadInDim(XlaOp operand,XlaOp padding_value,int64 dimno,int64 pad_lo,int64 pad_hi)4118 XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64 dimno, int64 pad_lo,
4119 int64 pad_hi) {
4120 return operand.builder()->PadInDim(operand, padding_value, dimno, pad_lo,
4121 pad_hi);
4122 }
4123
Reshape(const XlaOp operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes)4124 XlaOp Reshape(const XlaOp operand, absl::Span<const int64> dimensions,
4125 absl::Span<const int64> new_sizes) {
4126 return operand.builder()->Reshape(operand, dimensions, new_sizes);
4127 }
4128
Reshape(const XlaOp operand,absl::Span<const int64> new_sizes)4129 XlaOp Reshape(const XlaOp operand, absl::Span<const int64> new_sizes) {
4130 return operand.builder()->Reshape(operand, new_sizes);
4131 }
4132
Reshape(const Shape & shape,XlaOp operand)4133 XlaOp Reshape(const Shape& shape, XlaOp operand) {
4134 return operand.builder()->Reshape(shape, operand);
4135 }
4136
DynamicReshape(XlaOp operand,absl::Span<const XlaOp> dim_sizes,absl::Span<const int64> new_size_bounds,const std::vector<bool> & dims_are_dynamic)4137 XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
4138 absl::Span<const int64> new_size_bounds,
4139 const std::vector<bool>& dims_are_dynamic) {
4140 return operand.builder()->DynamicReshape(operand, dim_sizes, new_size_bounds,
4141 dims_are_dynamic);
4142 }
4143
ReshapeWithInferredDimension(XlaOp operand,absl::Span<const int64> new_sizes,int64 inferred_dimension)4144 XlaOp ReshapeWithInferredDimension(XlaOp operand,
4145 absl::Span<const int64> new_sizes,
4146 int64 inferred_dimension) {
4147 return operand.builder()->Reshape(operand, new_sizes, inferred_dimension);
4148 }
4149
Collapse(const XlaOp operand,absl::Span<const int64> dimensions)4150 XlaOp Collapse(const XlaOp operand, absl::Span<const int64> dimensions) {
4151 return operand.builder()->Collapse(operand, dimensions);
4152 }
4153
Slice(const XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)4154 XlaOp Slice(const XlaOp operand, absl::Span<const int64> start_indices,
4155 absl::Span<const int64> limit_indices,
4156 absl::Span<const int64> strides) {
4157 return operand.builder()->Slice(operand, start_indices, limit_indices,
4158 strides);
4159 }
4160
SliceInDim(const XlaOp operand,int64 start_index,int64 limit_index,int64 stride,int64 dimno)4161 XlaOp SliceInDim(const XlaOp operand, int64 start_index, int64 limit_index,
4162 int64 stride, int64 dimno) {
4163 return operand.builder()->SliceInDim(operand, start_index, limit_index,
4164 stride, dimno);
4165 }
4166
DynamicSlice(const XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)4167 XlaOp DynamicSlice(const XlaOp operand, absl::Span<const XlaOp> start_indices,
4168 absl::Span<const int64> slice_sizes) {
4169 return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
4170 }
4171
DynamicUpdateSlice(const XlaOp operand,const XlaOp update,absl::Span<const XlaOp> start_indices)4172 XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update,
4173 absl::Span<const XlaOp> start_indices) {
4174 return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
4175 }
4176
ConcatInDim(XlaBuilder * builder,absl::Span<const XlaOp> operands,int64 dimension)4177 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
4178 int64 dimension) {
4179 return builder->ConcatInDim(operands, dimension);
4180 }
4181
Trace(const string & tag,const XlaOp operand)4182 void Trace(const string& tag, const XlaOp operand) {
4183 return operand.builder()->Trace(tag, operand);
4184 }
4185
Select(const XlaOp pred,const XlaOp on_true,const XlaOp on_false)4186 XlaOp Select(const XlaOp pred, const XlaOp on_true, const XlaOp on_false) {
4187 return pred.builder()->Select(pred, on_true, on_false);
4188 }
4189
Tuple(XlaBuilder * builder,absl::Span<const XlaOp> elements)4190 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements) {
4191 return builder->Tuple(elements);
4192 }
4193
GetTupleElement(const XlaOp tuple_data,int64 index)4194 XlaOp GetTupleElement(const XlaOp tuple_data, int64 index) {
4195 return tuple_data.builder()->GetTupleElement(tuple_data, index);
4196 }
4197
Eq(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4198 XlaOp Eq(const XlaOp lhs, const XlaOp rhs,
4199 absl::Span<const int64> broadcast_dimensions) {
4200 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
4201 }
4202
CompareTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions,ComparisonDirection comparison_direction)4203 static XlaOp CompareTotalOrder(const XlaOp lhs, const XlaOp rhs,
4204 absl::Span<const int64> broadcast_dimensions,
4205 ComparisonDirection comparison_direction) {
4206 auto b = lhs.builder();
4207 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
4208 TF_ASSIGN_OR_RETURN(auto operand_shape, b->GetShape(lhs));
4209 auto operand_element_type = operand_shape.element_type();
4210 auto compare_type =
4211 primitive_util::IsFloatingPointType(operand_element_type)
4212 ? Comparison::Type::kFloatTotalOrder
4213 : Comparison::DefaultComparisonType(operand_element_type);
4214 return Compare(lhs, rhs, broadcast_dimensions, comparison_direction,
4215 compare_type);
4216 });
4217 }
4218
EqTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4219 XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs,
4220 absl::Span<const int64> broadcast_dimensions) {
4221 return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4222 ComparisonDirection::kEq);
4223 }
4224
Ne(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4225 XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
4226 absl::Span<const int64> broadcast_dimensions) {
4227 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe);
4228 }
4229
NeTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4230 XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs,
4231 absl::Span<const int64> broadcast_dimensions) {
4232 return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4233 ComparisonDirection::kNe);
4234 }
4235
Ge(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4236 XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
4237 absl::Span<const int64> broadcast_dimensions) {
4238 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe);
4239 }
4240
GeTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4241 XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs,
4242 absl::Span<const int64> broadcast_dimensions) {
4243 return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4244 ComparisonDirection::kGe);
4245 }
4246
Gt(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4247 XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
4248 absl::Span<const int64> broadcast_dimensions) {
4249 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt);
4250 }
4251
GtTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4252 XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs,
4253 absl::Span<const int64> broadcast_dimensions) {
4254 return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4255 ComparisonDirection::kGt);
4256 }
4257
Le(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4258 XlaOp Le(const XlaOp lhs, const XlaOp rhs,
4259 absl::Span<const int64> broadcast_dimensions) {
4260 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe);
4261 }
4262
LeTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4263 XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs,
4264 absl::Span<const int64> broadcast_dimensions) {
4265 return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4266 ComparisonDirection::kLe);
4267 }
4268
Lt(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4269 XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
4270 absl::Span<const int64> broadcast_dimensions) {
4271 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
4272 }
4273
LtTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4274 XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs,
4275 absl::Span<const int64> broadcast_dimensions) {
4276 return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4277 ComparisonDirection::kLt);
4278 }
4279
Compare(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions,ComparisonDirection direction)4280 XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
4281 absl::Span<const int64> broadcast_dimensions,
4282 ComparisonDirection direction) {
4283 return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
4284 broadcast_dimensions, direction);
4285 }
4286
Compare(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions,ComparisonDirection direction,Comparison::Type compare_type)4287 XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
4288 absl::Span<const int64> broadcast_dimensions,
4289 ComparisonDirection direction, Comparison::Type compare_type) {
4290 return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
4291 broadcast_dimensions, direction, compare_type);
4292 }
4293
Compare(const XlaOp lhs,const XlaOp rhs,ComparisonDirection direction)4294 XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) {
4295 return Compare(lhs, rhs, {}, direction);
4296 }
4297
Dot(const XlaOp lhs,const XlaOp rhs,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4298 XlaOp Dot(const XlaOp lhs, const XlaOp rhs,
4299 const PrecisionConfig* precision_config,
4300 absl::optional<PrimitiveType> preferred_element_type) {
4301 return lhs.builder()->Dot(lhs, rhs, precision_config, preferred_element_type);
4302 }
4303
DotGeneral(const XlaOp lhs,const XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4304 XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs,
4305 const DotDimensionNumbers& dimension_numbers,
4306 const PrecisionConfig* precision_config,
4307 absl::optional<PrimitiveType> preferred_element_type) {
4308 return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
4309 precision_config, preferred_element_type);
4310 }
4311
Conv(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4312 XlaOp Conv(const XlaOp lhs, const XlaOp rhs,
4313 absl::Span<const int64> window_strides, Padding padding,
4314 int64 feature_group_count, int64 batch_group_count,
4315 const PrecisionConfig* precision_config,
4316 absl::optional<PrimitiveType> preferred_element_type) {
4317 return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
4318 feature_group_count, batch_group_count,
4319 precision_config, preferred_element_type);
4320 }
4321
ConvWithGeneralPadding(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4322 XlaOp ConvWithGeneralPadding(
4323 const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
4324 absl::Span<const std::pair<int64, int64>> padding,
4325 int64 feature_group_count, int64 batch_group_count,
4326 const PrecisionConfig* precision_config,
4327 absl::optional<PrimitiveType> preferred_element_type) {
4328 return lhs.builder()->ConvWithGeneralPadding(
4329 lhs, rhs, window_strides, padding, feature_group_count, batch_group_count,
4330 precision_config, preferred_element_type);
4331 }
4332
ConvWithGeneralDimensions(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4333 XlaOp ConvWithGeneralDimensions(
4334 const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
4335 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
4336 int64 feature_group_count, int64 batch_group_count,
4337 const PrecisionConfig* precision_config,
4338 absl::optional<PrimitiveType> preferred_element_type) {
4339 return lhs.builder()->ConvWithGeneralDimensions(
4340 lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
4341 batch_group_count, precision_config, preferred_element_type);
4342 }
4343
ConvGeneral(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4344 XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
4345 absl::Span<const int64> window_strides,
4346 absl::Span<const std::pair<int64, int64>> padding,
4347 const ConvolutionDimensionNumbers& dimension_numbers,
4348 int64 feature_group_count, int64 batch_group_count,
4349 const PrecisionConfig* precision_config,
4350 absl::optional<PrimitiveType> preferred_element_type) {
4351 return lhs.builder()->ConvGeneral(
4352 lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
4353 batch_group_count, precision_config, preferred_element_type);
4354 }
4355
ConvGeneralDilated(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4356 XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
4357 absl::Span<const int64> window_strides,
4358 absl::Span<const std::pair<int64, int64>> padding,
4359 absl::Span<const int64> lhs_dilation,
4360 absl::Span<const int64> rhs_dilation,
4361 const ConvolutionDimensionNumbers& dimension_numbers,
4362 int64 feature_group_count, int64 batch_group_count,
4363 const PrecisionConfig* precision_config,
4364 absl::optional<PrimitiveType> preferred_element_type) {
4365 return lhs.builder()->ConvGeneralDilated(
4366 lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
4367 dimension_numbers, feature_group_count, batch_group_count,
4368 precision_config, preferred_element_type);
4369 }
4370
DynamicConvInputGrad(XlaOp input_sizes,const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)4371 XlaOp DynamicConvInputGrad(
4372 XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs,
4373 absl::Span<const int64> window_strides,
4374 absl::Span<const std::pair<int64, int64>> padding,
4375 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
4376 const ConvolutionDimensionNumbers& dimension_numbers,
4377 int64 feature_group_count, int64 batch_group_count,
4378 const PrecisionConfig* precision_config, PaddingType padding_type,
4379 absl::optional<PrimitiveType> preferred_element_type) {
4380 return lhs.builder()->DynamicConvInputGrad(
4381 input_sizes, lhs, rhs, window_strides, padding, lhs_dilation,
4382 rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
4383 precision_config, padding_type, preferred_element_type);
4384 }
4385
DynamicConvKernelGrad(XlaOp activations,XlaOp gradients,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)4386 XlaOp DynamicConvKernelGrad(
4387 XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
4388 absl::Span<const std::pair<int64, int64>> padding,
4389 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
4390 const ConvolutionDimensionNumbers& dimension_numbers,
4391 int64 feature_group_count, int64 batch_group_count,
4392 const PrecisionConfig* precision_config, PaddingType padding_type,
4393 absl::optional<PrimitiveType> preferred_element_type) {
4394 return activations.builder()->DynamicConvKernelGrad(
4395 activations, gradients, window_strides, padding, lhs_dilation,
4396 rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
4397 precision_config, padding_type, preferred_element_type);
4398 }
4399
DynamicConvForward(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)4400 XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
4401 absl::Span<const int64> window_strides,
4402 absl::Span<const std::pair<int64, int64>> padding,
4403 absl::Span<const int64> lhs_dilation,
4404 absl::Span<const int64> rhs_dilation,
4405 const ConvolutionDimensionNumbers& dimension_numbers,
4406 int64 feature_group_count, int64 batch_group_count,
4407 const PrecisionConfig* precision_config,
4408 PaddingType padding_type,
4409 absl::optional<PrimitiveType> preferred_element_type) {
4410 return lhs.builder()->DynamicConvForward(
4411 lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
4412 dimension_numbers, feature_group_count, batch_group_count,
4413 precision_config, padding_type, preferred_element_type);
4414 }
4415
Fft(const XlaOp operand,FftType fft_type,absl::Span<const int64> fft_length)4416 XlaOp Fft(const XlaOp operand, FftType fft_type,
4417 absl::Span<const int64> fft_length) {
4418 return operand.builder()->Fft(operand, fft_type, fft_length);
4419 }
4420
TriangularSolve(XlaOp a,XlaOp b,bool left_side,bool lower,bool unit_diagonal,TriangularSolveOptions::Transpose transpose_a)4421 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
4422 bool unit_diagonal,
4423 TriangularSolveOptions::Transpose transpose_a) {
4424 XlaBuilder* builder = a.builder();
4425 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
4426 TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a));
4427 TF_ASSIGN_OR_RETURN(const Shape* b_shape, builder->GetShapePtr(b));
4428 xla::TriangularSolveOptions options;
4429 options.set_left_side(left_side);
4430 options.set_lower(lower);
4431 options.set_unit_diagonal(unit_diagonal);
4432 options.set_transpose_a(transpose_a);
4433 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTriangularSolveShape(
4434 *a_shape, *b_shape, options));
4435 return builder->TriangularSolveInternal(shape, a, b, std::move(options));
4436 });
4437 }
4438
Cholesky(XlaOp a,bool lower)4439 XlaOp Cholesky(XlaOp a, bool lower) {
4440 XlaBuilder* builder = a.builder();
4441 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
4442 TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a));
4443 TF_ASSIGN_OR_RETURN(Shape shape,
4444 ShapeInference::InferCholeskyShape(*a_shape));
4445 return builder->CholeskyInternal(shape, a, lower);
4446 });
4447 }
4448
Infeed(XlaBuilder * builder,const Shape & shape,const string & config)4449 XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) {
4450 return builder->Infeed(shape, config);
4451 }
4452
Outfeed(const XlaOp operand,const Shape & shape_with_layout,const string & outfeed_config)4453 void Outfeed(const XlaOp operand, const Shape& shape_with_layout,
4454 const string& outfeed_config) {
4455 return operand.builder()->Outfeed(operand, shape_with_layout, outfeed_config);
4456 }
4457
Call(XlaBuilder * builder,const XlaComputation & computation,absl::Span<const XlaOp> operands)4458 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
4459 absl::Span<const XlaOp> operands) {
4460 return builder->Call(computation, operands);
4461 }
4462
CustomCall(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal)4463 XlaOp CustomCall(
4464 XlaBuilder* builder, const string& call_target_name,
4465 absl::Span<const XlaOp> operands, const Shape& shape, const string& opaque,
4466 bool has_side_effect,
4467 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
4468 output_operand_aliasing,
4469 const Literal* literal) {
4470 return builder->CustomCall(call_target_name, operands, shape, opaque,
4471 /*operand_shapes_with_layout=*/absl::nullopt,
4472 has_side_effect, output_operand_aliasing, literal);
4473 }
4474
CustomCallWithComputation(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const XlaComputation & computation,const Shape & shape,const string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal)4475 XlaOp CustomCallWithComputation(
4476 XlaBuilder* builder, const string& call_target_name,
4477 absl::Span<const XlaOp> operands, const XlaComputation& computation,
4478 const Shape& shape, const string& opaque, bool has_side_effect,
4479 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
4480 output_operand_aliasing,
4481 const Literal* literal) {
4482 return builder->CustomCall(call_target_name, operands, computation, shape,
4483 opaque,
4484 /*operand_shapes_with_layout=*/absl::nullopt,
4485 has_side_effect, output_operand_aliasing, literal);
4486 }
4487
CustomCallWithLayout(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,absl::Span<const Shape> operand_shapes_with_layout,const string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal)4488 XlaOp CustomCallWithLayout(
4489 XlaBuilder* builder, const string& call_target_name,
4490 absl::Span<const XlaOp> operands, const Shape& shape,
4491 absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
4492 bool has_side_effect,
4493 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
4494 output_operand_aliasing,
4495 const Literal* literal) {
4496 return builder->CustomCall(call_target_name, operands, shape, opaque,
4497 operand_shapes_with_layout, has_side_effect,
4498 output_operand_aliasing, literal);
4499 }
4500
Complex(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4501 XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
4502 absl::Span<const int64> broadcast_dimensions) {
4503 return lhs.builder()->BinaryOp(HloOpcode::kComplex, lhs, rhs,
4504 broadcast_dimensions);
4505 }
4506
Conj(const XlaOp operand)4507 XlaOp Conj(const XlaOp operand) {
4508 return Complex(Real(operand), Neg(Imag(operand)));
4509 }
4510
Add(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4511 XlaOp Add(const XlaOp lhs, const XlaOp rhs,
4512 absl::Span<const int64> broadcast_dimensions) {
4513 return lhs.builder()->BinaryOp(HloOpcode::kAdd, lhs, rhs,
4514 broadcast_dimensions);
4515 }
4516
Sub(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4517 XlaOp Sub(const XlaOp lhs, const XlaOp rhs,
4518 absl::Span<const int64> broadcast_dimensions) {
4519 return lhs.builder()->BinaryOp(HloOpcode::kSubtract, lhs, rhs,
4520 broadcast_dimensions);
4521 }
4522
Mul(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4523 XlaOp Mul(const XlaOp lhs, const XlaOp rhs,
4524 absl::Span<const int64> broadcast_dimensions) {
4525 return lhs.builder()->BinaryOp(HloOpcode::kMultiply, lhs, rhs,
4526 broadcast_dimensions);
4527 }
4528
Div(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4529 XlaOp Div(const XlaOp lhs, const XlaOp rhs,
4530 absl::Span<const int64> broadcast_dimensions) {
4531 return lhs.builder()->BinaryOp(HloOpcode::kDivide, lhs, rhs,
4532 broadcast_dimensions);
4533 }
4534
Rem(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4535 XlaOp Rem(const XlaOp lhs, const XlaOp rhs,
4536 absl::Span<const int64> broadcast_dimensions) {
4537 return lhs.builder()->BinaryOp(HloOpcode::kRemainder, lhs, rhs,
4538 broadcast_dimensions);
4539 }
4540
Max(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4541 XlaOp Max(const XlaOp lhs, const XlaOp rhs,
4542 absl::Span<const int64> broadcast_dimensions) {
4543 return lhs.builder()->BinaryOp(HloOpcode::kMaximum, lhs, rhs,
4544 broadcast_dimensions);
4545 }
4546
Min(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4547 XlaOp Min(const XlaOp lhs, const XlaOp rhs,
4548 absl::Span<const int64> broadcast_dimensions) {
4549 return lhs.builder()->BinaryOp(HloOpcode::kMinimum, lhs, rhs,
4550 broadcast_dimensions);
4551 }
4552
And(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4553 XlaOp And(const XlaOp lhs, const XlaOp rhs,
4554 absl::Span<const int64> broadcast_dimensions) {
4555 return lhs.builder()->BinaryOp(HloOpcode::kAnd, lhs, rhs,
4556 broadcast_dimensions);
4557 }
4558
Or(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4559 XlaOp Or(const XlaOp lhs, const XlaOp rhs,
4560 absl::Span<const int64> broadcast_dimensions) {
4561 return lhs.builder()->BinaryOp(HloOpcode::kOr, lhs, rhs,
4562 broadcast_dimensions);
4563 }
4564
Xor(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4565 XlaOp Xor(const XlaOp lhs, const XlaOp rhs,
4566 absl::Span<const int64> broadcast_dimensions) {
4567 return lhs.builder()->BinaryOp(HloOpcode::kXor, lhs, rhs,
4568 broadcast_dimensions);
4569 }
4570
Not(const XlaOp operand)4571 XlaOp Not(const XlaOp operand) {
4572 return operand.builder()->UnaryOp(HloOpcode::kNot, operand);
4573 }
4574
PopulationCount(const XlaOp operand)4575 XlaOp PopulationCount(const XlaOp operand) {
4576 return operand.builder()->UnaryOp(HloOpcode::kPopulationCount, operand);
4577 }
4578
ShiftLeft(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4579 XlaOp ShiftLeft(const XlaOp lhs, const XlaOp rhs,
4580 absl::Span<const int64> broadcast_dimensions) {
4581 return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs,
4582 broadcast_dimensions);
4583 }
4584
ShiftRightArithmetic(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4585 XlaOp ShiftRightArithmetic(const XlaOp lhs, const XlaOp rhs,
4586 absl::Span<const int64> broadcast_dimensions) {
4587 return lhs.builder()->BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs,
4588 broadcast_dimensions);
4589 }
4590
ShiftRightLogical(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4591 XlaOp ShiftRightLogical(const XlaOp lhs, const XlaOp rhs,
4592 absl::Span<const int64> broadcast_dimensions) {
4593 return lhs.builder()->BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs,
4594 broadcast_dimensions);
4595 }
4596
Reduce(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)4597 XlaOp Reduce(const XlaOp operand, const XlaOp init_value,
4598 const XlaComputation& computation,
4599 absl::Span<const int64> dimensions_to_reduce) {
4600 return operand.builder()->Reduce(operand, init_value, computation,
4601 dimensions_to_reduce);
4602 }
4603
4604 // Reduces several arrays simultaneously among the provided dimensions, given
4605 // "computation" as a reduction operator.
Reduce(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)4606 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
4607 absl::Span<const XlaOp> init_values,
4608 const XlaComputation& computation,
4609 absl::Span<const int64> dimensions_to_reduce) {
4610 return builder->Reduce(operands, init_values, computation,
4611 dimensions_to_reduce);
4612 }
4613
ReduceAll(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation)4614 XlaOp ReduceAll(const XlaOp operand, const XlaOp init_value,
4615 const XlaComputation& computation) {
4616 return operand.builder()->ReduceAll(operand, init_value, computation);
4617 }
4618
ReduceWindow(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)4619 XlaOp ReduceWindow(const XlaOp operand, const XlaOp init_value,
4620 const XlaComputation& computation,
4621 absl::Span<const int64> window_dimensions,
4622 absl::Span<const int64> window_strides, Padding padding) {
4623 return operand.builder()->ReduceWindow(operand, init_value, computation,
4624 window_dimensions, window_strides,
4625 padding);
4626 }
4627
ReduceWindow(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)4628 XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
4629 absl::Span<const XlaOp> init_values,
4630 const XlaComputation& computation,
4631 absl::Span<const int64> window_dimensions,
4632 absl::Span<const int64> window_strides, Padding padding) {
4633 CHECK(!operands.empty());
4634 return operands[0].builder()->ReduceWindow(operands, init_values, computation,
4635 window_dimensions, window_strides,
4636 padding);
4637 }
4638
ReduceWindowWithGeneralPadding(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)4639 XlaOp ReduceWindowWithGeneralPadding(
4640 const XlaOp operand, const XlaOp init_value,
4641 const XlaComputation& computation,
4642 absl::Span<const int64> window_dimensions,
4643 absl::Span<const int64> window_strides,
4644 absl::Span<const int64> base_dilations,
4645 absl::Span<const int64> window_dilations,
4646 absl::Span<const std::pair<int64, int64>> padding) {
4647 return operand.builder()->ReduceWindowWithGeneralPadding(
4648 absl::MakeSpan(&operand, 1), absl::MakeSpan(&init_value, 1), computation,
4649 window_dimensions, window_strides, base_dilations, window_dilations,
4650 padding);
4651 }
4652
AllGather(const XlaOp operand,int64 all_gather_dimension,int64 shard_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Layout> & layout,const absl::optional<bool> use_global_device_ids)4653 XlaOp AllGather(const XlaOp operand, int64 all_gather_dimension,
4654 int64 shard_count,
4655 absl::Span<const ReplicaGroup> replica_groups,
4656 const absl::optional<ChannelHandle>& channel_id,
4657 const absl::optional<Layout>& layout,
4658 const absl::optional<bool> use_global_device_ids) {
4659 return operand.builder()->AllGather(operand, all_gather_dimension,
4660 shard_count, replica_groups, channel_id,
4661 layout, use_global_device_ids);
4662 }
4663
CrossReplicaSum(const XlaOp operand,absl::Span<const ReplicaGroup> replica_groups)4664 XlaOp CrossReplicaSum(const XlaOp operand,
4665 absl::Span<const ReplicaGroup> replica_groups) {
4666 return operand.builder()->CrossReplicaSum(operand, replica_groups);
4667 }
4668
AllReduce(const XlaOp operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Shape> & shape_with_layout)4669 XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
4670 absl::Span<const ReplicaGroup> replica_groups,
4671 const absl::optional<ChannelHandle>& channel_id,
4672 const absl::optional<Shape>& shape_with_layout) {
4673 return operand.builder()->AllReduce(operand, computation, replica_groups,
4674 channel_id, shape_with_layout);
4675 }
4676
AllToAll(const XlaOp operand,int64 split_dimension,int64 concat_dimension,int64 split_count,const std::vector<ReplicaGroup> & replica_groups,const absl::optional<Layout> & layout)4677 XlaOp AllToAll(const XlaOp operand, int64 split_dimension,
4678 int64 concat_dimension, int64 split_count,
4679 const std::vector<ReplicaGroup>& replica_groups,
4680 const absl::optional<Layout>& layout) {
4681 return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
4682 split_count, replica_groups, layout);
4683 }
4684
AllToAllTuple(const XlaOp operand,int64 split_dimension,int64 concat_dimension,int64 split_count,const std::vector<ReplicaGroup> & replica_groups,const absl::optional<Layout> & layout)4685 XlaOp AllToAllTuple(const XlaOp operand, int64 split_dimension,
4686 int64 concat_dimension, int64 split_count,
4687 const std::vector<ReplicaGroup>& replica_groups,
4688 const absl::optional<Layout>& layout) {
4689 return operand.builder()->AllToAllTuple(operand, split_dimension,
4690 concat_dimension, split_count,
4691 replica_groups, layout);
4692 }
4693
CollectivePermute(const XlaOp operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)4694 XlaOp CollectivePermute(
4695 const XlaOp operand,
4696 const std::vector<std::pair<int64, int64>>& source_target_pairs) {
4697 return operand.builder()->CollectivePermute(operand, source_target_pairs);
4698 }
4699
ReplicaId(XlaBuilder * builder)4700 XlaOp ReplicaId(XlaBuilder* builder) { return builder->ReplicaId(); }
4701
SelectAndScatter(const XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding,const XlaOp source,const XlaOp init_value,const XlaComputation & scatter)4702 XlaOp SelectAndScatter(const XlaOp operand, const XlaComputation& select,
4703 absl::Span<const int64> window_dimensions,
4704 absl::Span<const int64> window_strides, Padding padding,
4705 const XlaOp source, const XlaOp init_value,
4706 const XlaComputation& scatter) {
4707 return operand.builder()->SelectAndScatter(operand, select, window_dimensions,
4708 window_strides, padding, source,
4709 init_value, scatter);
4710 }
4711
SelectAndScatterWithGeneralPadding(const XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const XlaOp source,const XlaOp init_value,const XlaComputation & scatter)4712 XlaOp SelectAndScatterWithGeneralPadding(
4713 const XlaOp operand, const XlaComputation& select,
4714 absl::Span<const int64> window_dimensions,
4715 absl::Span<const int64> window_strides,
4716 absl::Span<const std::pair<int64, int64>> padding, const XlaOp source,
4717 const XlaOp init_value, const XlaComputation& scatter) {
4718 return operand.builder()->SelectAndScatterWithGeneralPadding(
4719 operand, select, window_dimensions, window_strides, padding, source,
4720 init_value, scatter);
4721 }
4722
Abs(const XlaOp operand)4723 XlaOp Abs(const XlaOp operand) {
4724 return operand.builder()->UnaryOp(HloOpcode::kAbs, operand);
4725 }
4726
Atan2(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4727 XlaOp Atan2(const XlaOp lhs, const XlaOp rhs,
4728 absl::Span<const int64> broadcast_dimensions) {
4729 return lhs.builder()->BinaryOp(HloOpcode::kAtan2, lhs, rhs,
4730 broadcast_dimensions);
4731 }
4732
Exp(const XlaOp operand)4733 XlaOp Exp(const XlaOp operand) {
4734 return operand.builder()->UnaryOp(HloOpcode::kExp, operand);
4735 }
Expm1(const XlaOp operand)4736 XlaOp Expm1(const XlaOp operand) {
4737 return operand.builder()->UnaryOp(HloOpcode::kExpm1, operand);
4738 }
Floor(const XlaOp operand)4739 XlaOp Floor(const XlaOp operand) {
4740 return operand.builder()->UnaryOp(HloOpcode::kFloor, operand);
4741 }
Ceil(const XlaOp operand)4742 XlaOp Ceil(const XlaOp operand) {
4743 return operand.builder()->UnaryOp(HloOpcode::kCeil, operand);
4744 }
Round(const XlaOp operand)4745 XlaOp Round(const XlaOp operand) {
4746 return operand.builder()->UnaryOp(HloOpcode::kRoundNearestAfz, operand);
4747 }
Log(const XlaOp operand)4748 XlaOp Log(const XlaOp operand) {
4749 return operand.builder()->UnaryOp(HloOpcode::kLog, operand);
4750 }
Log1p(const XlaOp operand)4751 XlaOp Log1p(const XlaOp operand) {
4752 return operand.builder()->UnaryOp(HloOpcode::kLog1p, operand);
4753 }
Logistic(const XlaOp operand)4754 XlaOp Logistic(const XlaOp operand) {
4755 return operand.builder()->UnaryOp(HloOpcode::kLogistic, operand);
4756 }
Sign(const XlaOp operand)4757 XlaOp Sign(const XlaOp operand) {
4758 return operand.builder()->UnaryOp(HloOpcode::kSign, operand);
4759 }
Clz(const XlaOp operand)4760 XlaOp Clz(const XlaOp operand) {
4761 return operand.builder()->UnaryOp(HloOpcode::kClz, operand);
4762 }
Cos(const XlaOp operand)4763 XlaOp Cos(const XlaOp operand) {
4764 return operand.builder()->UnaryOp(HloOpcode::kCos, operand);
4765 }
Sin(const XlaOp operand)4766 XlaOp Sin(const XlaOp operand) {
4767 return operand.builder()->UnaryOp(HloOpcode::kSin, operand);
4768 }
Tanh(const XlaOp operand)4769 XlaOp Tanh(const XlaOp operand) {
4770 return operand.builder()->UnaryOp(HloOpcode::kTanh, operand);
4771 }
Real(const XlaOp operand)4772 XlaOp Real(const XlaOp operand) {
4773 return operand.builder()->UnaryOp(HloOpcode::kReal, operand);
4774 }
Imag(const XlaOp operand)4775 XlaOp Imag(const XlaOp operand) {
4776 return operand.builder()->UnaryOp(HloOpcode::kImag, operand);
4777 }
Sqrt(const XlaOp operand)4778 XlaOp Sqrt(const XlaOp operand) {
4779 return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand);
4780 }
Cbrt(const XlaOp operand)4781 XlaOp Cbrt(const XlaOp operand) {
4782 return operand.builder()->UnaryOp(HloOpcode::kCbrt, operand);
4783 }
Rsqrt(const XlaOp operand)4784 XlaOp Rsqrt(const XlaOp operand) {
4785 return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand);
4786 }
4787
Pow(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4788 XlaOp Pow(const XlaOp lhs, const XlaOp rhs,
4789 absl::Span<const int64> broadcast_dimensions) {
4790 return lhs.builder()->BinaryOp(HloOpcode::kPower, lhs, rhs,
4791 broadcast_dimensions);
4792 }
4793
IsFinite(const XlaOp operand)4794 XlaOp IsFinite(const XlaOp operand) {
4795 return operand.builder()->UnaryOp(HloOpcode::kIsFinite, operand);
4796 }
4797
ConvertElementType(const XlaOp operand,PrimitiveType new_element_type)4798 XlaOp ConvertElementType(const XlaOp operand, PrimitiveType new_element_type) {
4799 return operand.builder()->ConvertElementType(operand, new_element_type);
4800 }
4801
BitcastConvertType(const XlaOp operand,PrimitiveType new_element_type)4802 XlaOp BitcastConvertType(const XlaOp operand, PrimitiveType new_element_type) {
4803 return operand.builder()->BitcastConvertType(operand, new_element_type);
4804 }
4805
Neg(const XlaOp operand)4806 XlaOp Neg(const XlaOp operand) {
4807 return operand.builder()->UnaryOp(HloOpcode::kNegate, operand);
4808 }
4809
Transpose(const XlaOp operand,absl::Span<const int64> permutation)4810 XlaOp Transpose(const XlaOp operand, absl::Span<const int64> permutation) {
4811 return operand.builder()->Transpose(operand, permutation);
4812 }
4813
Rev(const XlaOp operand,absl::Span<const int64> dimensions)4814 XlaOp Rev(const XlaOp operand, absl::Span<const int64> dimensions) {
4815 return operand.builder()->Rev(operand, dimensions);
4816 }
4817
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64 dimension,bool is_stable)4818 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
4819 int64 dimension, bool is_stable) {
4820 return operands[0].builder()->Sort(operands, comparator, dimension,
4821 is_stable);
4822 }
4823
Clamp(const XlaOp min,const XlaOp operand,const XlaOp max)4824 XlaOp Clamp(const XlaOp min, const XlaOp operand, const XlaOp max) {
4825 return min.builder()->Clamp(min, operand, max);
4826 }
4827
Map(XlaBuilder * builder,absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64> dimensions,absl::Span<const XlaOp> static_operands)4828 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
4829 const XlaComputation& computation, absl::Span<const int64> dimensions,
4830 absl::Span<const XlaOp> static_operands) {
4831 return builder->Map(operands, computation, dimensions, static_operands);
4832 }
4833
RngNormal(const XlaOp mu,const XlaOp sigma,const Shape & shape)4834 XlaOp RngNormal(const XlaOp mu, const XlaOp sigma, const Shape& shape) {
4835 return mu.builder()->RngNormal(mu, sigma, shape);
4836 }
4837
RngUniform(const XlaOp a,const XlaOp b,const Shape & shape)4838 XlaOp RngUniform(const XlaOp a, const XlaOp b, const Shape& shape) {
4839 return a.builder()->RngUniform(a, b, shape);
4840 }
4841
RngBitGenerator(RandomAlgorithm algorithm,const XlaOp initial_state,const Shape & shape)4842 XlaOp RngBitGenerator(RandomAlgorithm algorithm, const XlaOp initial_state,
4843 const Shape& shape) {
4844 return initial_state.builder()->RngBitGenerator(algorithm, initial_state,
4845 shape);
4846 }
4847
While(const XlaComputation & condition,const XlaComputation & body,const XlaOp init)4848 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
4849 const XlaOp init) {
4850 return init.builder()->While(condition, body, init);
4851 }
4852
Conditional(const XlaOp predicate,const XlaOp true_operand,const XlaComputation & true_computation,const XlaOp false_operand,const XlaComputation & false_computation)4853 XlaOp Conditional(const XlaOp predicate, const XlaOp true_operand,
4854 const XlaComputation& true_computation,
4855 const XlaOp false_operand,
4856 const XlaComputation& false_computation) {
4857 return predicate.builder()->Conditional(predicate, true_operand,
4858 true_computation, false_operand,
4859 false_computation);
4860 }
4861
Conditional(const XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)4862 XlaOp Conditional(const XlaOp branch_index,
4863 absl::Span<const XlaComputation* const> branch_computations,
4864 absl::Span<const XlaOp> branch_operands) {
4865 return branch_index.builder()->Conditional(branch_index, branch_computations,
4866 branch_operands);
4867 }
4868
ReducePrecision(const XlaOp operand,const int exponent_bits,const int mantissa_bits)4869 XlaOp ReducePrecision(const XlaOp operand, const int exponent_bits,
4870 const int mantissa_bits) {
4871 return operand.builder()->ReducePrecision(operand, exponent_bits,
4872 mantissa_bits);
4873 }
4874
Gather(const XlaOp input,const XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)4875 XlaOp Gather(const XlaOp input, const XlaOp start_indices,
4876 const GatherDimensionNumbers& dimension_numbers,
4877 absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
4878 return input.builder()->Gather(input, start_indices, dimension_numbers,
4879 slice_sizes, indices_are_sorted);
4880 }
4881
Scatter(const XlaOp input,const XlaOp scatter_indices,const XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)4882 XlaOp Scatter(const XlaOp input, const XlaOp scatter_indices,
4883 const XlaOp updates, const XlaComputation& update_computation,
4884 const ScatterDimensionNumbers& dimension_numbers,
4885 bool indices_are_sorted, bool unique_indices) {
4886 return input.builder()->Scatter(input, scatter_indices, updates,
4887 update_computation, dimension_numbers,
4888 indices_are_sorted, unique_indices);
4889 }
4890
Send(const XlaOp operand,const ChannelHandle & handle)4891 void Send(const XlaOp operand, const ChannelHandle& handle) {
4892 return operand.builder()->Send(operand, handle);
4893 }
4894
Recv(XlaBuilder * builder,const Shape & shape,const ChannelHandle & handle)4895 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
4896 const ChannelHandle& handle) {
4897 return builder->Recv(shape, handle);
4898 }
4899
SendWithToken(const XlaOp operand,const XlaOp token,const ChannelHandle & handle)4900 XlaOp SendWithToken(const XlaOp operand, const XlaOp token,
4901 const ChannelHandle& handle) {
4902 return operand.builder()->SendWithToken(operand, token, handle);
4903 }
4904
RecvWithToken(const XlaOp token,const Shape & shape,const ChannelHandle & handle)4905 XlaOp RecvWithToken(const XlaOp token, const Shape& shape,
4906 const ChannelHandle& handle) {
4907 return token.builder()->RecvWithToken(token, shape, handle);
4908 }
4909
SendToHost(const XlaOp operand,const XlaOp token,const Shape & shape_with_layout,const ChannelHandle & handle)4910 XlaOp SendToHost(const XlaOp operand, const XlaOp token,
4911 const Shape& shape_with_layout, const ChannelHandle& handle) {
4912 return operand.builder()->SendToHost(operand, token, shape_with_layout,
4913 handle);
4914 }
4915
RecvFromHost(const XlaOp token,const Shape & shape,const ChannelHandle & handle)4916 XlaOp RecvFromHost(const XlaOp token, const Shape& shape,
4917 const ChannelHandle& handle) {
4918 return token.builder()->RecvFromHost(token, shape, handle);
4919 }
4920
InfeedWithToken(const XlaOp token,const Shape & shape,const string & config)4921 XlaOp InfeedWithToken(const XlaOp token, const Shape& shape,
4922 const string& config) {
4923 return token.builder()->InfeedWithToken(token, shape, config);
4924 }
4925
OutfeedWithToken(const XlaOp operand,const XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)4926 XlaOp OutfeedWithToken(const XlaOp operand, const XlaOp token,
4927 const Shape& shape_with_layout,
4928 const string& outfeed_config) {
4929 return operand.builder()->OutfeedWithToken(operand, token, shape_with_layout,
4930 outfeed_config);
4931 }
4932
CreateToken(XlaBuilder * builder)4933 XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); }
4934
AfterAll(XlaBuilder * builder,absl::Span<const XlaOp> tokens)4935 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens) {
4936 return builder->AfterAll(tokens);
4937 }
4938
BatchNormTraining(const XlaOp operand,const XlaOp scale,const XlaOp offset,float epsilon,int64 feature_index)4939 XlaOp BatchNormTraining(const XlaOp operand, const XlaOp scale,
4940 const XlaOp offset, float epsilon,
4941 int64 feature_index) {
4942 return operand.builder()->BatchNormTraining(operand, scale, offset, epsilon,
4943 feature_index);
4944 }
4945
BatchNormInference(const XlaOp operand,const XlaOp scale,const XlaOp offset,const XlaOp mean,const XlaOp variance,float epsilon,int64 feature_index)4946 XlaOp BatchNormInference(const XlaOp operand, const XlaOp scale,
4947 const XlaOp offset, const XlaOp mean,
4948 const XlaOp variance, float epsilon,
4949 int64 feature_index) {
4950 return operand.builder()->BatchNormInference(
4951 operand, scale, offset, mean, variance, epsilon, feature_index);
4952 }
4953
BatchNormGrad(const XlaOp operand,const XlaOp scale,const XlaOp batch_mean,const XlaOp batch_var,const XlaOp grad_output,float epsilon,int64 feature_index)4954 XlaOp BatchNormGrad(const XlaOp operand, const XlaOp scale,
4955 const XlaOp batch_mean, const XlaOp batch_var,
4956 const XlaOp grad_output, float epsilon,
4957 int64 feature_index) {
4958 return operand.builder()->BatchNormGrad(operand, scale, batch_mean, batch_var,
4959 grad_output, epsilon, feature_index);
4960 }
4961
Iota(XlaBuilder * builder,PrimitiveType type,int64 size)4962 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) {
4963 return builder->Iota(type, size);
4964 }
4965
Iota(XlaBuilder * builder,const Shape & shape,int64 iota_dimension)4966 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) {
4967 return builder->Iota(shape, iota_dimension);
4968 }
4969
GetDimensionSize(const XlaOp operand,int64 dimension)4970 XlaOp GetDimensionSize(const XlaOp operand, int64 dimension) {
4971 return operand.builder()->GetDimensionSize(operand, dimension);
4972 }
4973
SetDimensionSize(const XlaOp operand,const XlaOp val,int64 dimension)4974 XlaOp SetDimensionSize(const XlaOp operand, const XlaOp val, int64 dimension) {
4975 return operand.builder()->SetDimensionSize(operand, val, dimension);
4976 }
4977
RemoveDynamicDimension(const XlaOp operand,int64 dimension)4978 XlaOp RemoveDynamicDimension(const XlaOp operand, int64 dimension) {
4979 return operand.builder()->RemoveDynamicDimension(operand, dimension);
4980 }
4981
4982 } // namespace xla
4983