1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/xla_builder.h"
17
18 #include <functional>
19 #include <numeric>
20 #include <queue>
21 #include <string>
22 #include <utility>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/memory/memory.h"
28 #include "absl/strings/match.h"
29 #include "absl/strings/numbers.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/strings/str_split.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/xla/client/sharding_builder.h"
35 #include "tensorflow/compiler/xla/client/xla_computation.h"
36 #include "tensorflow/compiler/xla/comparison_util.h"
37 #include "tensorflow/compiler/xla/execution_options_util.h"
38 #include "tensorflow/compiler/xla/layout_util.h"
39 #include "tensorflow/compiler/xla/permutation_util.h"
40 #include "tensorflow/compiler/xla/primitive_util.h"
41 #include "tensorflow/compiler/xla/service/hlo.pb.h"
42 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
43 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
44 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
45 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
46 #include "tensorflow/compiler/xla/service/shape_inference.h"
47 #include "tensorflow/compiler/xla/status_macros.h"
48 #include "tensorflow/compiler/xla/util.h"
49 #include "tensorflow/compiler/xla/window_util.h"
50 #include "tensorflow/compiler/xla/xla_data.pb.h"
51 #include "tensorflow/core/platform/errors.h"
52 #include "tensorflow/core/platform/macros.h"
53 #include "tensorflow/stream_executor/lib/statusor.h"
54
55 namespace xla {
56
57 using absl::StrCat;
58
59 namespace {
60
61 static const char kNameSeparator = '.';
62
63 // Retrieves the base name of an instruction or computation fully qualified
64 // name, using separator as boundary between the initial base name part, and
65 // the numeric identification.
GetBaseName(const string & name,char separator)66 string GetBaseName(const string& name, char separator) {
67 auto pos = name.rfind(separator);
68 CHECK_NE(pos, string::npos) << name;
69 return name.substr(0, pos);
70 }
71
72 // Generates a fully qualified computation/instruction name.
GetFullName(const string & base_name,char separator,int64_t id)73 string GetFullName(const string& base_name, char separator, int64_t id) {
74 const char separator_str[] = {separator, '\0'};
75 return StrCat(base_name, separator_str, id);
76 }
77
78 // Common function to standardize setting name and IDs on computation and
79 // instruction proto entities.
80 template <typename T>
SetProtoIdAndName(T * entry,const string & base_name,char separator,int64_t id)81 void SetProtoIdAndName(T* entry, const string& base_name, char separator,
82 int64_t id) {
83 entry->set_id(id);
84 entry->set_name(GetFullName(base_name, separator, id));
85 }
86
InstrIsSetBound(const HloInstructionProto * instr_proto)87 bool InstrIsSetBound(const HloInstructionProto* instr_proto) {
88 HloOpcode opcode = StringToHloOpcode(instr_proto->opcode()).ValueOrDie();
89 if (opcode == HloOpcode::kCustomCall &&
90 instr_proto->custom_call_target() == "SetBound") {
91 return true;
92 }
93 return false;
94 }
95
96 } // namespace
97
98 namespace internal {
99
BuildFusion(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::string_view fusion_kind,const XlaComputation & fused_computation)100 XlaOp XlaBuilderFriend::BuildFusion(XlaBuilder* builder,
101 absl::Span<const XlaOp> operands,
102 absl::string_view fusion_kind,
103 const XlaComputation& fused_computation) {
104 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
105 HloInstructionProto instr;
106 instr.set_fusion_kind(std::string(fusion_kind));
107 std::vector<const Shape*> operand_shape_ptrs;
108 TF_ASSIGN_OR_RETURN(auto program_shape,
109 fused_computation.GetProgramShape());
110 *instr.mutable_shape() = program_shape.result().ToProto();
111 builder->AddCalledComputation(fused_computation, &instr);
112 return builder->AddInstruction(std::move(instr), HloOpcode::kFusion,
113 operands);
114 });
115 }
116
BuildBitcast(XlaBuilder * builder,XlaOp operand,const Shape & shape)117 XlaOp XlaBuilderFriend::BuildBitcast(XlaBuilder* builder, XlaOp operand,
118 const Shape& shape) {
119 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
120 HloInstructionProto instr;
121 *instr.mutable_shape() = shape.ToProto();
122 return builder->AddInstruction(std::move(instr), HloOpcode::kBitcast,
123 {operand});
124 });
125 }
126
GetInstruction(XlaOp op)127 HloInstructionProto* XlaBuilderFriend::GetInstruction(XlaOp op) {
128 return &op.builder()
129 ->instructions_[op.builder()->handle_to_index_[op.handle_]];
130 }
131
132 } // namespace internal
133
operator -(XlaOp x)134 XlaOp operator-(XlaOp x) { return Neg(x); }
operator +(XlaOp x,XlaOp y)135 XlaOp operator+(XlaOp x, XlaOp y) { return Add(x, y); }
operator -(XlaOp x,XlaOp y)136 XlaOp operator-(XlaOp x, XlaOp y) { return Sub(x, y); }
operator *(XlaOp x,XlaOp y)137 XlaOp operator*(XlaOp x, XlaOp y) { return Mul(x, y); }
operator /(XlaOp x,XlaOp y)138 XlaOp operator/(XlaOp x, XlaOp y) { return Div(x, y); }
operator %(XlaOp x,XlaOp y)139 XlaOp operator%(XlaOp x, XlaOp y) { return Rem(x, y); }
140
operator ~(XlaOp x)141 XlaOp operator~(XlaOp x) { return Not(x); }
operator &(XlaOp x,XlaOp y)142 XlaOp operator&(XlaOp x, XlaOp y) { return And(x, y); }
operator |(XlaOp x,XlaOp y)143 XlaOp operator|(XlaOp x, XlaOp y) { return Or(x, y); }
operator ^(XlaOp x,XlaOp y)144 XlaOp operator^(XlaOp x, XlaOp y) { return Xor(x, y); }
operator <<(XlaOp x,XlaOp y)145 XlaOp operator<<(XlaOp x, XlaOp y) { return ShiftLeft(x, y); }
146
operator >>(XlaOp x,XlaOp y)147 XlaOp operator>>(XlaOp x, XlaOp y) {
148 XlaBuilder* builder = x.builder();
149 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
150 TF_ASSIGN_OR_RETURN(const xla::Shape* shape, builder->GetShapePtr(x));
151 if (!ShapeUtil::ElementIsIntegral(*shape)) {
152 return InvalidArgument(
153 "Argument to >> operator does not have an integral type (%s).",
154 ShapeUtil::HumanString(*shape));
155 }
156 if (ShapeUtil::ElementIsSigned(*shape)) {
157 return ShiftRightArithmetic(x, y);
158 } else {
159 return ShiftRightLogical(x, y);
160 }
161 });
162 }
163
GetShapePtr(XlaOp op) const164 StatusOr<const Shape*> XlaBuilder::GetShapePtr(XlaOp op) const {
165 TF_RETURN_IF_ERROR(first_error_);
166 TF_RETURN_IF_ERROR(CheckOpBuilder(op));
167 auto it = handle_to_index_.find(op.handle());
168 if (it == handle_to_index_.end()) {
169 return InvalidArgument("No XlaOp with handle %d", op.handle());
170 }
171 return instruction_shapes_.at(it->second).get();
172 }
173
GetShape(XlaOp op) const174 StatusOr<Shape> XlaBuilder::GetShape(XlaOp op) const {
175 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(op));
176 return *shape;
177 }
178
GetOperandShapes(absl::Span<const XlaOp> operands) const179 StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
180 absl::Span<const XlaOp> operands) const {
181 std::vector<Shape> operand_shapes;
182 for (XlaOp operand : operands) {
183 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
184 operand_shapes.push_back(*shape);
185 }
186 return operand_shapes;
187 }
188
OpToString(XlaOp op) const189 std::string XlaBuilder::OpToString(XlaOp op) const {
190 std::string s;
191 ToStringHelper(&s, /*ident=*/0, op.handle());
192 return s;
193 }
194
ShapeToString(const xla::ShapeProto & shape)195 static std::string ShapeToString(const xla::ShapeProto& shape) {
196 if (shape.tuple_shapes_size() > 1) {
197 return absl::StrCat(
198 "(",
199 absl::StrJoin(shape.tuple_shapes(), ", ",
200 [&](std::string* s, const xla::ShapeProto& subshape) {
201 absl::StrAppend(s, ShapeToString(subshape));
202 }),
203 ")");
204 }
205 return absl::StrCat("[", absl::StrJoin(shape.dimensions(), ", "), "]");
206 }
207
ToStringHelper(std::string * out,int ident,int64_t op_handle) const208 void XlaBuilder::ToStringHelper(std::string* out, int ident,
209 int64_t op_handle) const {
210 const HloInstructionProto& instr =
211 *(LookUpInstructionByHandle(op_handle).ValueOrDie());
212 absl::StrAppend(out, std::string(ident, ' '), instr.opcode(),
213 ", shape=", ShapeToString(instr.shape()));
214 if (instr.has_metadata()) {
215 absl::StrAppend(out, ", metadata={", instr.metadata().source_file(), ":",
216 instr.metadata().source_line(), "}");
217 }
218 if (instr.operand_ids_size()) {
219 absl::StrAppend(out, "\n");
220 }
221 absl::StrAppend(out, absl::StrJoin(instr.operand_ids(), "\n",
222 [&](std::string* s, int64_t subop) {
223 ToStringHelper(s, ident + 2, subop);
224 }));
225 }
226
XlaBuilder(const string & computation_name)227 XlaBuilder::XlaBuilder(const string& computation_name)
228 : name_(computation_name) {}
229
~XlaBuilder()230 XlaBuilder::~XlaBuilder() {}
231
ReportError(const Status & error)232 XlaOp XlaBuilder::ReportError(const Status& error) {
233 CHECK(!error.ok());
234 if (die_immediately_on_error_) {
235 LOG(FATAL) << "error building computation: " << error;
236 }
237
238 if (first_error_.ok()) {
239 first_error_ = error;
240 first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
241 }
242 return XlaOp(this);
243 }
244
ReportErrorOrReturn(const StatusOr<XlaOp> & op)245 XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr<XlaOp>& op) {
246 if (!first_error_.ok()) {
247 return XlaOp(this);
248 }
249 if (!op.ok()) {
250 return ReportError(op.status());
251 }
252 return op.ValueOrDie();
253 }
254
ReportErrorOrReturn(const std::function<StatusOr<XlaOp> ()> & op_creator)255 XlaOp XlaBuilder::ReportErrorOrReturn(
256 const std::function<StatusOr<XlaOp>()>& op_creator) {
257 return ReportErrorOrReturn(op_creator());
258 }
259
GetProgramShape(int64_t root_id) const260 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64_t root_id) const {
261 TF_RETURN_IF_ERROR(first_error_);
262 TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
263 LookUpInstructionByHandle(root_id));
264
265 ProgramShape program_shape;
266
267 *program_shape.mutable_result() = Shape(root_proto->shape());
268
269 // Check that the parameter numbers are continuous from 0, and add parameter
270 // shapes and names to the program shape.
271 const int64_t param_count = parameter_numbers_.size();
272 for (int64_t i = 0; i < param_count; i++) {
273 program_shape.add_parameters();
274 program_shape.add_parameter_names();
275 }
276 for (const HloInstructionProto& instr : instructions_) {
277 // Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So
278 // to verify continuity, we just need to verify that every parameter is in
279 // the right range.
280 if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
281 const int64_t index = instr.parameter_number();
282 TF_RET_CHECK(index >= 0 && index < param_count)
283 << "invalid parameter number: " << index;
284 *program_shape.mutable_parameters(index) = Shape(instr.shape());
285 *program_shape.mutable_parameter_names(index) = instr.name();
286 }
287 }
288 return program_shape;
289 }
290
GetProgramShape() const291 StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const {
292 TF_RET_CHECK(!instructions_.empty());
293 return GetProgramShape(instructions_.back().id());
294 }
295
GetProgramShape(XlaOp root) const296 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(XlaOp root) const {
297 if (root.builder_ != this) {
298 return InvalidArgument("Given root operation is not in this computation.");
299 }
300 return GetProgramShape(root.handle());
301 }
302
IsConstantVisitor(const int64_t op_handle,int depth,absl::flat_hash_set<int64> * visited,bool * is_constant) const303 void XlaBuilder::IsConstantVisitor(const int64_t op_handle, int depth,
304 absl::flat_hash_set<int64>* visited,
305 bool* is_constant) const {
306 if (visited->contains(op_handle) || !*is_constant) {
307 return;
308 }
309
310 const HloInstructionProto& instr =
311 *(LookUpInstructionByHandle(op_handle).ValueOrDie());
312 HloInstructionProto to_print(instr);
313 to_print.clear_shape();
314 const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
315 const string indent =
316 absl::StrJoin(std::vector<absl::string_view>(depth, " "), "");
317 if (VLOG_IS_ON(2)) {
318 VLOG(2) << indent << "Visiting:";
319 for (const auto& l : absl::StrSplit(to_print.DebugString(), '\n')) {
320 VLOG(2) << indent << l;
321 }
322 }
323 switch (opcode) {
324 default:
325 for (const int64_t operand_id : instr.operand_ids()) {
326 IsConstantVisitor(operand_id, depth + 1, visited, is_constant);
327 }
328 // TODO(b/32495713): We aren't checking the called computations.
329 break;
330
331 case HloOpcode::kGetDimensionSize:
332 // GetDimensionSize is always considered constant in XLA -- If a dynamic
333 // dimension is presented, -1 is returned.
334 break;
335 // Non functional ops.
336 case HloOpcode::kRng:
337 case HloOpcode::kAllReduce:
338 case HloOpcode::kReduceScatter:
339 // TODO(b/33009255): Implement constant folding for cross replica sum.
340 case HloOpcode::kInfeed:
341 case HloOpcode::kOutfeed:
342 case HloOpcode::kCall:
343 // TODO(b/32495713): We aren't checking the to_apply computation itself,
344 // so we conservatively say that computations containing the Call op
345 // cannot be constant. We cannot set is_functional=false in other similar
346 // cases since we're already relying on IsConstant to return true.
347 case HloOpcode::kCustomCall:
348 if (instr.custom_call_target() == "SetBound") {
349 // Set bound is considered constant -- the bound is used as the value.
350 break;
351 }
352 TF_FALLTHROUGH_INTENDED;
353 case HloOpcode::kWhile:
354 // TODO(b/32495713): We aren't checking the condition and body
355 // computations themselves.
356 case HloOpcode::kScatter:
357 // TODO(b/32495713): We aren't checking the embedded computation in
358 // Scatter.
359 case HloOpcode::kSend:
360 case HloOpcode::kRecv:
361 case HloOpcode::kParameter:
362 *is_constant = false;
363 break;
364 case HloOpcode::kGetTupleElement: {
365 const HloInstructionProto& operand_instr =
366 *(LookUpInstructionByHandle(instr.operand_ids(0)).ValueOrDie());
367 if (HloOpcodeString(HloOpcode::kTuple) == operand_instr.opcode()) {
368 IsConstantVisitor(operand_instr.operand_ids(instr.tuple_index()),
369 depth + 1, visited, is_constant);
370 } else {
371 for (const int64_t operand_id : instr.operand_ids()) {
372 IsConstantVisitor(operand_id, depth + 1, visited, is_constant);
373 }
374 }
375 }
376 }
377 if (VLOG_IS_ON(1) && !*is_constant) {
378 VLOG(1) << indent << "Non-constant: ";
379 for (const auto& l : absl::StrSplit(to_print.DebugString(), '\n')) {
380 VLOG(1) << indent << l;
381 }
382 }
383 visited->insert(op_handle);
384 }
385
SetDynamicBinding(int64_t dynamic_size_param_num,ShapeIndex dynamic_size_param_index,int64_t target_param_num,ShapeIndex target_param_index,int64_t target_dim_num)386 Status XlaBuilder::SetDynamicBinding(int64_t dynamic_size_param_num,
387 ShapeIndex dynamic_size_param_index,
388 int64_t target_param_num,
389 ShapeIndex target_param_index,
390 int64_t target_dim_num) {
391 bool param_exists = false;
392 for (size_t index = 0; index < instructions_.size(); ++index) {
393 HloInstructionProto& instr = instructions_[index];
394 if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
395 instr.parameter_number() == target_param_num) {
396 param_exists = true;
397 Shape param_shape(instr.shape());
398 Shape* param_shape_ptr = ¶m_shape;
399 for (int64_t index : target_param_index) {
400 param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index);
401 }
402 param_shape_ptr->set_dynamic_dimension(target_dim_num,
403 /*is_dynamic=*/true);
404 *instr.mutable_shape() = param_shape.ToProto();
405 instruction_shapes_[index] =
406 absl::make_unique<Shape>(std::move(param_shape));
407 }
408 }
409 if (!param_exists) {
410 return InvalidArgument(
411 "Asked to mark parameter %lld as dynamic sized parameter, but the "
412 "doesn't exists",
413 target_param_num);
414 }
415
416 TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind(
417 DynamicParameterBinding::DynamicParameter{dynamic_size_param_num,
418 dynamic_size_param_index},
419 DynamicParameterBinding::DynamicDimension{
420 target_param_num, target_param_index, target_dim_num}));
421 return Status::OK();
422 }
423
SetInstructionFrontendAttribute(const XlaOp op,std::string attribute,std::string value)424 Status XlaBuilder::SetInstructionFrontendAttribute(const XlaOp op,
425 std::string attribute,
426 std::string value) {
427 TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op));
428 auto* frontend_attributes = instr_proto->mutable_frontend_attributes();
429 (*frontend_attributes->mutable_map())[attribute] = std::move(value);
430 return Status::OK();
431 }
432
BuildAndNoteError()433 XlaComputation XlaBuilder::BuildAndNoteError() {
434 DCHECK(parent_builder_ != nullptr);
435 auto build_status = Build();
436 if (!build_status.ok()) {
437 parent_builder_->ReportError(
438 AddStatus(build_status.status(), absl::StrCat("error from: ", name_)));
439 return {};
440 }
441 return build_status.ConsumeValueOrDie();
442 }
443
GetCurrentStatus() const444 Status XlaBuilder::GetCurrentStatus() const {
445 if (!first_error_.ok()) {
446 string backtrace;
447 first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
448 return AppendStatus(first_error_, backtrace);
449 }
450 return Status::OK();
451 }
452
Build(bool remove_dynamic_dimensions)453 StatusOr<XlaComputation> XlaBuilder::Build(bool remove_dynamic_dimensions) {
454 TF_RETURN_IF_ERROR(GetCurrentStatus());
455 return Build(instructions_.back().id(), remove_dynamic_dimensions);
456 }
457
Build(XlaOp root,bool remove_dynamic_dimensions)458 StatusOr<XlaComputation> XlaBuilder::Build(XlaOp root,
459 bool remove_dynamic_dimensions) {
460 if (root.builder_ != this) {
461 return InvalidArgument("Given root operation is not in this computation.");
462 }
463 return Build(root.handle(), remove_dynamic_dimensions);
464 }
465
Build(int64_t root_id,bool remove_dynamic_dimensions)466 StatusOr<XlaComputation> XlaBuilder::Build(int64_t root_id,
467 bool remove_dynamic_dimensions) {
468 TF_RETURN_IF_ERROR(GetCurrentStatus());
469
470 // TODO(b/121223198): XLA backend cannot handle dynamic dimensions yet, remove
471 // all dynamic dimensions before building xla program until we have support in
472 // the backend.
473 if (remove_dynamic_dimensions) {
474 std::function<void(Shape*)> remove_dynamic_dimension = [&](Shape* shape) {
475 if (shape->tuple_shapes_size() != 0) {
476 for (int64_t i = 0; i < shape->tuple_shapes_size(); ++i) {
477 remove_dynamic_dimension(shape->mutable_tuple_shapes(i));
478 }
479 }
480 for (int64_t i = 0; i < shape->dimensions_size(); ++i) {
481 shape->set_dynamic_dimension(i, false);
482 }
483 };
484 for (size_t index = 0; index < instructions_.size(); ++index) {
485 remove_dynamic_dimension(instruction_shapes_[index].get());
486 *instructions_[index].mutable_shape() =
487 instruction_shapes_[index]->ToProto();
488 }
489 }
490
491 HloComputationProto entry;
492 SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId());
493 TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id));
494 *entry.mutable_program_shape() = program_shape.ToProto();
495 entry.set_root_id(root_id);
496
497 for (auto& instruction : instructions_) {
498 // Ensures that the instruction names are unique among the whole graph.
499 instruction.set_name(
500 GetFullName(instruction.name(), kNameSeparator, instruction.id()));
501 entry.add_instructions()->Swap(&instruction);
502 }
503
504 XlaComputation computation(entry.id());
505 HloModuleProto* module = computation.mutable_proto();
506 module->set_name(entry.name());
507 module->set_id(entry.id());
508 module->set_entry_computation_name(entry.name());
509 module->set_entry_computation_id(entry.id());
510 *module->mutable_host_program_shape() = entry.program_shape();
511 for (auto& e : embedded_) {
512 module->add_computations()->Swap(&e.second);
513 }
514 module->add_computations()->Swap(&entry);
515 if (!input_output_aliases_.empty()) {
516 TF_RETURN_IF_ERROR(
517 PopulateInputOutputAlias(module, program_shape, input_output_aliases_));
518 }
519 *(module->mutable_dynamic_parameter_binding()) =
520 dynamic_parameter_binding_.ToProto();
521
522 // Clear data held by this builder.
523 this->instructions_.clear();
524 this->instruction_shapes_.clear();
525 this->handle_to_index_.clear();
526 this->embedded_.clear();
527 this->parameter_numbers_.clear();
528
529 return std::move(computation);
530 }
531
PopulateInputOutputAlias(HloModuleProto * module,const ProgramShape & program_shape,const std::vector<InputOutputAlias> & input_output_aliases)532 /* static */ Status XlaBuilder::PopulateInputOutputAlias(
533 HloModuleProto* module, const ProgramShape& program_shape,
534 const std::vector<InputOutputAlias>& input_output_aliases) {
535 HloInputOutputAliasConfig config(program_shape.result());
536 for (auto& alias : input_output_aliases) {
537 // The HloInputOutputAliasConfig does not do parameter validation as it only
538 // carries the result shape. Maybe it should be constructed with a
539 // ProgramShape to allow full validation. We will still get an error when
540 // trying to compile the HLO module, but would be better to have validation
541 // at this stage.
542 if (alias.param_number >= program_shape.parameters_size()) {
543 return InvalidArgument("Invalid parameter number %ld (total %ld)",
544 alias.param_number,
545 program_shape.parameters_size());
546 }
547 const Shape& parameter_shape = program_shape.parameters(alias.param_number);
548 if (!ShapeUtil::IndexIsValid(parameter_shape, alias.param_index)) {
549 return InvalidArgument("Invalid parameter %ld index: %s",
550 alias.param_number,
551 alias.param_index.ToString().c_str());
552 }
553 TF_RETURN_IF_ERROR(config.SetUpAlias(alias.output_index, alias.param_number,
554 alias.param_index, alias.kind));
555 }
556 *module->mutable_input_output_alias() = config.ToProto();
557 return Status::OK();
558 }
559
InDimBroadcast(const Shape & shape,XlaOp operand,absl::Span<const int64> broadcast_dimensions)560 StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
561 const Shape& shape, XlaOp operand,
562 absl::Span<const int64> broadcast_dimensions) {
563 TF_RETURN_IF_ERROR(first_error_);
564
565 HloInstructionProto instr;
566 *instr.mutable_shape() = shape.ToProto();
567 for (int64_t dim : broadcast_dimensions) {
568 instr.add_dimensions(dim);
569 }
570
571 return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand});
572 }
573
AddBroadcastSequence(const Shape & output_shape,XlaOp operand)574 StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
575 XlaOp operand) {
576 TF_RETURN_IF_ERROR(first_error_);
577
578 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
579
580 CHECK(ShapeUtil::IsScalar(*operand_shape) ||
581 operand_shape->rank() == output_shape.rank());
582 Shape broadcast_shape =
583 ShapeUtil::ChangeElementType(output_shape, operand_shape->element_type());
584
585 // Do explicit broadcast for scalar.
586 if (ShapeUtil::IsScalar(*operand_shape)) {
587 return InDimBroadcast(broadcast_shape, operand, {});
588 }
589
590 // Do explicit broadcast for degenerate broadcast.
591 std::vector<int64> broadcast_dimensions;
592 std::vector<int64> reshaped_dimensions;
593 for (int i = 0; i < operand_shape->rank(); i++) {
594 if (operand_shape->dimensions(i) == output_shape.dimensions(i)) {
595 broadcast_dimensions.push_back(i);
596 reshaped_dimensions.push_back(operand_shape->dimensions(i));
597 } else {
598 TF_RET_CHECK(operand_shape->dimensions(i) == 1)
599 << "An explicit broadcast sequence requires the broadcasted "
600 "dimensions to be trivial; operand shape: "
601 << *operand_shape << "; output_shape: " << output_shape;
602 }
603 }
604
605 Shape reshaped_shape =
606 ShapeUtil::MakeShape(operand_shape->element_type(), reshaped_dimensions);
607
608 std::vector<std::pair<int64, int64>> unmodified_dims =
609 ShapeUtil::DimensionsUnmodifiedByReshape(*operand_shape, reshaped_shape);
610
611 for (auto& unmodified : unmodified_dims) {
612 if (operand_shape->is_dynamic_dimension(unmodified.first)) {
613 reshaped_shape.set_dynamic_dimension(unmodified.second, true);
614 }
615 }
616
617 // Eliminate the size one dimensions.
618 TF_ASSIGN_OR_RETURN(
619 XlaOp reshaped_operand,
620 ReshapeInternal(reshaped_shape, operand, /*inferred_dimension=*/-1));
621 // Broadcast 'reshape' up to the larger size.
622 return InDimBroadcast(broadcast_shape, reshaped_operand,
623 broadcast_dimensions);
624 }
625
UnaryOp(HloOpcode unop,XlaOp operand)626 XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
627 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
628 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
629 TF_ASSIGN_OR_RETURN(
630 Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape));
631 return AddOpWithShape(unop, shape, {operand});
632 });
633 }
634
BinaryOp(HloOpcode binop,XlaOp lhs,XlaOp rhs,absl::Span<const int64> broadcast_dimensions,absl::optional<ComparisonDirection> direction,absl::optional<Comparison::Type> type)635 XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
636 absl::Span<const int64> broadcast_dimensions,
637 absl::optional<ComparisonDirection> direction,
638 absl::optional<Comparison::Type> type) {
639 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
640 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
641 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
642 TF_ASSIGN_OR_RETURN(
643 Shape shape, ShapeInference::InferBinaryOpShape(
644 binop, *lhs_shape, *rhs_shape, broadcast_dimensions));
645
646 const int64_t lhs_rank = lhs_shape->rank();
647 const int64_t rhs_rank = rhs_shape->rank();
648
649 XlaOp updated_lhs = lhs;
650 XlaOp updated_rhs = rhs;
651 if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) {
652 const bool should_broadcast_lhs = lhs_rank < rhs_rank;
653 XlaOp from = should_broadcast_lhs ? lhs : rhs;
654 const Shape& from_shape = should_broadcast_lhs ? *lhs_shape : *rhs_shape;
655
656 std::vector<int64> to_size;
657 std::vector<bool> to_size_is_dynamic;
658 for (int i = 0; i < shape.rank(); i++) {
659 to_size.push_back(shape.dimensions(i));
660 to_size_is_dynamic.push_back(shape.is_dynamic_dimension(i));
661 }
662 for (int64_t from_dim = 0; from_dim < from_shape.rank(); from_dim++) {
663 int64_t to_dim = broadcast_dimensions[from_dim];
664 to_size[to_dim] = from_shape.dimensions(from_dim);
665 to_size_is_dynamic[to_dim] = from_shape.is_dynamic_dimension(from_dim);
666 }
667
668 const Shape& broadcasted_shape = ShapeUtil::MakeShape(
669 from_shape.element_type(), to_size, to_size_is_dynamic);
670 TF_ASSIGN_OR_RETURN(
671 XlaOp broadcasted_operand,
672 InDimBroadcast(broadcasted_shape, from, broadcast_dimensions));
673
674 updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs;
675 updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs;
676 }
677
678 TF_ASSIGN_OR_RETURN(const Shape* updated_lhs_shape,
679 GetShapePtr(updated_lhs));
680 if (!ShapeUtil::SameDimensions(shape, *updated_lhs_shape)) {
681 TF_ASSIGN_OR_RETURN(updated_lhs,
682 AddBroadcastSequence(shape, updated_lhs));
683 }
684 TF_ASSIGN_OR_RETURN(const Shape* updated_rhs_shape,
685 GetShapePtr(updated_rhs));
686 if (!ShapeUtil::SameDimensions(shape, *updated_rhs_shape)) {
687 TF_ASSIGN_OR_RETURN(updated_rhs,
688 AddBroadcastSequence(shape, updated_rhs));
689 }
690
691 if (binop == HloOpcode::kCompare) {
692 if (!direction.has_value()) {
693 return InvalidArgument(
694 "kCompare expects a ComparisonDirection, but none provided.");
695 }
696 if (type == absl::nullopt) {
697 return Compare(shape, updated_lhs, updated_rhs, *direction);
698 } else {
699 return Compare(shape, updated_lhs, updated_rhs, *direction, *type);
700 }
701 }
702
703 if (direction.has_value()) {
704 return InvalidArgument(
705 "A comparison direction is provided for a non-compare opcode: %s.",
706 HloOpcodeString(binop));
707 }
708 return BinaryOpNoBroadcast(binop, shape, updated_lhs, updated_rhs);
709 });
710 }
711
BinaryOpNoBroadcast(HloOpcode binop,const Shape & shape,XlaOp lhs,XlaOp rhs)712 XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
713 XlaOp lhs, XlaOp rhs) {
714 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
715 HloInstructionProto instr;
716 *instr.mutable_shape() = shape.ToProto();
717 return AddInstruction(std::move(instr), binop, {lhs, rhs});
718 });
719 }
720
Compare(const Shape & shape,XlaOp lhs,XlaOp rhs,ComparisonDirection direction)721 StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
722 ComparisonDirection direction) {
723 TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(lhs));
724 return Compare(
725 shape, lhs, rhs, direction,
726 Comparison::DefaultComparisonType(operand_shape.element_type()));
727 }
728
Compare(const Shape & shape,XlaOp lhs,XlaOp rhs,ComparisonDirection direction,Comparison::Type type)729 StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
730 ComparisonDirection direction,
731 Comparison::Type type) {
732 HloInstructionProto instr;
733 instr.set_comparison_direction(ComparisonDirectionToString(direction));
734 instr.set_comparison_type(ComparisonTypeToString(type));
735 *instr.mutable_shape() = shape.ToProto();
736 return AddInstruction(std::move(instr), HloOpcode::kCompare, {lhs, rhs});
737 }
738
TernaryOp(HloOpcode triop,XlaOp lhs,XlaOp rhs,XlaOp ehs)739 XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) {
740 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
741 XlaOp updated_lhs = lhs;
742 XlaOp updated_rhs = rhs;
743 XlaOp updated_ehs = ehs;
744 // The client API supports implicit broadcast for kSelect and kClamp, but
745 // XLA does not support implicit broadcast. Make implicit broadcast explicit
746 // and update the operands.
747 if (triop == HloOpcode::kSelect || triop == HloOpcode::kClamp) {
748 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
749 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
750 TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(ehs));
751
752 absl::optional<Shape> non_scalar_shape;
753 for (const Shape* shape : {lhs_shape, rhs_shape, ehs_shape}) {
754 if (shape->IsArray() && shape->rank() != 0) {
755 if (non_scalar_shape.has_value()) {
756 // TODO(jpienaar): The case where we need to compute the broadcasted
757 // shape by considering multiple of the shapes is not implemented.
758 // Consider reusing getBroadcastedType from mlir/Dialect/Traits.h.
759 TF_RET_CHECK(non_scalar_shape.value().dimensions() ==
760 shape->dimensions())
761 << "Unimplemented implicit broadcast.";
762 } else {
763 non_scalar_shape = *shape;
764 }
765 }
766 }
767 if (non_scalar_shape.has_value()) {
768 if (ShapeUtil::IsScalar(*lhs_shape)) {
769 TF_ASSIGN_OR_RETURN(updated_lhs,
770 AddBroadcastSequence(*non_scalar_shape, lhs));
771 }
772 if (ShapeUtil::IsScalar(*rhs_shape)) {
773 TF_ASSIGN_OR_RETURN(updated_rhs,
774 AddBroadcastSequence(*non_scalar_shape, rhs));
775 }
776 if (ShapeUtil::IsScalar(*ehs_shape)) {
777 TF_ASSIGN_OR_RETURN(updated_ehs,
778 AddBroadcastSequence(*non_scalar_shape, ehs));
779 }
780 }
781 }
782
783 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(updated_lhs));
784 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(updated_rhs));
785 TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(updated_ehs));
786 StatusOr<const Shape> status_or_shape = ShapeInference::InferTernaryOpShape(
787 triop, *lhs_shape, *rhs_shape, *ehs_shape);
788 if (!status_or_shape.status().ok()) {
789 return InvalidArgument(
790 "%s Input scalar shapes may have been changed to non-scalar shapes.",
791 status_or_shape.status().error_message());
792 }
793
794 return AddOpWithShape(triop, status_or_shape.ValueOrDie(),
795 {updated_lhs, updated_rhs, updated_ehs});
796 });
797 }
798
ConstantLiteral(const LiteralSlice & literal)799 XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
800 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
801 if (literal.shape().IsArray() && literal.element_count() > 1 &&
802 literal.IsAllFirst()) {
803 Literal scalar = LiteralUtil::GetFirstScalarLiteral(literal);
804 HloInstructionProto instr;
805 *instr.mutable_shape() = scalar.shape().ToProto();
806 *instr.mutable_literal() = scalar.ToProto();
807 TF_ASSIGN_OR_RETURN(
808 XlaOp scalar_op,
809 AddInstruction(std::move(instr), HloOpcode::kConstant));
810 return Broadcast(scalar_op, literal.shape().dimensions());
811 } else {
812 HloInstructionProto instr;
813 *instr.mutable_shape() = literal.shape().ToProto();
814 *instr.mutable_literal() = literal.ToProto();
815 return AddInstruction(std::move(instr), HloOpcode::kConstant);
816 }
817 });
818 }
819
Iota(const Shape & shape,int64_t iota_dimension)820 XlaOp XlaBuilder::Iota(const Shape& shape, int64_t iota_dimension) {
821 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
822 HloInstructionProto instr;
823 *instr.mutable_shape() = shape.ToProto();
824 instr.add_dimensions(iota_dimension);
825 return AddInstruction(std::move(instr), HloOpcode::kIota);
826 });
827 }
828
Iota(PrimitiveType type,int64_t size)829 XlaOp XlaBuilder::Iota(PrimitiveType type, int64_t size) {
830 return Iota(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0);
831 }
832
Call(const XlaComputation & computation,absl::Span<const XlaOp> operands)833 XlaOp XlaBuilder::Call(const XlaComputation& computation,
834 absl::Span<const XlaOp> operands) {
835 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
836 HloInstructionProto instr;
837 std::vector<const Shape*> operand_shape_ptrs;
838 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
839 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
840 [](const Shape& shape) { return &shape; });
841 TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
842 computation.GetProgramShape());
843 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape(
844 operand_shape_ptrs,
845 /*to_apply=*/called_program_shape));
846 *instr.mutable_shape() = shape.ToProto();
847
848 AddCalledComputation(computation, &instr);
849
850 return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
851 });
852 }
853
Parameter(int64_t parameter_number,const Shape & shape,const string & name,const std::vector<bool> & replicated_at_leaf_buffers)854 XlaOp XlaBuilder::Parameter(
855 int64_t parameter_number, const Shape& shape, const string& name,
856 const std::vector<bool>& replicated_at_leaf_buffers) {
857 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
858 HloInstructionProto instr;
859 if (!parameter_numbers_.insert(parameter_number).second) {
860 return InvalidArgument("parameter %d already registered",
861 parameter_number);
862 }
863 instr.set_parameter_number(parameter_number);
864 instr.set_name(name);
865 *instr.mutable_shape() = shape.ToProto();
866 if (!replicated_at_leaf_buffers.empty()) {
867 auto replication = instr.mutable_parameter_replication();
868 for (bool replicated : replicated_at_leaf_buffers) {
869 replication->add_replicated_at_leaf_buffers(replicated);
870 }
871 }
872 return AddInstruction(std::move(instr), HloOpcode::kParameter);
873 });
874 }
875
Broadcast(XlaOp operand,absl::Span<const int64> broadcast_sizes)876 XlaOp XlaBuilder::Broadcast(XlaOp operand,
877 absl::Span<const int64> broadcast_sizes) {
878 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
879 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
880 TF_ASSIGN_OR_RETURN(
881 const Shape& shape,
882 ShapeInference::InferBroadcastShape(*operand_shape, broadcast_sizes));
883
884 // The client-level broadcast op just appends dimensions on the left (adds
885 // lowest numbered dimensions). The HLO broadcast instruction is more
886 // flexible and can add new dimensions anywhere. The instruction's
887 // dimensions field maps operand dimensions to dimensions in the broadcast
888 // output, so to append dimensions on the left the instruction's dimensions
889 // should just be the n highest dimension numbers of the output shape where
890 // n is the number of input dimensions.
891 const int64_t operand_rank = operand_shape->rank();
892 std::vector<int64> dimensions(operand_rank);
893 for (int i = 0; i < operand_rank; ++i) {
894 dimensions[i] = i + shape.rank() - operand_rank;
895 }
896 return InDimBroadcast(shape, operand, dimensions);
897 });
898 }
899
BroadcastInDim(XlaOp operand,const absl::Span<const int64> out_dim_size,const absl::Span<const int64> broadcast_dimensions)900 XlaOp XlaBuilder::BroadcastInDim(
901 XlaOp operand, const absl::Span<const int64> out_dim_size,
902 const absl::Span<const int64> broadcast_dimensions) {
903 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
904 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
905 // Output shape, in the case of degenerate broadcast, the out_dim_size is
906 // not necessarily the same as the dimension sizes of the output shape.
907 TF_ASSIGN_OR_RETURN(auto output_shape,
908 ShapeUtil::MakeValidatedShape(
909 operand_shape->element_type(), out_dim_size));
910 int64_t broadcast_rank = broadcast_dimensions.size();
911 if (operand_shape->rank() != broadcast_rank) {
912 return InvalidArgument(
913 "Size of broadcast_dimensions has to match operand's rank; operand "
914 "rank: %lld, size of broadcast_dimensions %u.",
915 operand_shape->rank(), broadcast_dimensions.size());
916 }
917 for (int i = 0; i < broadcast_rank; i++) {
918 const int64_t num_dims = out_dim_size.size();
919 if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] > num_dims) {
920 return InvalidArgument("Broadcast dimension %lld is out of bound",
921 broadcast_dimensions[i]);
922 }
923 output_shape.set_dynamic_dimension(
924 broadcast_dimensions[i], operand_shape->is_dynamic_dimension(i));
925 }
926
927 TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape(
928 *operand_shape, output_shape, broadcast_dimensions)
929 .status());
930 std::vector<int64> in_dim_size(out_dim_size.begin(), out_dim_size.end());
931 for (int i = 0; i < broadcast_rank; i++) {
932 in_dim_size[broadcast_dimensions[i]] = operand_shape->dimensions(i);
933 }
934 const auto& in_dim_shape =
935 ShapeUtil::MakeShape(operand_shape->element_type(), in_dim_size);
936 TF_ASSIGN_OR_RETURN(
937 XlaOp in_dim_broadcast,
938 InDimBroadcast(in_dim_shape, operand, broadcast_dimensions));
939
940 // If broadcast is not degenerate, return broadcasted result.
941 if (ShapeUtil::Equal(in_dim_shape, output_shape)) {
942 return in_dim_broadcast;
943 }
944
945 // Otherwise handle degenerate broadcast case.
946 return AddBroadcastSequence(output_shape, in_dim_broadcast);
947 });
948 }
949
ReshapeInternal(const Shape & shape,XlaOp operand,int64_t inferred_dimension)950 StatusOr<XlaOp> XlaBuilder::ReshapeInternal(const Shape& shape, XlaOp operand,
951 int64_t inferred_dimension) {
952 TF_RETURN_IF_ERROR(first_error_);
953
954 HloInstructionProto instr;
955 *instr.mutable_shape() = shape.ToProto();
956 if (inferred_dimension != -1) {
957 instr.add_dimensions(inferred_dimension);
958 }
959 return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand});
960 }
961
Slice(XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)962 XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span<const int64> start_indices,
963 absl::Span<const int64> limit_indices,
964 absl::Span<const int64> strides) {
965 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
966 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
967 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSliceShape(
968 *operand_shape, start_indices,
969 limit_indices, strides));
970 return SliceInternal(shape, operand, start_indices, limit_indices, strides);
971 });
972 }
973
SliceInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)974 StatusOr<XlaOp> XlaBuilder::SliceInternal(const Shape& shape, XlaOp operand,
975 absl::Span<const int64> start_indices,
976 absl::Span<const int64> limit_indices,
977 absl::Span<const int64> strides) {
978 HloInstructionProto instr;
979 *instr.mutable_shape() = shape.ToProto();
980 for (int i = 0, end = start_indices.size(); i < end; i++) {
981 auto* slice_config = instr.add_slice_dimensions();
982 slice_config->set_start(start_indices[i]);
983 slice_config->set_limit(limit_indices[i]);
984 slice_config->set_stride(strides[i]);
985 }
986 return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
987 }
988
SliceInDim(XlaOp operand,int64_t start_index,int64_t limit_index,int64_t stride,int64_t dimno)989 XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64_t start_index,
990 int64_t limit_index, int64_t stride,
991 int64_t dimno) {
992 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
993 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
994 std::vector<int64> starts(shape->rank(), 0);
995 std::vector<int64> limits(shape->dimensions().begin(),
996 shape->dimensions().end());
997 std::vector<int64> strides(shape->rank(), 1);
998 starts[dimno] = start_index;
999 limits[dimno] = limit_index;
1000 strides[dimno] = stride;
1001 return Slice(operand, starts, limits, strides);
1002 });
1003 }
1004
DynamicSlice(XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)1005 XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
1006 absl::Span<const XlaOp> start_indices,
1007 absl::Span<const int64> slice_sizes) {
1008 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1009 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1010 std::vector<const Shape*> start_indices_shape_ptrs;
1011 TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
1012 GetOperandShapes(start_indices));
1013 absl::c_transform(start_indices_shapes,
1014 std::back_inserter(start_indices_shape_ptrs),
1015 [](const Shape& shape) { return &shape; });
1016 TF_ASSIGN_OR_RETURN(Shape shape,
1017 ShapeInference::InferDynamicSliceShape(
1018 *operand_shape, start_indices_shapes, slice_sizes));
1019 return DynamicSliceInternal(shape, operand, start_indices, slice_sizes);
1020 });
1021 }
1022
DynamicSliceInternal(const Shape & shape,XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)1023 StatusOr<XlaOp> XlaBuilder::DynamicSliceInternal(
1024 const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
1025 absl::Span<const int64> slice_sizes) {
1026 HloInstructionProto instr;
1027 *instr.mutable_shape() = shape.ToProto();
1028
1029 for (int64_t size : slice_sizes) {
1030 instr.add_dynamic_slice_sizes(size);
1031 }
1032
1033 std::vector<XlaOp> operands = {operand};
1034 operands.insert(operands.end(), start_indices.begin(), start_indices.end());
1035 return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands);
1036 }
1037
DynamicUpdateSlice(XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)1038 XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
1039 absl::Span<const XlaOp> start_indices) {
1040 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1041 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1042 TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
1043 std::vector<const Shape*> start_indices_shape_ptrs;
1044 TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
1045 GetOperandShapes(start_indices));
1046 absl::c_transform(start_indices_shapes,
1047 std::back_inserter(start_indices_shape_ptrs),
1048 [](const Shape& shape) { return &shape; });
1049 TF_ASSIGN_OR_RETURN(
1050 Shape shape, ShapeInference::InferDynamicUpdateSliceShape(
1051 *operand_shape, *update_shape, start_indices_shapes));
1052 return DynamicUpdateSliceInternal(shape, operand, update, start_indices);
1053 });
1054 }
1055
DynamicUpdateSliceInternal(const Shape & shape,XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)1056 StatusOr<XlaOp> XlaBuilder::DynamicUpdateSliceInternal(
1057 const Shape& shape, XlaOp operand, XlaOp update,
1058 absl::Span<const XlaOp> start_indices) {
1059 HloInstructionProto instr;
1060 *instr.mutable_shape() = shape.ToProto();
1061
1062 std::vector<XlaOp> operands = {operand, update};
1063 operands.insert(operands.end(), start_indices.begin(), start_indices.end());
1064 return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
1065 operands);
1066 }
1067
ConcatInDim(absl::Span<const XlaOp> operands,int64_t dimension)1068 XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
1069 int64_t dimension) {
1070 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1071 std::vector<const Shape*> operand_shape_ptrs;
1072 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
1073 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1074 [](const Shape& shape) { return &shape; });
1075 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape(
1076 operand_shape_ptrs, dimension));
1077 return ConcatInDimInternal(shape, operands, dimension);
1078 });
1079 }
1080
ConcatInDimInternal(const Shape & shape,absl::Span<const XlaOp> operands,int64_t dimension)1081 StatusOr<XlaOp> XlaBuilder::ConcatInDimInternal(
1082 const Shape& shape, absl::Span<const XlaOp> operands, int64_t dimension) {
1083 HloInstructionProto instr;
1084 *instr.mutable_shape() = shape.ToProto();
1085
1086 instr.add_dimensions(dimension);
1087
1088 return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
1089 }
1090
Pad(XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)1091 XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value,
1092 const PaddingConfig& padding_config) {
1093 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1094 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1095 TF_ASSIGN_OR_RETURN(const Shape* padding_value_shape,
1096 GetShapePtr(padding_value));
1097 TF_ASSIGN_OR_RETURN(
1098 Shape shape, ShapeInference::InferPadShape(
1099 *operand_shape, *padding_value_shape, padding_config));
1100 return PadInternal(shape, operand, padding_value, padding_config);
1101 });
1102 }
1103
PadInDim(XlaOp operand,XlaOp padding_value,int64_t dimno,int64_t pad_lo,int64_t pad_hi)1104 XlaOp XlaBuilder::PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
1105 int64_t pad_lo, int64_t pad_hi) {
1106 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1107 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
1108 PaddingConfig padding_config = MakeNoPaddingConfig(shape->rank());
1109 auto* dims = padding_config.mutable_dimensions(dimno);
1110 dims->set_edge_padding_low(pad_lo);
1111 dims->set_edge_padding_high(pad_hi);
1112 return Pad(operand, padding_value, padding_config);
1113 });
1114 }
1115
PadInternal(const Shape & shape,XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)1116 StatusOr<XlaOp> XlaBuilder::PadInternal(const Shape& shape, XlaOp operand,
1117 XlaOp padding_value,
1118 const PaddingConfig& padding_config) {
1119 HloInstructionProto instr;
1120 *instr.mutable_shape() = shape.ToProto();
1121 *instr.mutable_padding_config() = padding_config;
1122 return AddInstruction(std::move(instr), HloOpcode::kPad,
1123 {operand, padding_value});
1124 }
1125
Reshape(XlaOp operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes,int64_t inferred_dimension)1126 XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> dimensions,
1127 absl::Span<const int64> new_sizes,
1128 int64_t inferred_dimension) {
1129 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1130 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1131 TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferReshapeShape(
1132 *operand_shape, dimensions,
1133 new_sizes, inferred_dimension));
1134 XlaOp transposed = IsIdentityPermutation(dimensions)
1135 ? operand
1136 : Transpose(operand, dimensions);
1137 return ReshapeInternal(shape, transposed, inferred_dimension);
1138 });
1139 }
1140
Reshape(XlaOp operand,absl::Span<const int64> new_sizes,int64_t inferred_dimension)1141 XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
1142 int64_t inferred_dimension) {
1143 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1144 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
1145 std::vector<int64> dimensions(shape->dimensions_size());
1146 std::iota(dimensions.begin(), dimensions.end(), 0);
1147 return Reshape(operand, dimensions, new_sizes, inferred_dimension);
1148 });
1149 }
1150
Reshape(const Shape & shape,XlaOp operand,int64_t inferred_dimension)1151 XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand,
1152 int64_t inferred_dimension) {
1153 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1154 return ReshapeInternal(shape, operand, inferred_dimension);
1155 });
1156 }
1157
DynamicReshape(XlaOp operand,absl::Span<const XlaOp> dim_sizes,absl::Span<const int64> new_size_bounds,const std::vector<bool> & dims_are_dynamic)1158 XlaOp XlaBuilder::DynamicReshape(XlaOp operand,
1159 absl::Span<const XlaOp> dim_sizes,
1160 absl::Span<const int64> new_size_bounds,
1161 const std::vector<bool>& dims_are_dynamic) {
1162 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1163 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1164 std::vector<const Shape*> dim_size_shape_ptrs;
1165 TF_ASSIGN_OR_RETURN(const auto& dim_size_shapes,
1166 GetOperandShapes(dim_sizes));
1167
1168 absl::c_transform(dim_size_shapes, std::back_inserter(dim_size_shape_ptrs),
1169 [](const Shape& shape) { return &shape; });
1170 TF_ASSIGN_OR_RETURN(const Shape shape,
1171 ShapeInference::InferDynamicReshapeShape(
1172 *operand_shape, dim_size_shape_ptrs,
1173 new_size_bounds, dims_are_dynamic));
1174 TF_RETURN_IF_ERROR(first_error_);
1175 std::vector<XlaOp> operands;
1176 operands.reserve(1 + dim_sizes.size());
1177 operands.push_back(operand);
1178 for (const XlaOp& dim_size : dim_sizes) {
1179 operands.push_back(dim_size);
1180 }
1181 HloInstructionProto instr;
1182 *instr.mutable_shape() = shape.ToProto();
1183 return AddInstruction(std::move(instr), HloOpcode::kDynamicReshape,
1184 operands);
1185 });
1186 }
1187
Collapse(XlaOp operand,absl::Span<const int64> dimensions)1188 XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span<const int64> dimensions) {
1189 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1190 if (dimensions.size() <= 1) {
1191 // Not collapsing anything, trivially we can return the operand versus
1192 // enqueueing a trivial reshape.
1193 return operand;
1194 }
1195
1196 // Out-of-order collapse is not supported.
1197 // Checks that the collapsed dimensions are in order and consecutive.
1198 for (absl::Span<const int64>::size_type i = 1; i < dimensions.size(); ++i) {
1199 if (dimensions[i] - 1 != dimensions[i - 1]) {
1200 return InvalidArgument(
1201 "Collapsed dimensions are not in consecutive order.");
1202 }
1203 }
1204
1205 // Create a new sizes vector from the old shape, replacing the collapsed
1206 // dimensions by the product of their sizes.
1207 TF_ASSIGN_OR_RETURN(const Shape* original_shape, GetShapePtr(operand));
1208
1209 VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape);
1210 VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ",");
1211
1212 std::vector<int64> new_sizes;
1213 for (int i = 0; i < original_shape->rank(); ++i) {
1214 if (i <= dimensions.front() || i > dimensions.back()) {
1215 new_sizes.push_back(original_shape->dimensions(i));
1216 } else {
1217 new_sizes.back() *= original_shape->dimensions(i);
1218 }
1219 }
1220
1221 VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]";
1222
1223 return Reshape(operand, new_sizes);
1224 });
1225 }
1226
Trace(const string & tag,XlaOp operand)1227 void XlaBuilder::Trace(const string& tag, XlaOp operand) {
1228 ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1229 HloInstructionProto instr;
1230 *instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
1231 *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto();
1232 return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
1233 });
1234 }
1235
Select(XlaOp pred,XlaOp on_true,XlaOp on_false)1236 XlaOp XlaBuilder::Select(XlaOp pred, XlaOp on_true, XlaOp on_false) {
1237 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1238 TF_ASSIGN_OR_RETURN(const Shape* true_shape, GetShapePtr(on_true));
1239 TF_ASSIGN_OR_RETURN(const Shape* false_shape, GetShapePtr(on_false));
1240 TF_RET_CHECK(true_shape->IsTuple() == false_shape->IsTuple());
1241 HloOpcode opcode =
1242 true_shape->IsTuple() ? HloOpcode::kTupleSelect : HloOpcode::kSelect;
1243 return TernaryOp(opcode, pred, on_true, on_false);
1244 });
1245 }
1246
Tuple(absl::Span<const XlaOp> elements)1247 XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
1248 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1249 std::vector<const Shape*> operand_shape_ptrs;
1250 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
1251 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1252 [](const Shape& shape) { return &shape; });
1253 TF_ASSIGN_OR_RETURN(const Shape shape,
1254 ShapeInference::InferVariadicOpShape(
1255 HloOpcode::kTuple, operand_shape_ptrs));
1256 return TupleInternal(shape, elements);
1257 });
1258 }
1259
TupleInternal(const Shape & shape,absl::Span<const XlaOp> elements)1260 StatusOr<XlaOp> XlaBuilder::TupleInternal(const Shape& shape,
1261 absl::Span<const XlaOp> elements) {
1262 HloInstructionProto instr;
1263 *instr.mutable_shape() = shape.ToProto();
1264 return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
1265 }
1266
GetTupleElement(XlaOp tuple_data,int64_t index)1267 XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64_t index) {
1268 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1269 TF_ASSIGN_OR_RETURN(const Shape* tuple_shape, GetShapePtr(tuple_data));
1270 if (!tuple_shape->IsTuple()) {
1271 return InvalidArgument(
1272 "Operand to GetTupleElement() is not a tuple; got %s",
1273 ShapeUtil::HumanString(*tuple_shape));
1274 }
1275 if (index < 0 || index >= ShapeUtil::TupleElementCount(*tuple_shape)) {
1276 return InvalidArgument(
1277 "GetTupleElement() index (%d) out of range for tuple shape %s", index,
1278 ShapeUtil::HumanString(*tuple_shape));
1279 }
1280 return GetTupleElementInternal(
1281 ShapeUtil::GetTupleElementShape(*tuple_shape, index), tuple_data,
1282 index);
1283 });
1284 }
1285
GetTupleElementInternal(const Shape & shape,XlaOp tuple_data,int64_t index)1286 StatusOr<XlaOp> XlaBuilder::GetTupleElementInternal(const Shape& shape,
1287 XlaOp tuple_data,
1288 int64_t index) {
1289 HloInstructionProto instr;
1290 *instr.mutable_shape() = shape.ToProto();
1291 instr.set_tuple_index(index);
1292 return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
1293 {tuple_data});
1294 }
1295
Dot(XlaOp lhs,XlaOp rhs,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1296 XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
1297 const PrecisionConfig* precision_config,
1298 absl::optional<PrimitiveType> preferred_element_type) {
1299 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1300 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1301
1302 DotDimensionNumbers dimension_numbers;
1303 dimension_numbers.add_lhs_contracting_dimensions(
1304 lhs_shape->dimensions_size() == 1 ? 0 : 1);
1305 dimension_numbers.add_rhs_contracting_dimensions(0);
1306 return DotGeneral(lhs, rhs, dimension_numbers, precision_config,
1307 preferred_element_type);
1308 });
1309 }
1310
DotGeneral(XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1311 XlaOp XlaBuilder::DotGeneral(
1312 XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
1313 const PrecisionConfig* precision_config,
1314 absl::optional<PrimitiveType> preferred_element_type) {
1315 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1316 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1317 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1318 TF_ASSIGN_OR_RETURN(
1319 Shape shape,
1320 ShapeInference::InferDotOpShape(
1321 *lhs_shape, *rhs_shape, dimension_numbers, preferred_element_type));
1322 return DotGeneralInternal(shape, lhs, rhs, dimension_numbers,
1323 precision_config);
1324 });
1325 }
1326
DotGeneralInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config)1327 StatusOr<XlaOp> XlaBuilder::DotGeneralInternal(
1328 const Shape& shape, XlaOp lhs, XlaOp rhs,
1329 const DotDimensionNumbers& dimension_numbers,
1330 const PrecisionConfig* precision_config) {
1331 HloInstructionProto instr;
1332 *instr.mutable_shape() = shape.ToProto();
1333 *instr.mutable_dot_dimension_numbers() = dimension_numbers;
1334 if (precision_config != nullptr) {
1335 *instr.mutable_precision_config() = *precision_config;
1336 }
1337 return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
1338 }
1339
VerifyConvolution(const Shape & lhs_shape,const Shape & rhs_shape,const ConvolutionDimensionNumbers & dimension_numbers) const1340 Status XlaBuilder::VerifyConvolution(
1341 const Shape& lhs_shape, const Shape& rhs_shape,
1342 const ConvolutionDimensionNumbers& dimension_numbers) const {
1343 if (lhs_shape.rank() != rhs_shape.rank()) {
1344 return InvalidArgument(
1345 "Convolution arguments must have same number of "
1346 "dimensions. Got: %s and %s",
1347 ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1348 }
1349 int num_dims = lhs_shape.rank();
1350 if (num_dims < 2) {
1351 return InvalidArgument(
1352 "Convolution expects argument arrays with >= 3 dimensions. "
1353 "Got: %s and %s",
1354 ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1355 }
1356 int num_spatial_dims = num_dims - 2;
1357
1358 const auto check_spatial_dimensions =
1359 [&](const char* const field_name,
1360 const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
1361 numbers) {
1362 if (numbers.size() != num_spatial_dims) {
1363 return InvalidArgument("Expected %d elements for %s, but got %d.",
1364 num_spatial_dims, field_name, numbers.size());
1365 }
1366 for (int i = 0; i < numbers.size(); ++i) {
1367 if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
1368 return InvalidArgument("Convolution %s[%d] is out of bounds: %d",
1369 field_name, i, numbers.Get(i));
1370 }
1371 }
1372 return Status::OK();
1373 };
1374 TF_RETURN_IF_ERROR(
1375 check_spatial_dimensions("input_spatial_dimensions",
1376 dimension_numbers.input_spatial_dimensions()));
1377 TF_RETURN_IF_ERROR(
1378 check_spatial_dimensions("kernel_spatial_dimensions",
1379 dimension_numbers.kernel_spatial_dimensions()));
1380 return check_spatial_dimensions(
1381 "output_spatial_dimensions",
1382 dimension_numbers.output_spatial_dimensions());
1383 }
1384
Conv(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1385 XlaOp XlaBuilder::Conv(XlaOp lhs, XlaOp rhs,
1386 absl::Span<const int64> window_strides, Padding padding,
1387 int64_t feature_group_count, int64_t batch_group_count,
1388 const PrecisionConfig* precision_config,
1389 absl::optional<PrimitiveType> preferred_element_type) {
1390 return ConvWithGeneralDimensions(
1391 lhs, rhs, window_strides, padding,
1392 CreateDefaultConvDimensionNumbers(window_strides.size()),
1393 feature_group_count, batch_group_count, precision_config,
1394 preferred_element_type);
1395 }
1396
ConvWithGeneralPadding(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1397 XlaOp XlaBuilder::ConvWithGeneralPadding(
1398 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1399 absl::Span<const std::pair<int64, int64>> padding,
1400 int64_t feature_group_count, int64_t batch_group_count,
1401 const PrecisionConfig* precision_config,
1402 absl::optional<PrimitiveType> preferred_element_type) {
1403 return ConvGeneral(lhs, rhs, window_strides, padding,
1404 CreateDefaultConvDimensionNumbers(window_strides.size()),
1405 feature_group_count, batch_group_count, precision_config,
1406 preferred_element_type);
1407 }
1408
ConvWithGeneralDimensions(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1409 XlaOp XlaBuilder::ConvWithGeneralDimensions(
1410 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1411 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1412 int64_t feature_group_count, int64_t batch_group_count,
1413 const PrecisionConfig* precision_config,
1414 absl::optional<PrimitiveType> preferred_element_type) {
1415 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1416 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1417 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1418
1419 TF_RETURN_IF_ERROR(
1420 VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers));
1421
1422 std::vector<int64> base_area_dimensions(
1423 dimension_numbers.input_spatial_dimensions_size());
1424 for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size();
1425 ++i) {
1426 base_area_dimensions[i] =
1427 lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i));
1428 }
1429
1430 std::vector<int64> window_dimensions(
1431 dimension_numbers.kernel_spatial_dimensions_size());
1432 for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
1433 ++i) {
1434 window_dimensions[i] =
1435 rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1436 }
1437
1438 return ConvGeneral(lhs, rhs, window_strides,
1439 MakePadding(base_area_dimensions, window_dimensions,
1440 window_strides, padding),
1441 dimension_numbers, feature_group_count,
1442 batch_group_count, precision_config,
1443 preferred_element_type);
1444 });
1445 }
1446
ConvGeneral(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1447 XlaOp XlaBuilder::ConvGeneral(
1448 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1449 absl::Span<const std::pair<int64, int64>> padding,
1450 const ConvolutionDimensionNumbers& dimension_numbers,
1451 int64_t feature_group_count, int64_t batch_group_count,
1452 const PrecisionConfig* precision_config,
1453 absl::optional<PrimitiveType> preferred_element_type) {
1454 return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
1455 dimension_numbers, feature_group_count,
1456 batch_group_count, precision_config,
1457 preferred_element_type);
1458 }
1459
ConvGeneralDilated(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1460 XlaOp XlaBuilder::ConvGeneralDilated(
1461 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1462 absl::Span<const std::pair<int64, int64>> padding,
1463 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1464 const ConvolutionDimensionNumbers& dimension_numbers,
1465 int64_t feature_group_count, int64_t batch_group_count,
1466 const PrecisionConfig* precision_config,
1467 absl::optional<PrimitiveType> preferred_element_type) {
1468 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1469 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1470 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1471 TF_RETURN_IF_ERROR(
1472 VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers));
1473
1474 std::vector<int64> window_dimensions(
1475 dimension_numbers.kernel_spatial_dimensions_size());
1476 for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
1477 ++i) {
1478 window_dimensions[i] =
1479 rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1480 }
1481
1482 TF_ASSIGN_OR_RETURN(Window window,
1483 ShapeInference::InferWindowFromDimensions(
1484 window_dimensions, window_strides, padding,
1485 lhs_dilation, rhs_dilation));
1486 TF_ASSIGN_OR_RETURN(
1487 Shape shape,
1488 ShapeInference::InferConvolveShape(
1489 *lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
1490 window, dimension_numbers, preferred_element_type));
1491 return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides,
1492 padding, lhs_dilation, rhs_dilation,
1493 dimension_numbers, feature_group_count,
1494 batch_group_count, precision_config);
1495 });
1496 }
1497
DynamicConvInstruction(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)1498 StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
1499 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1500 absl::Span<const std::pair<int64, int64>> padding,
1501 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1502 const ConvolutionDimensionNumbers& dimension_numbers,
1503 int64_t feature_group_count, int64_t batch_group_count,
1504 const PrecisionConfig* precision_config, PaddingType padding_type,
1505 absl::optional<PrimitiveType> preferred_element_type) {
1506 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1507 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1508 std::vector<int64> window_dimensions(
1509 dimension_numbers.kernel_spatial_dimensions_size());
1510 for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) {
1511 window_dimensions[i] =
1512 rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1513 }
1514
1515 TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions(
1516 window_dimensions, window_strides,
1517 padding, lhs_dilation, rhs_dilation));
1518 TF_ASSIGN_OR_RETURN(
1519 Shape shape,
1520 ShapeInference::InferConvolveShape(
1521 *lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
1522 window, dimension_numbers, preferred_element_type));
1523
1524 HloInstructionProto instr;
1525 *instr.mutable_shape() = shape.ToProto();
1526
1527 *instr.mutable_window() = window;
1528 *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
1529 instr.set_feature_group_count(feature_group_count);
1530 instr.set_batch_group_count(batch_group_count);
1531 instr.set_padding_type(padding_type);
1532
1533 if (precision_config != nullptr) {
1534 *instr.mutable_precision_config() = *precision_config;
1535 }
1536 return std::move(instr);
1537 }
1538
DynamicConvInputGrad(XlaOp input_sizes,XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)1539 XlaOp XlaBuilder::DynamicConvInputGrad(
1540 XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
1541 absl::Span<const int64> window_strides,
1542 absl::Span<const std::pair<int64, int64>> padding,
1543 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1544 const ConvolutionDimensionNumbers& dimension_numbers,
1545 int64_t feature_group_count, int64_t batch_group_count,
1546 const PrecisionConfig* precision_config, PaddingType padding_type,
1547 absl::optional<PrimitiveType> preferred_element_type) {
1548 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1549 TF_ASSIGN_OR_RETURN(
1550 HloInstructionProto instr,
1551 DynamicConvInstruction(
1552 lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
1553 dimension_numbers, feature_group_count, batch_group_count,
1554 precision_config, padding_type, preferred_element_type));
1555
1556 instr.set_custom_call_target("DynamicConvolutionInputGrad");
1557
1558 return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
1559 {input_sizes, lhs, rhs});
1560 });
1561 }
1562
DynamicConvKernelGrad(XlaOp activations,XlaOp gradients,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)1563 XlaOp XlaBuilder::DynamicConvKernelGrad(
1564 XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
1565 absl::Span<const std::pair<int64, int64>> padding,
1566 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1567 const ConvolutionDimensionNumbers& dimension_numbers,
1568 int64_t feature_group_count, int64_t batch_group_count,
1569 const PrecisionConfig* precision_config, PaddingType padding_type,
1570 absl::optional<PrimitiveType> preferred_element_type) {
1571 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1572 TF_ASSIGN_OR_RETURN(
1573 HloInstructionProto instr,
1574 DynamicConvInstruction(activations, gradients, window_strides, padding,
1575 lhs_dilation, rhs_dilation, dimension_numbers,
1576 feature_group_count, batch_group_count,
1577 precision_config, padding_type,
1578 preferred_element_type));
1579
1580 instr.set_custom_call_target("DynamicConvolutionKernelGrad");
1581 // The gradient of kernel has kernel shape and shouldn't have any dynamic
1582 // sizes.
1583 instr.mutable_shape()->clear_is_dynamic_dimension();
1584 return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
1585 {activations, gradients});
1586 });
1587 }
1588
DynamicConvForward(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)1589 XlaOp XlaBuilder::DynamicConvForward(
1590 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1591 absl::Span<const std::pair<int64, int64>> padding,
1592 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1593 const ConvolutionDimensionNumbers& dimension_numbers,
1594 int64_t feature_group_count, int64_t batch_group_count,
1595 const PrecisionConfig* precision_config, PaddingType padding_type,
1596 absl::optional<PrimitiveType> preferred_element_type) {
1597 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1598 TF_ASSIGN_OR_RETURN(
1599 HloInstructionProto instr,
1600 DynamicConvInstruction(
1601 lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
1602 dimension_numbers, feature_group_count, batch_group_count,
1603 precision_config, padding_type, preferred_element_type));
1604 instr.set_custom_call_target("DynamicConvolutionForward");
1605
1606 return AddInstruction(std::move(instr), HloOpcode::kCustomCall, {lhs, rhs});
1607 });
1608 }
1609
ConvGeneralDilatedInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const Window & window,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config)1610 StatusOr<XlaOp> XlaBuilder::ConvGeneralDilatedInternal(
1611 const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
1612 absl::Span<const int64> window_strides,
1613 absl::Span<const std::pair<int64, int64>> padding,
1614 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1615 const ConvolutionDimensionNumbers& dimension_numbers,
1616 int64_t feature_group_count, int64_t batch_group_count,
1617 const PrecisionConfig* precision_config) {
1618 HloInstructionProto instr;
1619 *instr.mutable_shape() = shape.ToProto();
1620
1621 *instr.mutable_window() = window;
1622 *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
1623 instr.set_feature_group_count(feature_group_count);
1624 instr.set_batch_group_count(batch_group_count);
1625
1626 if (precision_config != nullptr) {
1627 *instr.mutable_precision_config() = *precision_config;
1628 }
1629
1630 return AddInstruction(std::move(instr), HloOpcode::kConvolution, {lhs, rhs});
1631 }
1632
Fft(XlaOp operand,const FftType fft_type,const absl::Span<const int64> fft_length)1633 XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type,
1634 const absl::Span<const int64> fft_length) {
1635 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1636 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1637 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape(
1638 *operand_shape, fft_type, fft_length));
1639 return FftInternal(shape, operand, fft_type, fft_length);
1640 });
1641 }
1642
FftInternal(const Shape & shape,XlaOp operand,const FftType fft_type,const absl::Span<const int64> fft_length)1643 StatusOr<XlaOp> XlaBuilder::FftInternal(
1644 const Shape& shape, XlaOp operand, const FftType fft_type,
1645 const absl::Span<const int64> fft_length) {
1646 HloInstructionProto instr;
1647 *instr.mutable_shape() = shape.ToProto();
1648 instr.set_fft_type(fft_type);
1649 for (int64_t i : fft_length) {
1650 instr.add_fft_length(i);
1651 }
1652
1653 return AddInstruction(std::move(instr), HloOpcode::kFft, {operand});
1654 }
1655
TriangularSolveInternal(const Shape & shape,XlaOp a,XlaOp b,TriangularSolveOptions options)1656 StatusOr<XlaOp> XlaBuilder::TriangularSolveInternal(
1657 const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) {
1658 HloInstructionProto instr;
1659 *instr.mutable_triangular_solve_options() = std::move(options);
1660 *instr.mutable_shape() = shape.ToProto();
1661
1662 return AddInstruction(std::move(instr), HloOpcode::kTriangularSolve, {a, b});
1663 }
1664
CholeskyInternal(const Shape & shape,XlaOp a,bool lower)1665 StatusOr<XlaOp> XlaBuilder::CholeskyInternal(const Shape& shape, XlaOp a,
1666 bool lower) {
1667 HloInstructionProto instr;
1668 xla::CholeskyOptions& options = *instr.mutable_cholesky_options();
1669 options.set_lower(lower);
1670 *instr.mutable_shape() = shape.ToProto();
1671
1672 return AddInstruction(std::move(instr), HloOpcode::kCholesky, {a});
1673 }
1674
Infeed(const Shape & shape,const string & config)1675 XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
1676 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1677 HloInstructionProto instr;
1678 if (!LayoutUtil::HasLayout(shape)) {
1679 return InvalidArgument("Given shape to Infeed must have a layout");
1680 }
1681 const Shape infeed_instruction_shape =
1682 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1683 *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1684 instr.set_infeed_config(config);
1685
1686 if (shape.IsArray() && sharding() &&
1687 sharding()->type() == OpSharding::OTHER) {
1688 // TODO(b/110793772): Support tiled array-shaped infeeds.
1689 return InvalidArgument(
1690 "Tiled sharding is not yet supported for array-shaped infeeds");
1691 }
1692
1693 if (sharding() && sharding()->type() == OpSharding::REPLICATED) {
1694 return InvalidArgument(
1695 "Replicated sharding is not yet supported for infeeds");
1696 }
1697
1698 // Infeed takes a single token operand. Generate the token to pass to the
1699 // infeed.
1700 XlaOp token;
1701 auto make_token = [&]() {
1702 HloInstructionProto token_instr;
1703 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1704 return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {});
1705 };
1706 if (sharding()) {
1707 // Arbitrarily assign token to device 0.
1708 OpSharding sharding = sharding_builder::AssignDevice(0);
1709 XlaScopedShardingAssignment scoped_sharding(this, sharding);
1710 TF_ASSIGN_OR_RETURN(token, make_token());
1711 } else {
1712 TF_ASSIGN_OR_RETURN(token, make_token());
1713 }
1714
1715 // The sharding is set by the client according to the data tuple shape.
1716 // However, the shape of the infeed instruction is a tuple containing the
1717 // data and a token. For tuple sharding type, the sharding must be changed
1718 // to accommodate the token.
1719 XlaOp infeed;
1720 if (sharding() && sharding()->type() == OpSharding::TUPLE) {
1721 // TODO(b/80000000): Remove this when clients have been updated to handle
1722 // tokens.
1723 OpSharding infeed_instruction_sharding = *sharding();
1724 // Arbitrarily assign the token to device 0.
1725 *infeed_instruction_sharding.add_tuple_shardings() =
1726 sharding_builder::AssignDevice(0);
1727 XlaScopedShardingAssignment scoped_sharding(this,
1728 infeed_instruction_sharding);
1729 TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1730 HloOpcode::kInfeed, {token}));
1731 } else {
1732 TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1733 HloOpcode::kInfeed, {token}));
1734 }
1735
1736 // The infeed instruction produces a tuple of the infed data and a token
1737 // type. Return XLA op containing the data.
1738 // TODO(b/80000000): Remove this when clients have been updated to handle
1739 // tokens.
1740 HloInstructionProto infeed_data;
1741 *infeed_data.mutable_shape() = shape.ToProto();
1742 infeed_data.set_tuple_index(0);
1743 return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement,
1744 {infeed});
1745 });
1746 }
1747
InfeedWithToken(XlaOp token,const Shape & shape,const string & config)1748 XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape,
1749 const string& config) {
1750 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1751 if (!LayoutUtil::HasLayout(shape)) {
1752 return InvalidArgument("Given shape to Infeed must have a layout");
1753 }
1754 const Shape infeed_instruction_shape =
1755 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1756
1757 if (shape.IsArray() && sharding() &&
1758 sharding()->type() == OpSharding::OTHER) {
1759 // TODO(b/110793772): Support tiled array-shaped infeeds.
1760 return InvalidArgument(
1761 "Tiled sharding is not yet supported for array-shaped infeeds");
1762 }
1763
1764 if (sharding() && sharding()->type() == OpSharding::REPLICATED) {
1765 return InvalidArgument(
1766 "Replicated sharding is not yet supported for infeeds");
1767 }
1768 return InfeedWithTokenInternal(infeed_instruction_shape, token, config);
1769 });
1770 }
1771
InfeedWithTokenInternal(const Shape & infeed_instruction_shape,XlaOp token,const string & config)1772 StatusOr<XlaOp> XlaBuilder::InfeedWithTokenInternal(
1773 const Shape& infeed_instruction_shape, XlaOp token, const string& config) {
1774 HloInstructionProto instr;
1775 *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1776 instr.set_infeed_config(config);
1777 return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
1778 }
1779
Outfeed(XlaOp operand,const Shape & shape_with_layout,const string & outfeed_config)1780 void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout,
1781 const string& outfeed_config) {
1782 ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1783 HloInstructionProto instr;
1784
1785 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1786
1787 // Check and set outfeed shape.
1788 if (!LayoutUtil::HasLayout(shape_with_layout)) {
1789 return InvalidArgument("Given shape to Outfeed must have a layout");
1790 }
1791 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1792 if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
1793 return InvalidArgument(
1794 "Outfeed shape %s must be compatible with operand shape %s",
1795 ShapeUtil::HumanStringWithLayout(shape_with_layout),
1796 ShapeUtil::HumanStringWithLayout(*operand_shape));
1797 }
1798 *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1799
1800 instr.set_outfeed_config(outfeed_config);
1801
1802 // Outfeed takes a token as its second operand. Generate the token to pass
1803 // to the outfeed.
1804 HloInstructionProto token_instr;
1805 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1806 TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
1807 HloOpcode::kAfterAll, {}));
1808
1809 TF_RETURN_IF_ERROR(
1810 AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand, token})
1811 .status());
1812
1813 // The outfeed instruction produces a token. However, existing users expect
1814 // a nil shape (empty tuple). This should only be relevant if the outfeed is
1815 // the root of a computation.
1816 // TODO(b/80000000): Remove this when clients have been updated to handle
1817 // tokens.
1818 HloInstructionProto tuple_instr;
1819 *tuple_instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
1820
1821 // The dummy tuple should have no sharding.
1822 {
1823 XlaScopedShardingAssignment scoped_sharding(this, OpSharding());
1824 TF_ASSIGN_OR_RETURN(
1825 XlaOp empty_tuple,
1826 AddInstruction(std::move(tuple_instr), HloOpcode::kTuple, {}));
1827 return empty_tuple;
1828 }
1829 });
1830 }
1831
OutfeedWithToken(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)1832 XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token,
1833 const Shape& shape_with_layout,
1834 const string& outfeed_config) {
1835 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1836 // Check and set outfeed shape.
1837 if (!LayoutUtil::HasLayout(shape_with_layout)) {
1838 return InvalidArgument("Given shape to Outfeed must have a layout");
1839 }
1840 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1841 if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
1842 return InvalidArgument(
1843 "Outfeed shape %s must be compatible with operand shape %s",
1844 ShapeUtil::HumanStringWithLayout(shape_with_layout),
1845 ShapeUtil::HumanStringWithLayout(*operand_shape));
1846 }
1847 return OutfeedWithTokenInternal(operand, token, shape_with_layout,
1848 outfeed_config);
1849 });
1850 }
1851
OutfeedWithTokenInternal(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)1852 StatusOr<XlaOp> XlaBuilder::OutfeedWithTokenInternal(
1853 XlaOp operand, XlaOp token, const Shape& shape_with_layout,
1854 const string& outfeed_config) {
1855 HloInstructionProto instr;
1856 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1857 *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1858 instr.set_outfeed_config(outfeed_config);
1859 return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
1860 {operand, token});
1861 }
1862
CreateToken()1863 XlaOp XlaBuilder::CreateToken() {
1864 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1865 HloInstructionProto instr;
1866 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1867 return AddInstruction(std::move(instr), HloOpcode::kAfterAll);
1868 });
1869 }
1870
AfterAll(absl::Span<const XlaOp> tokens)1871 XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
1872 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1873 if (tokens.empty()) {
1874 return InvalidArgument("AfterAll requires at least one operand");
1875 }
1876 for (int i = 0, end = tokens.size(); i < end; ++i) {
1877 XlaOp operand = tokens[i];
1878 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1879 if (!operand_shape->IsToken()) {
1880 return InvalidArgument(
1881 "All operands to AfterAll must be tokens; operand %d has shape %s",
1882 i, ShapeUtil::HumanString(*operand_shape));
1883 }
1884 }
1885 HloInstructionProto instr;
1886 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1887 return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens);
1888 });
1889 }
1890
CustomCall(const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal,absl::optional<Window> window,absl::optional<ConvolutionDimensionNumbers> dnums,CustomCallSchedule schedule,CustomCallApiVersion api_version)1891 XlaOp XlaBuilder::CustomCall(
1892 const string& call_target_name, absl::Span<const XlaOp> operands,
1893 const Shape& shape, const string& opaque,
1894 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
1895 bool has_side_effect,
1896 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1897 output_operand_aliasing,
1898 const Literal* literal, absl::optional<Window> window,
1899 absl::optional<ConvolutionDimensionNumbers> dnums,
1900 CustomCallSchedule schedule, CustomCallApiVersion api_version) {
1901 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1902 if (absl::StartsWith(call_target_name, "$")) {
1903 return InvalidArgument(
1904 "Invalid custom_call_target \"%s\": Call targets that start with '$' "
1905 "are reserved for internal use.",
1906 call_target_name);
1907 }
1908 if (operand_shapes_with_layout.has_value()) {
1909 if (!LayoutUtil::HasLayout(shape)) {
1910 return InvalidArgument(
1911 "Result shape must have layout for custom call with constrained "
1912 "layout.");
1913 }
1914 if (operands.size() != operand_shapes_with_layout->size()) {
1915 return InvalidArgument(
1916 "Must specify a shape with layout for each operand for custom call "
1917 "with constrained layout; given %d shapes, expected %d",
1918 operand_shapes_with_layout->size(), operands.size());
1919 }
1920 int64_t operand_num = 0;
1921 for (const Shape& operand_shape : *operand_shapes_with_layout) {
1922 if (!LayoutUtil::HasLayout(operand_shape)) {
1923 return InvalidArgument(
1924 "No layout specified for operand %d for custom call with "
1925 "constrained layout.",
1926 operand_num);
1927 }
1928 ++operand_num;
1929 }
1930 }
1931 return CustomCallInternal(call_target_name, operands, shape, opaque,
1932 operand_shapes_with_layout, has_side_effect,
1933 output_operand_aliasing, literal, window, dnums,
1934 schedule, api_version);
1935 });
1936 }
1937
CustomCallInternal(const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal,absl::optional<Window> window,absl::optional<ConvolutionDimensionNumbers> dnums,CustomCallSchedule schedule,CustomCallApiVersion api_version)1938 StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
1939 const string& call_target_name, absl::Span<const XlaOp> operands,
1940 const Shape& shape, const string& opaque,
1941 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
1942 bool has_side_effect,
1943 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1944 output_operand_aliasing,
1945 const Literal* literal, absl::optional<Window> window,
1946 absl::optional<ConvolutionDimensionNumbers> dnums,
1947 CustomCallSchedule schedule, CustomCallApiVersion api_version) {
1948 HloInstructionProto instr;
1949 *instr.mutable_shape() = shape.ToProto();
1950 instr.set_custom_call_target(call_target_name);
1951 instr.set_backend_config(opaque);
1952 if (operand_shapes_with_layout.has_value()) {
1953 instr.set_constrain_layout(true);
1954 for (const Shape& operand_shape : *operand_shapes_with_layout) {
1955 *instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
1956 }
1957 }
1958 if (literal != nullptr) {
1959 *instr.mutable_literal() = literal->ToProto();
1960 }
1961 instr.set_custom_call_has_side_effect(has_side_effect);
1962 for (const auto& pair : output_operand_aliasing) {
1963 auto aliasing = instr.add_custom_call_output_operand_aliasing();
1964 aliasing->set_operand_index(pair.second.first);
1965 for (int64_t index : pair.second.second) {
1966 aliasing->add_operand_shape_index(index);
1967 }
1968 for (int64_t index : pair.first) {
1969 aliasing->add_output_shape_index(index);
1970 }
1971 }
1972 if (window.has_value()) {
1973 *instr.mutable_window() = *window;
1974 }
1975 if (dnums.has_value()) {
1976 *instr.mutable_convolution_dimension_numbers() = *dnums;
1977 }
1978 instr.set_custom_call_schedule(schedule);
1979 instr.set_custom_call_api_version(api_version);
1980 return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
1981 }
1982
CustomCall(const string & call_target_name,absl::Span<const XlaOp> operands,const XlaComputation & computation,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal,CustomCallSchedule schedule,CustomCallApiVersion api_version)1983 XlaOp XlaBuilder::CustomCall(
1984 const string& call_target_name, absl::Span<const XlaOp> operands,
1985 const XlaComputation& computation, const Shape& shape, const string& opaque,
1986 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
1987 bool has_side_effect,
1988 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1989 output_operand_aliasing,
1990 const Literal* literal, CustomCallSchedule schedule,
1991 CustomCallApiVersion api_version) {
1992 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1993 HloInstructionProto instr;
1994 if (absl::StartsWith(call_target_name, "$")) {
1995 return InvalidArgument(
1996 "Invalid custom_call_target \"%s\": Call targets that start with '$' "
1997 "are reserved for internal use.",
1998 call_target_name);
1999 }
2000 *instr.mutable_shape() = shape.ToProto();
2001 instr.set_custom_call_target(call_target_name);
2002 instr.set_backend_config(opaque);
2003 if (literal != nullptr) {
2004 *instr.mutable_literal() = literal->ToProto();
2005 }
2006 if (operand_shapes_with_layout.has_value()) {
2007 if (!LayoutUtil::HasLayout(shape)) {
2008 return InvalidArgument(
2009 "Result shape must have layout for custom call with constrained "
2010 "layout.");
2011 }
2012 if (operands.size() != operand_shapes_with_layout->size()) {
2013 return InvalidArgument(
2014 "Must specify a shape with layout for each operand for custom call "
2015 "with constrained layout; given %d shapes, expected %d",
2016 operand_shapes_with_layout->size(), operands.size());
2017 }
2018 instr.set_constrain_layout(true);
2019 int64_t operand_num = 0;
2020 for (const Shape& operand_shape : *operand_shapes_with_layout) {
2021 if (!LayoutUtil::HasLayout(operand_shape)) {
2022 return InvalidArgument(
2023 "No layout specified for operand %d for custom call with "
2024 "constrained layout.",
2025 operand_num);
2026 }
2027 *instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
2028 ++operand_num;
2029 }
2030 }
2031 AddCalledComputation(computation, &instr);
2032 for (const auto& pair : output_operand_aliasing) {
2033 auto aliasing = instr.add_custom_call_output_operand_aliasing();
2034 aliasing->set_operand_index(pair.second.first);
2035 for (int64_t index : pair.second.second) {
2036 aliasing->add_operand_shape_index(index);
2037 }
2038 for (int64_t index : pair.first) {
2039 aliasing->add_output_shape_index(index);
2040 }
2041 }
2042 instr.set_custom_call_schedule(schedule);
2043 instr.set_custom_call_api_version(api_version);
2044 return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
2045 });
2046 }
2047
Transpose(XlaOp operand,absl::Span<const int64> permutation)2048 XlaOp XlaBuilder::Transpose(XlaOp operand,
2049 absl::Span<const int64> permutation) {
2050 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2051 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2052 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape(
2053 *operand_shape, permutation));
2054 return TransposeInternal(shape, operand, permutation);
2055 });
2056 }
2057
TransposeInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> permutation)2058 StatusOr<XlaOp> XlaBuilder::TransposeInternal(
2059 const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
2060 HloInstructionProto instr;
2061 *instr.mutable_shape() = shape.ToProto();
2062 for (int64_t dim : permutation) {
2063 instr.add_dimensions(dim);
2064 }
2065 return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
2066 }
2067
Rev(XlaOp operand,absl::Span<const int64> dimensions)2068 XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span<const int64> dimensions) {
2069 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2070 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2071 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape(
2072 *operand_shape, dimensions));
2073 return RevInternal(shape, operand, dimensions);
2074 });
2075 }
2076
RevInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> dimensions)2077 StatusOr<XlaOp> XlaBuilder::RevInternal(const Shape& shape, XlaOp operand,
2078 absl::Span<const int64> dimensions) {
2079 HloInstructionProto instr;
2080 *instr.mutable_shape() = shape.ToProto();
2081 for (int64_t dim : dimensions) {
2082 instr.add_dimensions(dim);
2083 }
2084 return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand});
2085 }
2086
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64_t dimension,bool is_stable)2087 XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
2088 const XlaComputation& comparator, int64_t dimension,
2089 bool is_stable) {
2090 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2091 std::vector<const Shape*> operand_shape_ptrs;
2092 TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes,
2093 GetOperandShapes(operands));
2094 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
2095 [](const Shape& shape) { return &shape; });
2096 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape(
2097 HloOpcode::kSort, operand_shape_ptrs));
2098 return SortInternal(shape, operands, comparator, dimension, is_stable);
2099 });
2100 }
2101
SortInternal(const Shape & shape,absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64_t dimension,bool is_stable)2102 StatusOr<XlaOp> XlaBuilder::SortInternal(const Shape& shape,
2103 absl::Span<const XlaOp> operands,
2104 const XlaComputation& comparator,
2105 int64_t dimension, bool is_stable) {
2106 HloInstructionProto instr;
2107 *instr.mutable_shape() = shape.ToProto();
2108 instr.set_is_stable(is_stable);
2109 if (dimension == -1) {
2110 TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0]));
2111 dimension = keys_shape->rank() - 1;
2112 }
2113 instr.add_dimensions(dimension);
2114 AddCalledComputation(comparator, &instr);
2115 return AddInstruction(std::move(instr), HloOpcode::kSort, operands);
2116 }
2117
ConvertElementType(XlaOp operand,PrimitiveType new_element_type)2118 XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
2119 PrimitiveType new_element_type) {
2120 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2121 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2122 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
2123 *operand_shape, new_element_type));
2124 return AddOpWithShape(HloOpcode::kConvert, shape, {operand});
2125 });
2126 }
2127
BitcastConvertType(XlaOp operand,PrimitiveType new_element_type)2128 XlaOp XlaBuilder::BitcastConvertType(XlaOp operand,
2129 PrimitiveType new_element_type) {
2130 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2131 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2132 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
2133 *operand_shape, new_element_type));
2134 return BitcastConvertTypeInternal(shape, operand);
2135 });
2136 }
2137
BitcastConvertTypeInternal(const Shape & shape,XlaOp operand)2138 StatusOr<XlaOp> XlaBuilder::BitcastConvertTypeInternal(const Shape& shape,
2139 XlaOp operand) {
2140 HloInstructionProto instr;
2141 *instr.mutable_shape() = shape.ToProto();
2142 return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
2143 {operand});
2144 }
2145
Clamp(XlaOp min,XlaOp operand,XlaOp max)2146 XlaOp XlaBuilder::Clamp(XlaOp min, XlaOp operand, XlaOp max) {
2147 return TernaryOp(HloOpcode::kClamp, min, operand, max);
2148 }
2149
Map(absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64> dimensions,absl::Span<const XlaOp> static_operands)2150 XlaOp XlaBuilder::Map(absl::Span<const XlaOp> operands,
2151 const XlaComputation& computation,
2152 absl::Span<const int64> dimensions,
2153 absl::Span<const XlaOp> static_operands) {
2154 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2155 if (!static_operands.empty()) {
2156 return Unimplemented("static_operands is not supported in Map");
2157 }
2158
2159 HloInstructionProto instr;
2160 std::vector<const Shape*> operand_shape_ptrs;
2161 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
2162 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
2163 [](const Shape& shape) { return &shape; });
2164 TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
2165 computation.GetProgramShape());
2166 TF_ASSIGN_OR_RETURN(
2167 Shape shape, ShapeInference::InferMapShape(
2168 operand_shape_ptrs, called_program_shape, dimensions));
2169 *instr.mutable_shape() = shape.ToProto();
2170
2171 Shape output_shape(instr.shape());
2172 const int64_t output_rank = output_shape.rank();
2173 AddCalledComputation(computation, &instr);
2174 std::vector<XlaOp> new_operands(operands.begin(), operands.end());
2175 for (XlaOp& new_operand : new_operands) {
2176 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(new_operand));
2177 const int64_t rank = shape->rank();
2178 if (rank != output_rank) {
2179 TF_ASSIGN_OR_RETURN(new_operand,
2180 InDimBroadcast(output_shape, new_operand, {}));
2181 TF_ASSIGN_OR_RETURN(shape, GetShapePtr(new_operand));
2182 }
2183 if (!ShapeUtil::SameDimensions(output_shape, *shape)) {
2184 TF_ASSIGN_OR_RETURN(new_operand,
2185 AddBroadcastSequence(output_shape, new_operand));
2186 }
2187 }
2188
2189 return AddInstruction(std::move(instr), HloOpcode::kMap, new_operands);
2190 });
2191 }
2192
RngOp(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)2193 XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
2194 absl::Span<const XlaOp> parameters,
2195 const Shape& shape) {
2196 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2197 // Check the number of parameters per RNG distribution.
2198 switch (distribution) {
2199 case RandomDistribution::RNG_NORMAL:
2200 case RandomDistribution::RNG_UNIFORM:
2201 if (parameters.size() != 2) {
2202 return InvalidArgument(
2203 "RNG distribution (%s) expects 2 parameters, but got %ld",
2204 RandomDistribution_Name(distribution), parameters.size());
2205 }
2206 break;
2207 default:
2208 LOG(FATAL) << "unhandled distribution " << distribution;
2209 }
2210
2211 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
2212 return RngOpInternal(distribution, parameters, shape);
2213 });
2214 }
2215
RngOpInternal(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)2216 StatusOr<XlaOp> XlaBuilder::RngOpInternal(RandomDistribution distribution,
2217 absl::Span<const XlaOp> parameters,
2218 const Shape& shape) {
2219 HloInstructionProto instr;
2220 *instr.mutable_shape() = shape.ToProto();
2221 instr.set_distribution(distribution);
2222
2223 return AddInstruction(std::move(instr), HloOpcode::kRng, parameters);
2224 }
2225
RngNormal(XlaOp mu,XlaOp sigma,const Shape & shape)2226 XlaOp XlaBuilder::RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape) {
2227 return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
2228 }
2229
RngUniform(XlaOp a,XlaOp b,const Shape & shape)2230 XlaOp XlaBuilder::RngUniform(XlaOp a, XlaOp b, const Shape& shape) {
2231 return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
2232 }
2233
RngBitGenerator(RandomAlgorithm algorithm,XlaOp initial_state,const Shape & shape)2234 XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm,
2235 XlaOp initial_state, const Shape& shape) {
2236 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2237 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
2238 TF_ASSIGN_OR_RETURN(Shape state_shape, GetShape(initial_state));
2239 Shape output_shape = shape;
2240 switch (output_shape.element_type()) {
2241 case PrimitiveType::F32:
2242 case PrimitiveType::S32:
2243 case PrimitiveType::U32:
2244 output_shape.set_element_type(PrimitiveType::U32);
2245 break;
2246 case PrimitiveType::F64:
2247 case PrimitiveType::S64:
2248 case PrimitiveType::U64:
2249 output_shape.set_element_type(PrimitiveType::U64);
2250 break;
2251 default:
2252 return InvalidArgument("Unsupported shape for RngBitGenerator: %s",
2253 PrimitiveType_Name(output_shape.element_type()));
2254 }
2255 return RngBitGeneratorInternal(
2256 ShapeUtil::MakeTupleShape({state_shape, output_shape}), algorithm,
2257 initial_state);
2258 });
2259 }
2260
RngBitGeneratorInternal(const Shape & full_result_shape,RandomAlgorithm algorithm,XlaOp initial_state)2261 StatusOr<XlaOp> XlaBuilder::RngBitGeneratorInternal(
2262 const Shape& full_result_shape, RandomAlgorithm algorithm,
2263 XlaOp initial_state) {
2264 HloInstructionProto instr;
2265 *instr.mutable_shape() = full_result_shape.ToProto();
2266 instr.set_rng_algorithm(algorithm);
2267 return AddInstruction(std::move(instr), HloOpcode::kRngBitGenerator,
2268 {initial_state});
2269 }
2270
While(const XlaComputation & condition,const XlaComputation & body,XlaOp init)2271 XlaOp XlaBuilder::While(const XlaComputation& condition,
2272 const XlaComputation& body, XlaOp init) {
2273 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2274 // Infer shape.
2275 TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape());
2276 TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
2277 condition.GetProgramShape());
2278 TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init));
2279 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape(
2280 condition_program_shape,
2281 body_program_shape, *init_shape));
2282 return WhileInternal(shape, condition, body, init);
2283 });
2284 }
2285
WhileInternal(const Shape & shape,const XlaComputation & condition,const XlaComputation & body,XlaOp init)2286 StatusOr<XlaOp> XlaBuilder::WhileInternal(const Shape& shape,
2287 const XlaComputation& condition,
2288 const XlaComputation& body,
2289 XlaOp init) {
2290 HloInstructionProto instr;
2291 *instr.mutable_shape() = shape.ToProto();
2292 // Body comes before condition computation in the vector.
2293 AddCalledComputation(body, &instr);
2294 AddCalledComputation(condition, &instr);
2295 return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
2296 }
2297
Gather(XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)2298 XlaOp XlaBuilder::Gather(XlaOp input, XlaOp start_indices,
2299 const GatherDimensionNumbers& dimension_numbers,
2300 absl::Span<const int64> slice_sizes,
2301 bool indices_are_sorted) {
2302 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2303 TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
2304 TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape,
2305 GetShapePtr(start_indices));
2306 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGatherShape(
2307 *input_shape, *start_indices_shape,
2308 dimension_numbers, slice_sizes));
2309 return GatherInternal(shape, input, start_indices, dimension_numbers,
2310 slice_sizes, indices_are_sorted);
2311 });
2312 }
2313
GatherInternal(const Shape & shape,XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)2314 StatusOr<XlaOp> XlaBuilder::GatherInternal(
2315 const Shape& shape, XlaOp input, XlaOp start_indices,
2316 const GatherDimensionNumbers& dimension_numbers,
2317 absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
2318 HloInstructionProto instr;
2319 instr.set_indices_are_sorted(indices_are_sorted);
2320 *instr.mutable_shape() = shape.ToProto();
2321 *instr.mutable_gather_dimension_numbers() = dimension_numbers;
2322 for (int64_t bound : slice_sizes) {
2323 instr.add_gather_slice_sizes(bound);
2324 }
2325
2326 return AddInstruction(std::move(instr), HloOpcode::kGather,
2327 {input, start_indices});
2328 }
2329
Scatter(XlaOp input,XlaOp scatter_indices,XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)2330 XlaOp XlaBuilder::Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
2331 const XlaComputation& update_computation,
2332 const ScatterDimensionNumbers& dimension_numbers,
2333 bool indices_are_sorted, bool unique_indices) {
2334 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2335 TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
2336 TF_ASSIGN_OR_RETURN(const Shape* scatter_indices_shape,
2337 GetShapePtr(scatter_indices));
2338 TF_ASSIGN_OR_RETURN(const Shape* updates_shape, GetShapePtr(updates));
2339 TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2340 update_computation.GetProgramShape());
2341 TF_ASSIGN_OR_RETURN(
2342 Shape shape, ShapeInference::InferScatterShape(
2343 *input_shape, *scatter_indices_shape, *updates_shape,
2344 to_apply_shape, dimension_numbers));
2345 return ScatterInternal(shape, input, scatter_indices, updates,
2346 update_computation, dimension_numbers,
2347 indices_are_sorted, unique_indices);
2348 });
2349 }
2350
ScatterInternal(const Shape & shape,XlaOp input,XlaOp scatter_indices,XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)2351 StatusOr<XlaOp> XlaBuilder::ScatterInternal(
2352 const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
2353 const XlaComputation& update_computation,
2354 const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
2355 bool unique_indices) {
2356 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2357 HloInstructionProto instr;
2358 instr.set_indices_are_sorted(indices_are_sorted);
2359 instr.set_unique_indices(unique_indices);
2360 *instr.mutable_shape() = shape.ToProto();
2361 *instr.mutable_scatter_dimension_numbers() = dimension_numbers;
2362
2363 AddCalledComputation(update_computation, &instr);
2364 return AddInstruction(std::move(instr), HloOpcode::kScatter,
2365 {input, scatter_indices, updates});
2366 });
2367 }
2368
Conditional(XlaOp predicate,XlaOp true_operand,const XlaComputation & true_computation,XlaOp false_operand,const XlaComputation & false_computation)2369 XlaOp XlaBuilder::Conditional(XlaOp predicate, XlaOp true_operand,
2370 const XlaComputation& true_computation,
2371 XlaOp false_operand,
2372 const XlaComputation& false_computation) {
2373 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2374 TF_ASSIGN_OR_RETURN(const xla::Shape* shape, GetShapePtr(predicate));
2375
2376 if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != PRED) {
2377 return InvalidArgument(
2378 "Argument to predicated-Conditional is not a scalar of PRED type "
2379 "(%s).",
2380 ShapeUtil::HumanString(*shape));
2381 }
2382 // The index of true_computation must be 0 and that of false computation
2383 // must be 1.
2384 return ConditionalImpl(predicate, {&true_computation, &false_computation},
2385 {true_operand, false_operand});
2386 });
2387 }
2388
Conditional(XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)2389 XlaOp XlaBuilder::Conditional(
2390 XlaOp branch_index,
2391 absl::Span<const XlaComputation* const> branch_computations,
2392 absl::Span<const XlaOp> branch_operands) {
2393 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2394 TF_ASSIGN_OR_RETURN(const xla::Shape* shape, GetShapePtr(branch_index));
2395
2396 if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != S32) {
2397 return InvalidArgument(
2398 "Argument to indexed-Conditional is not a scalar of S32 type (%s).",
2399 ShapeUtil::HumanString(*shape));
2400 }
2401 return ConditionalImpl(branch_index, branch_computations, branch_operands);
2402 });
2403 }
2404
ConditionalImpl(XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)2405 XlaOp XlaBuilder::ConditionalImpl(
2406 XlaOp branch_index,
2407 absl::Span<const XlaComputation* const> branch_computations,
2408 absl::Span<const XlaOp> branch_operands) {
2409 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2410 HloInstructionProto instr;
2411
2412 TF_ASSIGN_OR_RETURN(const Shape* branch_index_shape,
2413 GetShapePtr(branch_index));
2414 std::vector<Shape> branch_operand_shapes(branch_operands.size());
2415 std::vector<ProgramShape> branch_computation_shapes(
2416 branch_computations.size());
2417 for (int j = 0, end = branch_operands.size(); j < end; ++j) {
2418 TF_ASSIGN_OR_RETURN(branch_operand_shapes[j],
2419 GetShape(branch_operands[j]));
2420 TF_ASSIGN_OR_RETURN(branch_computation_shapes[j],
2421 branch_computations[j]->GetProgramShape());
2422 }
2423 TF_ASSIGN_OR_RETURN(const Shape shape,
2424 ShapeInference::InferConditionalShape(
2425 *branch_index_shape, branch_computation_shapes,
2426 branch_operand_shapes));
2427 *instr.mutable_shape() = shape.ToProto();
2428
2429 for (const XlaComputation* branch_computation : branch_computations) {
2430 AddCalledComputation(*branch_computation, &instr);
2431 }
2432
2433 std::vector<XlaOp> operands(1, branch_index);
2434 for (const XlaOp branch_operand : branch_operands) {
2435 operands.emplace_back(branch_operand);
2436 }
2437 return AddInstruction(std::move(instr), HloOpcode::kConditional,
2438 absl::MakeSpan(operands));
2439 });
2440 }
2441
CheckOpBuilder(XlaOp op) const2442 Status XlaBuilder::CheckOpBuilder(XlaOp op) const {
2443 if (this != op.builder()) {
2444 return InvalidArgument(
2445 "XlaOp with handle %d is built by builder '%s', but is trying to use "
2446 "it in builder '%s'",
2447 op.handle(), op.builder()->name(), name());
2448 }
2449 return Status::OK();
2450 }
2451
Reduce(XlaOp operand,XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)2452 XlaOp XlaBuilder::Reduce(XlaOp operand, XlaOp init_value,
2453 const XlaComputation& computation,
2454 absl::Span<const int64> dimensions_to_reduce) {
2455 return Reduce(absl::Span<const XlaOp>({operand}),
2456 absl::Span<const XlaOp>({init_value}), computation,
2457 dimensions_to_reduce);
2458 }
2459
Reduce(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)2460 XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
2461 absl::Span<const XlaOp> init_values,
2462 const XlaComputation& computation,
2463 absl::Span<const int64> dimensions_to_reduce) {
2464 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2465 TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
2466 computation.GetProgramShape());
2467
2468 std::vector<XlaOp> all_operands;
2469 all_operands.insert(all_operands.end(), operands.begin(), operands.end());
2470 all_operands.insert(all_operands.end(), init_values.begin(),
2471 init_values.end());
2472
2473 std::vector<const Shape*> operand_shape_ptrs;
2474 TF_ASSIGN_OR_RETURN(const auto& operand_shapes,
2475 GetOperandShapes(all_operands));
2476 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
2477 [](const Shape& shape) { return &shape; });
2478
2479 TF_ASSIGN_OR_RETURN(
2480 Shape shape,
2481 ShapeInference::InferReduceShape(
2482 operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
2483 return ReduceInternal(shape, all_operands, computation,
2484 dimensions_to_reduce);
2485 });
2486 }
2487
ReduceInternal(const Shape & shape,absl::Span<const XlaOp> all_operands,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)2488 StatusOr<XlaOp> XlaBuilder::ReduceInternal(
2489 const Shape& shape, absl::Span<const XlaOp> all_operands,
2490 const XlaComputation& computation,
2491 absl::Span<const int64> dimensions_to_reduce) {
2492 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2493 HloInstructionProto instr;
2494 *instr.mutable_shape() = shape.ToProto();
2495
2496 for (int64_t dim : dimensions_to_reduce) {
2497 instr.add_dimensions(dim);
2498 }
2499
2500 AddCalledComputation(computation, &instr);
2501 return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands);
2502 });
2503 }
2504
ReduceAll(XlaOp operand,XlaOp init_value,const XlaComputation & computation)2505 XlaOp XlaBuilder::ReduceAll(XlaOp operand, XlaOp init_value,
2506 const XlaComputation& computation) {
2507 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2508 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2509 std::vector<int64> all_dimnos(operand_shape->rank());
2510 std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
2511 return Reduce(operand, init_value, computation, all_dimnos);
2512 });
2513 }
2514
ReduceWindow(XlaOp operand,XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)2515 XlaOp XlaBuilder::ReduceWindow(XlaOp operand, XlaOp init_value,
2516 const XlaComputation& computation,
2517 absl::Span<const int64> window_dimensions,
2518 absl::Span<const int64> window_strides,
2519 Padding padding) {
2520 return ReduceWindow(absl::MakeSpan(&operand, 1),
2521 absl::MakeSpan(&init_value, 1), computation,
2522 window_dimensions, window_strides, padding);
2523 }
2524
ReduceWindow(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)2525 XlaOp XlaBuilder::ReduceWindow(absl::Span<const XlaOp> operands,
2526 absl::Span<const XlaOp> init_values,
2527 const XlaComputation& computation,
2528 absl::Span<const int64> window_dimensions,
2529 absl::Span<const int64> window_strides,
2530 Padding padding) {
2531 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2532 const Shape* operand_shape = nullptr;
2533 for (const auto& operand : operands) {
2534 TF_ASSIGN_OR_RETURN(operand_shape, GetShapePtr(operand));
2535 TF_RETURN_IF_ERROR(
2536 ValidatePaddingValues(AsInt64Slice(operand_shape->dimensions()),
2537 window_dimensions, window_strides));
2538 }
2539 CHECK(operand_shape != nullptr);
2540 std::vector<std::pair<int64, int64>> padding_values =
2541 MakePadding(AsInt64Slice(operand_shape->dimensions()),
2542 window_dimensions, window_strides, padding);
2543 TF_ASSIGN_OR_RETURN(auto window,
2544 ShapeInference::InferWindowFromDimensions(
2545 window_dimensions, window_strides, padding_values,
2546 /*lhs_dilation=*/{},
2547 /*rhs_dilation=*/{}));
2548 PaddingType padding_type = PADDING_INVALID;
2549 for (int64_t i = 0; i < operand_shape->rank(); ++i) {
2550 if (operand_shape->is_dynamic_dimension(i) &&
2551 !window_util::IsTrivialWindowDimension(window.dimensions(i)) &&
2552 padding == Padding::kSame) {
2553 // SAME padding can create dynamic padding sizes. The padding size
2554 // need to be rewritten by dynamic padder using HloInstructions. We
2555 // create a CustomCall to handle this.
2556 padding_type = PADDING_SAME;
2557 }
2558 }
2559 if (padding_type == PADDING_SAME) {
2560 TF_ASSIGN_OR_RETURN(
2561 HloInstructionProto instr,
2562 ReduceWindowInternal(operands, init_values, computation,
2563 window_dimensions, window_strides, {}, {},
2564 padding_values));
2565 instr.set_custom_call_target("DynamicReduceWindowSamePadding");
2566 std::vector<XlaOp> args;
2567 args.insert(args.end(), operands.begin(), operands.end());
2568 args.insert(args.end(), init_values.begin(), init_values.end());
2569 return AddInstruction(std::move(instr), HloOpcode::kCustomCall, args);
2570 }
2571 return ReduceWindowWithGeneralPadding(
2572 operands, init_values, computation, window_dimensions, window_strides,
2573 /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values);
2574 });
2575 }
2576
ReduceWindowWithGeneralPadding(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)2577 XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
2578 absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
2579 const XlaComputation& computation,
2580 absl::Span<const int64> window_dimensions,
2581 absl::Span<const int64> window_strides,
2582 absl::Span<const int64> base_dilations,
2583 absl::Span<const int64> window_dilations,
2584 absl::Span<const std::pair<int64, int64>> padding) {
2585 std::vector<const Shape*> operand_shapes, init_shapes;
2586 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2587 if (operands.size() == 1) {
2588 const auto& operand = operands[0];
2589 const auto& init_value = init_values[0];
2590 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2591 operand_shapes.push_back(operand_shape);
2592 TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
2593 init_shapes.push_back(init_shape);
2594
2595 TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2596 computation.GetProgramShape());
2597 TF_ASSIGN_OR_RETURN(auto window,
2598 ShapeInference::InferWindowFromDimensions(
2599 window_dimensions, window_strides, padding,
2600 /*lhs_dilation=*/base_dilations,
2601 /*rhs_dilation=*/window_dilations));
2602 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape(
2603 absl::MakeSpan(operand_shapes),
2604 absl::MakeSpan(init_shapes), window,
2605 to_apply_shape));
2606 return ReduceWindowInternal(shape, operands[0], init_values[0],
2607 computation, window);
2608 }
2609
2610 TF_ASSIGN_OR_RETURN(
2611 HloInstructionProto instr,
2612 ReduceWindowInternal(operands, init_values, computation,
2613 window_dimensions, window_strides, base_dilations,
2614 window_dilations, padding));
2615 std::vector<XlaOp> args;
2616 args.insert(args.end(), operands.begin(), operands.end());
2617 args.insert(args.end(), init_values.begin(), init_values.end());
2618 return AddInstruction(std::move(instr), HloOpcode::kReduceWindow, args);
2619 });
2620 }
2621
ReduceWindowInternal(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)2622 StatusOr<HloInstructionProto> XlaBuilder::ReduceWindowInternal(
2623 absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
2624 const XlaComputation& computation,
2625 absl::Span<const int64> window_dimensions,
2626 absl::Span<const int64> window_strides,
2627 absl::Span<const int64> base_dilations,
2628 absl::Span<const int64> window_dilations,
2629 absl::Span<const std::pair<int64, int64>> padding) {
2630 std::vector<const Shape*> operand_shapes, init_shapes;
2631 for (int i = 0; i < operands.size(); ++i) {
2632 const auto& operand = operands[i];
2633 const auto& init_value = init_values[i];
2634 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2635 operand_shapes.push_back(operand_shape);
2636 TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
2637 init_shapes.push_back(init_shape);
2638 }
2639 TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2640 computation.GetProgramShape());
2641 TF_ASSIGN_OR_RETURN(auto window,
2642 ShapeInference::InferWindowFromDimensions(
2643 window_dimensions, window_strides, padding,
2644 /*lhs_dilation=*/base_dilations,
2645 /*rhs_dilation=*/window_dilations));
2646 TF_ASSIGN_OR_RETURN(Shape shape,
2647 ShapeInference::InferReduceWindowShape(
2648 absl::MakeSpan(operand_shapes),
2649 absl::MakeSpan(init_shapes), window, to_apply_shape));
2650 HloInstructionProto instr;
2651 *instr.mutable_shape() = shape.ToProto();
2652 *instr.mutable_window() = std::move(window);
2653 AddCalledComputation(computation, &instr);
2654 return instr;
2655 }
2656
ReduceWindowInternal(const Shape & shape,XlaOp operand,XlaOp init_value,const XlaComputation & computation,Window window)2657 StatusOr<XlaOp> XlaBuilder::ReduceWindowInternal(
2658 const Shape& shape, XlaOp operand, XlaOp init_value,
2659 const XlaComputation& computation, Window window) {
2660 HloInstructionProto instr;
2661 *instr.mutable_shape() = shape.ToProto();
2662 *instr.mutable_window() = std::move(window);
2663
2664 AddCalledComputation(computation, &instr);
2665 return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
2666 {operand, init_value});
2667 }
2668
BatchNormTraining(XlaOp operand,XlaOp scale,XlaOp offset,float epsilon,int64_t feature_index)2669 XlaOp XlaBuilder::BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
2670 float epsilon, int64_t feature_index) {
2671 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2672 HloInstructionProto instr;
2673
2674 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2675 TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2676 TF_ASSIGN_OR_RETURN(const Shape* offset_shape, GetShapePtr(offset));
2677 TF_ASSIGN_OR_RETURN(
2678 Shape shape,
2679 ShapeInference::InferBatchNormTrainingShape(
2680 *operand_shape, *scale_shape, *offset_shape, feature_index));
2681 *instr.mutable_shape() = shape.ToProto();
2682
2683 instr.set_epsilon(epsilon);
2684 instr.set_feature_index(feature_index);
2685
2686 return AddInstruction(std::move(instr), HloOpcode::kBatchNormTraining,
2687 {operand, scale, offset});
2688 });
2689 }
2690
BatchNormInference(XlaOp operand,XlaOp scale,XlaOp offset,XlaOp mean,XlaOp variance,float epsilon,int64_t feature_index)2691 XlaOp XlaBuilder::BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset,
2692 XlaOp mean, XlaOp variance, float epsilon,
2693 int64_t feature_index) {
2694 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2695 HloInstructionProto instr;
2696
2697 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2698 TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2699 TF_ASSIGN_OR_RETURN(const Shape* offset_shape, GetShapePtr(offset));
2700 TF_ASSIGN_OR_RETURN(const Shape* mean_shape, GetShapePtr(mean));
2701 TF_ASSIGN_OR_RETURN(const Shape* variance_shape, GetShapePtr(variance));
2702 TF_ASSIGN_OR_RETURN(Shape shape,
2703 ShapeInference::InferBatchNormInferenceShape(
2704 *operand_shape, *scale_shape, *offset_shape,
2705 *mean_shape, *variance_shape, feature_index));
2706 *instr.mutable_shape() = shape.ToProto();
2707
2708 instr.set_epsilon(epsilon);
2709 instr.set_feature_index(feature_index);
2710
2711 return AddInstruction(std::move(instr), HloOpcode::kBatchNormInference,
2712 {operand, scale, offset, mean, variance});
2713 });
2714 }
2715
BatchNormGrad(XlaOp operand,XlaOp scale,XlaOp batch_mean,XlaOp batch_var,XlaOp grad_output,float epsilon,int64_t feature_index)2716 XlaOp XlaBuilder::BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
2717 XlaOp batch_var, XlaOp grad_output,
2718 float epsilon, int64_t feature_index) {
2719 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2720 HloInstructionProto instr;
2721
2722 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2723 TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2724 TF_ASSIGN_OR_RETURN(const Shape* batch_mean_shape, GetShapePtr(batch_mean));
2725 TF_ASSIGN_OR_RETURN(const Shape* batch_var_shape, GetShapePtr(batch_var));
2726 TF_ASSIGN_OR_RETURN(const Shape* grad_output_shape,
2727 GetShapePtr(grad_output));
2728 TF_ASSIGN_OR_RETURN(
2729 Shape shape, ShapeInference::InferBatchNormGradShape(
2730 *operand_shape, *scale_shape, *batch_mean_shape,
2731 *batch_var_shape, *grad_output_shape, feature_index));
2732 *instr.mutable_shape() = shape.ToProto();
2733
2734 instr.set_epsilon(epsilon);
2735 instr.set_feature_index(feature_index);
2736
2737 return AddInstruction(std::move(instr), HloOpcode::kBatchNormGrad,
2738 {operand, scale, batch_mean, batch_var, grad_output});
2739 });
2740 }
2741
AllGather(XlaOp operand,int64_t all_gather_dimension,int64_t shard_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Layout> & layout,const absl::optional<bool> use_global_device_ids)2742 XlaOp XlaBuilder::AllGather(XlaOp operand, int64_t all_gather_dimension,
2743 int64_t shard_count,
2744 absl::Span<const ReplicaGroup> replica_groups,
2745 const absl::optional<ChannelHandle>& channel_id,
2746 const absl::optional<Layout>& layout,
2747 const absl::optional<bool> use_global_device_ids) {
2748 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2749 HloInstructionProto instr;
2750 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2751
2752 TF_ASSIGN_OR_RETURN(
2753 Shape inferred_shape,
2754 ShapeInference::InferAllGatherShape({operand_shape},
2755 all_gather_dimension, shard_count));
2756 if (layout) {
2757 *inferred_shape.mutable_layout() = *layout;
2758 instr.set_constrain_layout(true);
2759 }
2760 *instr.mutable_shape() = inferred_shape.ToProto();
2761
2762 instr.add_dimensions(all_gather_dimension);
2763 for (const ReplicaGroup& group : replica_groups) {
2764 *instr.add_replica_groups() = group;
2765 }
2766 if (channel_id.has_value()) {
2767 instr.set_channel_id(channel_id->handle());
2768 }
2769 if (use_global_device_ids.has_value()) {
2770 instr.set_use_global_device_ids(use_global_device_ids.value());
2771 }
2772
2773 TF_ASSIGN_OR_RETURN(
2774 auto all_gather,
2775 AddInstruction(std::move(instr), HloOpcode::kAllGather, {operand}));
2776 return all_gather;
2777 });
2778 }
2779
CrossReplicaSum(XlaOp operand,absl::Span<const ReplicaGroup> replica_groups)2780 XlaOp XlaBuilder::CrossReplicaSum(
2781 XlaOp operand, absl::Span<const ReplicaGroup> replica_groups) {
2782 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2783 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
2784 const Shape* element_shape;
2785 if (shape->IsTuple()) {
2786 if (shape->tuple_shapes_size() == 0) {
2787 return Unimplemented(
2788 "0 element tuple CrossReplicaSum is not supported");
2789 }
2790 element_shape = &shape->tuple_shapes(0);
2791 } else {
2792 element_shape = shape;
2793 }
2794 const Shape scalar_shape =
2795 ShapeUtil::MakeShape(element_shape->element_type(), {});
2796 auto b = CreateSubBuilder("sum");
2797 auto x = b->Parameter(/*parameter_number=*/0, scalar_shape, "x");
2798 auto y = b->Parameter(/*parameter_number=*/1, scalar_shape, "y");
2799 if (scalar_shape.element_type() == PRED) {
2800 Or(x, y);
2801 } else {
2802 Add(x, y);
2803 }
2804 TF_ASSIGN_OR_RETURN(auto computation, b->Build());
2805 return AllReduce(operand, computation, replica_groups,
2806 /*channel_id=*/absl::nullopt);
2807 });
2808 }
2809
AllReduce(XlaOp operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Shape> & shape_with_layout)2810 XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
2811 absl::Span<const ReplicaGroup> replica_groups,
2812 const absl::optional<ChannelHandle>& channel_id,
2813 const absl::optional<Shape>& shape_with_layout) {
2814 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2815 HloInstructionProto instr;
2816 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2817 std::vector<const Shape*> operand_shapes;
2818 std::vector<XlaOp> operands;
2819 if (operand_shape->IsTuple()) {
2820 if (operand_shape->tuple_shapes_size() == 0) {
2821 return Unimplemented("0 element tuple AllReduce is not supported");
2822 }
2823 for (int64_t i = 0; i < operand_shape->tuple_shapes_size(); ++i) {
2824 if (operand_shape->tuple_shapes(i).element_type() !=
2825 operand_shape->tuple_shapes(0).element_type()) {
2826 return Unimplemented(
2827 "All the shapes of a tuple input of AllReduce must have the same "
2828 "element type");
2829 }
2830 operand_shapes.push_back(&operand_shape->tuple_shapes(i));
2831 operands.push_back(GetTupleElement(operand, i));
2832 }
2833 } else {
2834 operand_shapes.push_back(operand_shape);
2835 operands.push_back(operand);
2836 }
2837
2838 TF_ASSIGN_OR_RETURN(Shape inferred_shape,
2839 ShapeInference::InferAllReduceShape(operand_shapes));
2840 if (shape_with_layout) {
2841 if (!LayoutUtil::HasLayout(*shape_with_layout)) {
2842 return InvalidArgument("shape_with_layout must have the layout set: %s",
2843 shape_with_layout->ToString());
2844 }
2845 if (!ShapeUtil::Compatible(*shape_with_layout, *operand_shape)) {
2846 return InvalidArgument(
2847 "Provided shape_with_layout must be compatible with the "
2848 "operand shape: %s vs %s",
2849 shape_with_layout->ToString(), operand_shape->ToString());
2850 }
2851 instr.set_constrain_layout(true);
2852 if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
2853 // For a single-element tuple, take the tuple element shape.
2854 TF_RET_CHECK(shape_with_layout->tuple_shapes_size() == 1);
2855 *instr.mutable_shape() = shape_with_layout->tuple_shapes(0).ToProto();
2856 } else {
2857 *instr.mutable_shape() = shape_with_layout->ToProto();
2858 }
2859 } else {
2860 *instr.mutable_shape() = inferred_shape.ToProto();
2861 }
2862
2863 for (const ReplicaGroup& group : replica_groups) {
2864 *instr.add_replica_groups() = group;
2865 }
2866
2867 if (channel_id.has_value()) {
2868 instr.set_channel_id(channel_id->handle());
2869 }
2870
2871 AddCalledComputation(computation, &instr);
2872
2873 TF_ASSIGN_OR_RETURN(
2874 auto all_reduce,
2875 AddInstruction(std::move(instr), HloOpcode::kAllReduce, operands));
2876 if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
2877 // For a single-element tuple, wrap the result into a tuple.
2878 TF_RET_CHECK(operand_shapes.size() == 1);
2879 TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], inferred_shape));
2880 return Tuple({all_reduce});
2881 }
2882 return all_reduce;
2883 });
2884 }
2885
ReduceScatter(XlaOp operand,const XlaComputation & computation,int64_t scatter_dimension,int64_t shard_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Layout> & layout,const absl::optional<bool> use_global_device_ids)2886 XlaOp XlaBuilder::ReduceScatter(
2887 XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension,
2888 int64_t shard_count, absl::Span<const ReplicaGroup> replica_groups,
2889 const absl::optional<ChannelHandle>& channel_id,
2890 const absl::optional<Layout>& layout,
2891 const absl::optional<bool> use_global_device_ids) {
2892 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2893 HloInstructionProto instr;
2894 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2895 std::vector<const Shape*> operand_shapes;
2896 std::vector<XlaOp> operands;
2897 if (operand_shape->IsTuple()) {
2898 if (operand_shape->tuple_shapes_size() == 0) {
2899 return Unimplemented("0 element tuple ReduceScatter is not supported");
2900 }
2901 for (int64_t i = 0; i < operand_shape->tuple_shapes_size(); ++i) {
2902 if (operand_shape->tuple_shapes(i).element_type() !=
2903 operand_shape->tuple_shapes(0).element_type()) {
2904 return Unimplemented(
2905 "All the shapes of a tuple input of ReduceScatter must have "
2906 "the same "
2907 "element type");
2908 }
2909 operand_shapes.push_back(&operand_shape->tuple_shapes(i));
2910 operands.push_back(GetTupleElement(operand, i));
2911 }
2912 } else {
2913 operand_shapes.push_back(operand_shape);
2914 operands.push_back(operand);
2915 }
2916
2917 TF_ASSIGN_OR_RETURN(Shape inferred_shape,
2918 ShapeInference::InferReduceScatterShape(
2919 operand_shapes, scatter_dimension, shard_count));
2920 if (layout) {
2921 *inferred_shape.mutable_layout() = *layout;
2922 instr.set_constrain_layout(true);
2923 }
2924 *instr.mutable_shape() = inferred_shape.ToProto();
2925
2926 AddCalledComputation(computation, &instr);
2927
2928 instr.add_dimensions(scatter_dimension);
2929 for (const ReplicaGroup& group : replica_groups) {
2930 *instr.add_replica_groups() = group;
2931 }
2932 if (channel_id.has_value()) {
2933 instr.set_channel_id(channel_id->handle());
2934 }
2935 if (use_global_device_ids.has_value()) {
2936 instr.set_use_global_device_ids(use_global_device_ids.value());
2937 }
2938
2939 TF_ASSIGN_OR_RETURN(
2940 auto reduce_scatter,
2941 AddInstruction(std::move(instr), HloOpcode::kReduceScatter, {operand}));
2942 return reduce_scatter;
2943 });
2944 }
2945
AllToAll(XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<Layout> & layout)2946 XlaOp XlaBuilder::AllToAll(XlaOp operand, int64_t split_dimension,
2947 int64_t concat_dimension, int64_t split_count,
2948 absl::Span<const ReplicaGroup> replica_groups,
2949 const absl::optional<Layout>& layout) {
2950 // Array all_to_all may need to violate layout constraint to be legal so use
2951 // the tuple version.
2952 if (layout.has_value()) {
2953 return AllToAllTuple(operand, split_dimension, concat_dimension,
2954 split_count, replica_groups, layout);
2955 }
2956 return AllToAllArray(operand, split_dimension, concat_dimension, split_count,
2957 replica_groups);
2958 }
2959
AllToAllArray(XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups)2960 XlaOp XlaBuilder::AllToAllArray(XlaOp operand, int64_t split_dimension,
2961 int64_t concat_dimension, int64_t split_count,
2962 absl::Span<const ReplicaGroup> replica_groups) {
2963 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2964 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2965 TF_ASSIGN_OR_RETURN(
2966 const Shape all_to_all_shape,
2967 ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
2968 concat_dimension, split_count));
2969 HloInstructionProto instr;
2970 *instr.mutable_shape() = operand_shape->ToProto();
2971 if (replica_groups.empty()) {
2972 auto* group = instr.add_replica_groups();
2973 for (int64_t i = 0; i < split_count; ++i) {
2974 group->add_replica_ids(i);
2975 }
2976 } else {
2977 for (const ReplicaGroup& group : replica_groups) {
2978 *instr.add_replica_groups() = group;
2979 }
2980 }
2981 instr.add_dimensions(split_dimension);
2982 TF_ASSIGN_OR_RETURN(
2983 XlaOp all_to_all,
2984 AddInstruction(std::move(instr), HloOpcode::kAllToAll, {operand}));
2985 if (split_dimension == concat_dimension) {
2986 return all_to_all;
2987 }
2988 DimensionVector sizes;
2989 for (int64_t i = 0; i < operand_shape->rank(); ++i) {
2990 if (i != split_dimension) {
2991 sizes.push_back(operand_shape->dimensions(i));
2992 continue;
2993 }
2994 sizes.push_back(split_count);
2995 sizes.push_back(operand_shape->dimensions(i) / split_count);
2996 }
2997 all_to_all = Reshape(all_to_all, sizes);
2998
2999 std::vector<int64> permutation;
3000 for (int64_t i = 0; i < operand_shape->rank(); ++i) {
3001 int64_t dim_after_reshape = i >= split_dimension ? i + 1 : i;
3002 if (i == concat_dimension) {
3003 permutation.push_back(split_dimension);
3004 }
3005 permutation.push_back(dim_after_reshape);
3006 }
3007 all_to_all = Transpose(all_to_all, permutation);
3008 return Reshape(all_to_all_shape, all_to_all);
3009 });
3010 }
3011
AllToAllTuple(XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<Layout> & layout)3012 XlaOp XlaBuilder::AllToAllTuple(XlaOp operand, int64_t split_dimension,
3013 int64_t concat_dimension, int64_t split_count,
3014 absl::Span<const ReplicaGroup> replica_groups,
3015 const absl::optional<Layout>& layout) {
3016 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3017 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3018
3019 // The HloInstruction for Alltoall currently only handles the data
3020 // communication: it accepts N already split parts and scatters them to N
3021 // cores, and each core gathers the N received parts into a tuple as the
3022 // output. So here we explicitly split the operand before the hlo alltoall,
3023 // and concat the tuple elements.
3024 //
3025 // First, run shape inference to make sure the shapes are valid.
3026 TF_RETURN_IF_ERROR(
3027 ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
3028 concat_dimension, split_count)
3029 .status());
3030
3031 // Split into N parts.
3032 std::vector<XlaOp> slices;
3033 slices.reserve(split_count);
3034 const int64_t block_size =
3035 operand_shape->dimensions(split_dimension) / split_count;
3036 for (int i = 0; i < split_count; i++) {
3037 slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size,
3038 /*limit_index=*/(i + 1) * block_size,
3039 /*stride=*/1, /*dimno=*/split_dimension));
3040 }
3041
3042 // Handle data communication.
3043 HloInstructionProto instr;
3044 TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices));
3045 std::vector<const Shape*> slice_shape_ptrs;
3046 absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs),
3047 [](const Shape& shape) { return &shape; });
3048 TF_ASSIGN_OR_RETURN(
3049 Shape shape, ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
3050
3051 if (layout) {
3052 TF_RET_CHECK(shape.IsTuple() && !ShapeUtil::IsNestedTuple(shape));
3053 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
3054 const int64_t layout_minor_to_major_size =
3055 layout->minor_to_major().size();
3056 if (layout_minor_to_major_size != shape.tuple_shapes(i).rank()) {
3057 return InvalidArgument(
3058 "Provided layout must be compatible with the operand shape: %s "
3059 "vs %s",
3060 layout->ToString(), operand_shape->ToString());
3061 }
3062 *(shape.mutable_tuple_shapes(i)->mutable_layout()) = *layout;
3063 }
3064 instr.set_constrain_layout(true);
3065 }
3066 *instr.mutable_shape() = shape.ToProto();
3067
3068 for (const ReplicaGroup& group : replica_groups) {
3069 *instr.add_replica_groups() = group;
3070 }
3071 TF_ASSIGN_OR_RETURN(
3072 XlaOp alltoall,
3073 AddInstruction(std::move(instr), HloOpcode::kAllToAll, slices));
3074
3075 // Concat the N received parts.
3076 std::vector<XlaOp> received;
3077 received.reserve(split_count);
3078 for (int i = 0; i < split_count; i++) {
3079 received.push_back(this->GetTupleElement(alltoall, i));
3080 }
3081 return this->ConcatInDim(received, concat_dimension);
3082 });
3083 }
3084
CollectivePermute(XlaOp operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)3085 XlaOp XlaBuilder::CollectivePermute(
3086 XlaOp operand,
3087 const std::vector<std::pair<int64, int64>>& source_target_pairs) {
3088 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3089 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3090 HloInstructionProto instr;
3091 TF_ASSIGN_OR_RETURN(
3092 Shape shape,
3093 ShapeInference::InferCollectivePermuteShape({operand_shape}));
3094 *instr.mutable_shape() = shape.ToProto();
3095
3096 for (const auto& pair : source_target_pairs) {
3097 auto* proto_pair = instr.add_source_target_pairs();
3098 proto_pair->set_source(pair.first);
3099 proto_pair->set_target(pair.second);
3100 }
3101
3102 return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute,
3103 {operand});
3104 });
3105 }
3106
ReplicaId()3107 XlaOp XlaBuilder::ReplicaId() {
3108 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3109 HloInstructionProto instr;
3110 *instr.mutable_shape() = ShapeUtil::MakeShape(U32, {}).ToProto();
3111 return AddInstruction(std::move(instr), HloOpcode::kReplicaId, {});
3112 });
3113 }
3114
SelectAndScatter(XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)3115 XlaOp XlaBuilder::SelectAndScatter(XlaOp operand, const XlaComputation& select,
3116 absl::Span<const int64> window_dimensions,
3117 absl::Span<const int64> window_strides,
3118 Padding padding, XlaOp source,
3119 XlaOp init_value,
3120 const XlaComputation& scatter) {
3121 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3122 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3123
3124 std::vector<std::pair<int64, int64>> padding_values =
3125 MakePadding(AsInt64Slice(operand_shape->dimensions()),
3126 window_dimensions, window_strides, padding);
3127
3128 TF_ASSIGN_OR_RETURN(auto window,
3129 ShapeInference::InferWindowFromDimensions(
3130 window_dimensions, window_strides, padding_values,
3131 /*lhs_dilation=*/{},
3132 /*rhs_dilation=*/{}));
3133 PaddingType padding_type = PADDING_INVALID;
3134 for (int64_t i = 0; i < operand_shape->rank(); ++i) {
3135 if (operand_shape->is_dynamic_dimension(i) &&
3136 !window_util::IsTrivialWindowDimension(window.dimensions(i)) &&
3137 padding == Padding::kSame) {
3138 // SAME padding can create dynamic padding sizes. The padding size
3139 // need to be rewritten by dynamic padder using HloInstructions. We
3140 // create a CustomCall to handle this.
3141 padding_type = PADDING_SAME;
3142 }
3143 }
3144 if (padding_type == PADDING_SAME) {
3145 TF_ASSIGN_OR_RETURN(
3146 HloInstructionProto instr,
3147 SelectAndScatterInternal(operand, select, window_dimensions,
3148 window_strides, padding_values, source,
3149 init_value, scatter));
3150 instr.set_custom_call_target("DynamicSelectAndScatterSamePadding");
3151 return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
3152 {operand, source, init_value});
3153 }
3154 return SelectAndScatterWithGeneralPadding(
3155 operand, select, window_dimensions, window_strides, padding_values,
3156 source, init_value, scatter);
3157 });
3158 }
3159
SelectAndScatterInternal(XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)3160 StatusOr<HloInstructionProto> XlaBuilder::SelectAndScatterInternal(
3161 XlaOp operand, const XlaComputation& select,
3162 absl::Span<const int64> window_dimensions,
3163 absl::Span<const int64> window_strides,
3164 absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
3165 XlaOp init_value, const XlaComputation& scatter) {
3166 HloInstructionProto instr;
3167
3168 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3169 TF_ASSIGN_OR_RETURN(const Shape* source_shape, GetShapePtr(source));
3170 TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
3171 TF_ASSIGN_OR_RETURN(const ProgramShape& select_shape,
3172 select.GetProgramShape());
3173 TF_ASSIGN_OR_RETURN(const ProgramShape& scatter_shape,
3174 scatter.GetProgramShape());
3175 TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
3176 ShapeInference::InferWindowFromDimensions(
3177 window_dimensions, window_strides, padding,
3178 /*lhs_dilation=*/{}, /*rhs_dilation=*/{}));
3179 TF_ASSIGN_OR_RETURN(Shape shape,
3180 ShapeInference::InferSelectAndScatterShape(
3181 *operand_shape, select_shape, instr.window(),
3182 *source_shape, *init_shape, scatter_shape));
3183 *instr.mutable_shape() = shape.ToProto();
3184
3185 AddCalledComputation(select, &instr);
3186 AddCalledComputation(scatter, &instr);
3187 return instr;
3188 }
3189
SelectAndScatterWithGeneralPadding(XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)3190 XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
3191 XlaOp operand, const XlaComputation& select,
3192 absl::Span<const int64> window_dimensions,
3193 absl::Span<const int64> window_strides,
3194 absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
3195 XlaOp init_value, const XlaComputation& scatter) {
3196 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3197 TF_ASSIGN_OR_RETURN(HloInstructionProto instr,
3198 SelectAndScatterInternal(
3199 operand, select, window_dimensions, window_strides,
3200 padding, source, init_value, scatter));
3201
3202 return AddInstruction(std::move(instr), HloOpcode::kSelectAndScatter,
3203 {operand, source, init_value});
3204 });
3205 }
3206
ReducePrecision(XlaOp operand,const int exponent_bits,const int mantissa_bits)3207 XlaOp XlaBuilder::ReducePrecision(XlaOp operand, const int exponent_bits,
3208 const int mantissa_bits) {
3209 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3210 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3211 TF_ASSIGN_OR_RETURN(Shape shape,
3212 ShapeInference::InferReducePrecisionShape(
3213 *operand_shape, exponent_bits, mantissa_bits));
3214 return ReducePrecisionInternal(shape, operand, exponent_bits,
3215 mantissa_bits);
3216 });
3217 }
3218
ReducePrecisionInternal(const Shape & shape,XlaOp operand,const int exponent_bits,const int mantissa_bits)3219 StatusOr<XlaOp> XlaBuilder::ReducePrecisionInternal(const Shape& shape,
3220 XlaOp operand,
3221 const int exponent_bits,
3222 const int mantissa_bits) {
3223 HloInstructionProto instr;
3224 *instr.mutable_shape() = shape.ToProto();
3225 instr.set_exponent_bits(exponent_bits);
3226 instr.set_mantissa_bits(mantissa_bits);
3227 return AddInstruction(std::move(instr), HloOpcode::kReducePrecision,
3228 {operand});
3229 }
3230
Send(XlaOp operand,const ChannelHandle & handle)3231 void XlaBuilder::Send(XlaOp operand, const ChannelHandle& handle) {
3232 ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3233 // Send HLO takes two operands: a data operand and a token. Generate the
3234 // token to pass into the send.
3235 // TODO(b/80000000): Remove this when clients have been updated to handle
3236 // tokens.
3237 HloInstructionProto token_instr;
3238 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3239 TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
3240 HloOpcode::kAfterAll, {}));
3241
3242 return SendWithToken(operand, token, handle);
3243 });
3244 }
3245
SendWithToken(XlaOp operand,XlaOp token,const ChannelHandle & handle)3246 XlaOp XlaBuilder::SendWithToken(XlaOp operand, XlaOp token,
3247 const ChannelHandle& handle) {
3248 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3249 if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
3250 return InvalidArgument("Send must use a device-to-device channel");
3251 }
3252
3253 // Send instruction produces a tuple of {aliased operand, U32 context,
3254 // token}.
3255 HloInstructionProto send_instr;
3256 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
3257 *send_instr.mutable_shape() =
3258 ShapeUtil::MakeTupleShape({*shape, ShapeUtil::MakeShape(U32, {}),
3259 ShapeUtil::MakeTokenShape()})
3260 .ToProto();
3261 send_instr.set_channel_id(handle.handle());
3262 TF_ASSIGN_OR_RETURN(XlaOp send,
3263 AddInstruction(std::move(send_instr), HloOpcode::kSend,
3264 {operand, token}));
3265
3266 HloInstructionProto send_done_instr;
3267 *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3268 send_done_instr.set_channel_id(handle.handle());
3269 return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
3270 {send});
3271 });
3272 }
3273
Recv(const Shape & shape,const ChannelHandle & handle)3274 XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
3275 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3276 // Recv HLO takes a single token operand. Generate the token to pass into
3277 // the Recv and RecvDone instructions.
3278 // TODO(b/80000000): Remove this when clients have been updated to handle
3279 // tokens.
3280 HloInstructionProto token_instr;
3281 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3282 TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
3283 HloOpcode::kAfterAll, {}));
3284
3285 XlaOp recv = RecvWithToken(token, shape, handle);
3286
3287 // The RecvDone instruction produces a tuple of the data and a token
3288 // type. Return XLA op containing the data.
3289 // TODO(b/80000000): Remove this when clients have been updated to handle
3290 // tokens.
3291 HloInstructionProto recv_data;
3292 *recv_data.mutable_shape() = shape.ToProto();
3293 recv_data.set_tuple_index(0);
3294 return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement,
3295 {recv});
3296 });
3297 }
3298
RecvWithToken(XlaOp token,const Shape & shape,const ChannelHandle & handle)3299 XlaOp XlaBuilder::RecvWithToken(XlaOp token, const Shape& shape,
3300 const ChannelHandle& handle) {
3301 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3302 if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
3303 return InvalidArgument("Recv must use a device-to-device channel");
3304 }
3305
3306 // Recv instruction produces a tuple of {receive buffer, U32 context,
3307 // token}.
3308 HloInstructionProto recv_instr;
3309 *recv_instr.mutable_shape() =
3310 ShapeUtil::MakeTupleShape(
3311 {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
3312 .ToProto();
3313 recv_instr.set_channel_id(handle.handle());
3314 TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
3315 HloOpcode::kRecv, {token}));
3316
3317 HloInstructionProto recv_done_instr;
3318 *recv_done_instr.mutable_shape() =
3319 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
3320 .ToProto();
3321 recv_done_instr.set_channel_id(handle.handle());
3322 return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
3323 {recv});
3324 });
3325 }
3326
SendToHost(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const ChannelHandle & handle)3327 XlaOp XlaBuilder::SendToHost(XlaOp operand, XlaOp token,
3328 const Shape& shape_with_layout,
3329 const ChannelHandle& handle) {
3330 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3331 if (!LayoutUtil::HasLayout(shape_with_layout)) {
3332 return InvalidArgument("Shape passed to SendToHost must have a layout");
3333 }
3334 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3335 if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
3336 return InvalidArgument(
3337 "SendToHost shape %s must be compatible with operand shape %s",
3338 ShapeUtil::HumanStringWithLayout(shape_with_layout),
3339 ShapeUtil::HumanStringWithLayout(*operand_shape));
3340 }
3341 // TODO(b/111544877): Support tuple shapes.
3342 if (!operand_shape->IsArray()) {
3343 return InvalidArgument("SendToHost only supports array shapes, shape: %s",
3344 ShapeUtil::HumanString(*operand_shape));
3345 }
3346
3347 if (handle.type() != ChannelHandle::DEVICE_TO_HOST) {
3348 return InvalidArgument("SendToHost must use a device-to-host channel");
3349 }
3350
3351 // Send instruction produces a tuple of {aliased operand, U32 context,
3352 // token}.
3353 HloInstructionProto send_instr;
3354 *send_instr.mutable_shape() =
3355 ShapeUtil::MakeTupleShape({shape_with_layout,
3356 ShapeUtil::MakeShape(U32, {}),
3357 ShapeUtil::MakeTokenShape()})
3358 .ToProto();
3359 send_instr.set_channel_id(handle.handle());
3360 send_instr.set_is_host_transfer(true);
3361 TF_ASSIGN_OR_RETURN(XlaOp send,
3362 AddInstruction(std::move(send_instr), HloOpcode::kSend,
3363 {operand, token}));
3364
3365 HloInstructionProto send_done_instr;
3366 *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3367 send_done_instr.set_channel_id(handle.handle());
3368 send_done_instr.set_is_host_transfer(true);
3369 return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
3370 {send});
3371 });
3372 }
3373
RecvFromHost(XlaOp token,const Shape & shape,const ChannelHandle & handle)3374 XlaOp XlaBuilder::RecvFromHost(XlaOp token, const Shape& shape,
3375 const ChannelHandle& handle) {
3376 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3377 if (!LayoutUtil::HasLayout(shape)) {
3378 return InvalidArgument("Shape passed to RecvFromHost must have a layout");
3379 }
3380
3381 // TODO(b/111544877): Support tuple shapes.
3382 if (!shape.IsArray()) {
3383 return InvalidArgument(
3384 "RecvFromHost only supports array shapes, shape: %s",
3385 ShapeUtil::HumanString(shape));
3386 }
3387
3388 if (handle.type() != ChannelHandle::HOST_TO_DEVICE) {
3389 return InvalidArgument("RecvFromHost must use a host-to-device channel");
3390 }
3391
3392 // Recv instruction produces a tuple of {receive buffer, U32 context,
3393 // token}.
3394 HloInstructionProto recv_instr;
3395 *recv_instr.mutable_shape() =
3396 ShapeUtil::MakeTupleShape(
3397 {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
3398 .ToProto();
3399 recv_instr.set_channel_id(handle.handle());
3400 recv_instr.set_is_host_transfer(true);
3401 TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
3402 HloOpcode::kRecv, {token}));
3403
3404 HloInstructionProto recv_done_instr;
3405 *recv_done_instr.mutable_shape() =
3406 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
3407 .ToProto();
3408 recv_done_instr.set_channel_id(handle.handle());
3409 recv_done_instr.set_is_host_transfer(true);
3410 return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
3411 {recv});
3412 });
3413 }
3414
GetDimensionSize(XlaOp operand,int64_t dimension)3415 XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64_t dimension) {
3416 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3417 HloInstructionProto instr;
3418 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3419 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape(
3420 *operand_shape, dimension));
3421 // Calling GetDimensionSize on a static dimension returns a constant
3422 // instruction.
3423 if (!operand_shape->is_dynamic_dimension(dimension)) {
3424 return ConstantR0<int32>(this, operand_shape->dimensions(dimension));
3425 }
3426 *instr.mutable_shape() = shape.ToProto();
3427 instr.add_dimensions(dimension);
3428 return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize,
3429 {operand});
3430 });
3431 }
3432
RemoveDynamicDimension(XlaOp operand,int64_t dimension)3433 XlaOp XlaBuilder::RemoveDynamicDimension(XlaOp operand, int64_t dimension) {
3434 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3435 HloInstructionProto instr;
3436 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3437
3438 Shape shape = *operand_shape;
3439 shape.set_dynamic_dimension(dimension, false);
3440 // Setting an op's dynamic dimension to its static size removes the dynamic
3441 // dimension.
3442 XlaOp static_size =
3443 ConstantR0<int32>(this, operand_shape->dimensions(dimension));
3444 return SetDimensionSizeInternal(shape, operand, static_size, dimension);
3445 });
3446 }
3447
SetDimensionSize(XlaOp operand,XlaOp val,int64_t dimension)3448 XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val,
3449 int64_t dimension) {
3450 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3451 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3452 TF_ASSIGN_OR_RETURN(const Shape* val_shape, GetShapePtr(val));
3453
3454 TF_ASSIGN_OR_RETURN(Shape shape,
3455 ShapeInference::InferSetDimensionSizeShape(
3456 *operand_shape, *val_shape, dimension));
3457 return SetDimensionSizeInternal(shape, operand, val, dimension);
3458 });
3459 }
3460
SetDimensionSizeInternal(const Shape & shape,XlaOp operand,XlaOp val,int64_t dimension)3461 StatusOr<XlaOp> XlaBuilder::SetDimensionSizeInternal(const Shape& shape,
3462 XlaOp operand, XlaOp val,
3463 int64_t dimension) {
3464 TF_ASSIGN_OR_RETURN(const HloInstructionProto* val_proto,
3465 LookUpInstruction(val));
3466 if (StringToHloOpcode(val_proto->opcode()).ValueOrDie() ==
3467 HloOpcode::kConstant &&
3468 shape.is_dynamic_dimension(dimension)) {
3469 TF_ASSIGN_OR_RETURN(auto constant_size,
3470 Literal::CreateFromProto(val_proto->literal(), true));
3471 if (constant_size.Get<int32>({}) == shape.dimensions(dimension)) {
3472 return operand;
3473 }
3474 }
3475
3476 HloInstructionProto instr;
3477 *instr.mutable_shape() = shape.ToProto();
3478 instr.add_dimensions(dimension);
3479 return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize,
3480 {operand, val});
3481 }
3482
IsConstant(XlaOp operand) const3483 StatusOr<bool> XlaBuilder::IsConstant(XlaOp operand) const {
3484 TF_RETURN_IF_ERROR(first_error_);
3485
3486 // Verify that the handle is valid.
3487 TF_RETURN_IF_ERROR(LookUpInstruction(operand).status());
3488
3489 bool is_constant = true;
3490 absl::flat_hash_set<int64> visited;
3491 IsConstantVisitor(operand.handle(), /*depth=*/0, &visited, &is_constant);
3492 return is_constant;
3493 }
3494
BuildConstantSubGraph(XlaOp root_op,bool dynamic_dimension_is_minus_one)3495 StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
3496 XlaOp root_op, bool dynamic_dimension_is_minus_one) {
3497 TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
3498 if (!is_constant) {
3499 auto op_status = LookUpInstruction(root_op);
3500 string op_string =
3501 op_status.ok() ? op_status.ValueOrDie()->name() : "<unknown operation>";
3502 return InvalidArgument(
3503 "Operand to BuildConstantSubGraph depends on a parameter.\n\n"
3504 " op requested for constant subgraph: %s\n\n"
3505 "This is an internal error that typically happens when the XLA user "
3506 "(e.g. TensorFlow) is attempting to determine a value that must be a "
3507 "compile-time constant (e.g. an array dimension) but it is not capable "
3508 "of being evaluated at XLA compile time.\n\n"
3509 "Please file a usability bug with the framework being used (e.g. "
3510 "TensorFlow).",
3511 op_string);
3512 }
3513
3514 TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
3515 LookUpInstruction(root_op));
3516 if (VLOG_IS_ON(4)) {
3517 VLOG(4) << "Build constant subgraph for:\n" << OpToString(root_op);
3518 }
3519
3520 HloComputationProto entry;
3521 SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator,
3522 GetNextId());
3523 ProgramShapeProto* program_shape = entry.mutable_program_shape();
3524 *program_shape->mutable_result() = root->shape();
3525
3526 // We use std::set to keep the instruction ids in ascending order (which is
3527 // also a valid dependency order). The related ops will be added to the
3528 // subgraph in the same order.
3529 std::set<int64> related_ops;
3530 absl::flat_hash_map<int64, int64> substitutions;
3531 absl::flat_hash_set<int64> related_calls; // Related computations.
3532 std::queue<int64> worklist;
3533 worklist.push(root->id());
3534 related_ops.insert(root->id());
3535
3536 while (!worklist.empty()) {
3537 int64_t handle = worklist.front();
3538 worklist.pop();
3539 TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
3540 LookUpInstructionByHandle(handle));
3541
3542 auto default_behavior = [&related_ops, &worklist, &related_calls,
3543 instr_proto]() {
3544 for (int64_t id : instr_proto->operand_ids()) {
3545 if (related_ops.insert(id).second) {
3546 worklist.push(id);
3547 }
3548 }
3549 for (int64_t called_id : instr_proto->called_computation_ids()) {
3550 related_calls.insert(called_id);
3551 }
3552 };
3553
3554 if (instr_proto->opcode() ==
3555 HloOpcodeString(HloOpcode::kGetDimensionSize) ||
3556 InstrIsSetBound(instr_proto)) {
3557 int32_t constant_value = -1;
3558 HloInstructionProto const_instr;
3559
3560 if (instr_proto->opcode() ==
3561 HloOpcodeString(HloOpcode::kGetDimensionSize)) {
3562 // At this point, BuildConstantSubGraph should never encounter a
3563 // GetDimensionSize with a dynamic dimension. IsConstant check would
3564 // have failed at the beginning of this function.
3565 //
3566 // Replace GetDimensionSize with a Constant representing the static
3567 // bound of the shape.
3568 int64_t dimension = instr_proto->dimensions(0);
3569 int64_t operand_handle = instr_proto->operand_ids(0);
3570 TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
3571 LookUpInstructionByHandle(operand_handle));
3572
3573 if (!(operand_proto->shape().is_dynamic_dimension(dimension) &&
3574 dynamic_dimension_is_minus_one)) {
3575 constant_value =
3576 static_cast<int32>(operand_proto->shape().dimensions(dimension));
3577 }
3578 Literal literal = LiteralUtil::CreateR0(constant_value);
3579 *const_instr.mutable_literal() = literal.ToProto();
3580 *const_instr.mutable_shape() = literal.shape().ToProto();
3581 } else {
3582 if (instr_proto->literal().shape().element_type() == TUPLE) {
3583 *const_instr.mutable_literal() =
3584 // First literal of SetBound contains bounds, second literal
3585 // contains dynamism indicators.
3586 instr_proto->literal().tuple_literals(0);
3587 } else {
3588 *const_instr.mutable_literal() = instr_proto->literal();
3589 }
3590
3591 *const_instr.mutable_shape() = instr_proto->shape();
3592 }
3593 *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
3594 const_instr.set_id(handle);
3595 *const_instr.mutable_name() =
3596 GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id());
3597 *entry.add_instructions() =
3598 const_instr; // Add to the result constant graph.
3599
3600 } else if (instr_proto->opcode() ==
3601 HloOpcodeString(HloOpcode::kGetTupleElement)) {
3602 // Look through GTE(Tuple(..), i).
3603 TF_ASSIGN_OR_RETURN(
3604 const HloInstructionProto* maybe_tuple_instr,
3605 LookUpInstructionByHandle(instr_proto->operand_ids(0)));
3606
3607 if (maybe_tuple_instr->opcode() == HloOpcodeString(HloOpcode::kTuple)) {
3608 int64_t id = maybe_tuple_instr->operand_ids(instr_proto->tuple_index());
3609 // Enqueue any dependencies of `id`.
3610 if (related_ops.insert(id).second) {
3611 worklist.push(id);
3612 }
3613 substitutions[handle] = id;
3614
3615 } else {
3616 default_behavior();
3617 }
3618
3619 } else {
3620 default_behavior();
3621 }
3622 }
3623
3624 // Resolve any substitutions for the root id.
3625 int64_t root_id = root->id();
3626 auto it = substitutions.find(root_id);
3627 while (it != substitutions.end()) {
3628 root_id = it->second;
3629 it = substitutions.find(root_id);
3630 }
3631 entry.set_root_id(root_id);
3632
3633 // Add related ops to the computation.
3634 for (int64_t id : related_ops) {
3635 if (substitutions.find(id) != substitutions.end()) {
3636 // Skip adding this instruction; we will replace references to it with the
3637 // substitution instruction's id.
3638 continue;
3639 }
3640 TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
3641 LookUpInstructionByHandle(id));
3642
3643 if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize) ||
3644 InstrIsSetBound(instr_src)) {
3645 continue;
3646 }
3647 HloInstructionProto* instr = entry.add_instructions();
3648 *instr = *instr_src;
3649 // Replace operands in case we have substitutions mapped.
3650 instr->clear_operand_ids();
3651 for (int64_t operand_id : instr_src->operand_ids()) {
3652 auto it = substitutions.find(operand_id);
3653 while (it != substitutions.end()) {
3654 operand_id = it->second;
3655 it = substitutions.find(operand_id);
3656 }
3657 instr->add_operand_ids(operand_id);
3658 }
3659 // Ensures that the instruction names are unique among the graph.
3660 const string& new_name =
3661 StrCat(instr->name(), ".", entry.id(), ".", instr->id());
3662 instr->set_name(new_name);
3663 }
3664
3665 XlaComputation computation(entry.id());
3666 HloModuleProto* module = computation.mutable_proto();
3667 module->set_name(entry.name());
3668 module->set_id(entry.id());
3669 module->set_entry_computation_name(entry.name());
3670 module->set_entry_computation_id(entry.id());
3671 *module->mutable_host_program_shape() = *program_shape;
3672 for (auto& e : embedded_) {
3673 if (related_calls.find(e.second.id()) != related_calls.end()) {
3674 *module->add_computations() = e.second;
3675 }
3676 }
3677 *module->add_computations() = std::move(entry);
3678 if (VLOG_IS_ON(4)) {
3679 VLOG(4) << "Constant computation:\n" << module->DebugString();
3680 }
3681 return std::move(computation);
3682 }
3683
CreateSubBuilder(const string & computation_name)3684 std::unique_ptr<XlaBuilder> XlaBuilder::CreateSubBuilder(
3685 const string& computation_name) {
3686 auto sub_builder = absl::make_unique<XlaBuilder>(computation_name);
3687 sub_builder->parent_builder_ = this;
3688 sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_;
3689 return sub_builder;
3690 }
3691
3692 /* static */ ConvolutionDimensionNumbers
CreateDefaultConvDimensionNumbers(int num_spatial_dims)3693 XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
3694 ConvolutionDimensionNumbers dimension_numbers;
3695 dimension_numbers.set_input_batch_dimension(kConvBatchDimension);
3696 dimension_numbers.set_input_feature_dimension(kConvFeatureDimension);
3697 dimension_numbers.set_output_batch_dimension(kConvBatchDimension);
3698 dimension_numbers.set_output_feature_dimension(kConvFeatureDimension);
3699 dimension_numbers.set_kernel_output_feature_dimension(
3700 kConvKernelOutputDimension);
3701 dimension_numbers.set_kernel_input_feature_dimension(
3702 kConvKernelInputDimension);
3703 for (int i = 0; i < num_spatial_dims; ++i) {
3704 dimension_numbers.add_input_spatial_dimensions(i + 2);
3705 dimension_numbers.add_kernel_spatial_dimensions(i + 2);
3706 dimension_numbers.add_output_spatial_dimensions(i + 2);
3707 }
3708 return dimension_numbers;
3709 }
3710
Validate(const ConvolutionDimensionNumbers & dnum)3711 /* static */ Status XlaBuilder::Validate(
3712 const ConvolutionDimensionNumbers& dnum) {
3713 if (dnum.input_spatial_dimensions_size() < 2) {
3714 return FailedPrecondition("input spacial dimension < 2: %d",
3715 dnum.input_spatial_dimensions_size());
3716 }
3717 if (dnum.kernel_spatial_dimensions_size() < 2) {
3718 return FailedPrecondition("kernel spacial dimension < 2: %d",
3719 dnum.kernel_spatial_dimensions_size());
3720 }
3721 if (dnum.output_spatial_dimensions_size() < 2) {
3722 return FailedPrecondition("output spacial dimension < 2: %d",
3723 dnum.output_spatial_dimensions_size());
3724 }
3725
3726 if (std::set<int64>(
3727 {dnum.input_batch_dimension(), dnum.input_feature_dimension(),
3728 dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)})
3729 .size() != 4) {
3730 return FailedPrecondition(
3731 "dimension numbers for the input are not unique: (%d, %d, %d, "
3732 "%d)",
3733 dnum.input_batch_dimension(), dnum.input_feature_dimension(),
3734 dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1));
3735 }
3736 if (std::set<int64>({dnum.kernel_output_feature_dimension(),
3737 dnum.kernel_input_feature_dimension(),
3738 dnum.kernel_spatial_dimensions(0),
3739 dnum.kernel_spatial_dimensions(1)})
3740 .size() != 4) {
3741 return FailedPrecondition(
3742 "dimension numbers for the weight are not unique: (%d, %d, %d, "
3743 "%d)",
3744 dnum.kernel_output_feature_dimension(),
3745 dnum.kernel_input_feature_dimension(),
3746 dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1));
3747 }
3748 if (std::set<int64>({dnum.output_batch_dimension(),
3749 dnum.output_feature_dimension(),
3750 dnum.output_spatial_dimensions(0),
3751 dnum.output_spatial_dimensions(1)})
3752 .size() != 4) {
3753 return FailedPrecondition(
3754 "dimension numbers for the output are not unique: (%d, %d, %d, "
3755 "%d)",
3756 dnum.output_batch_dimension(), dnum.output_feature_dimension(),
3757 dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1));
3758 }
3759 return Status::OK();
3760 }
3761
AddInstruction(HloInstructionProto && instr,HloOpcode opcode,absl::Span<const XlaOp> operands)3762 StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
3763 HloOpcode opcode,
3764 absl::Span<const XlaOp> operands) {
3765 TF_RETURN_IF_ERROR(first_error_);
3766
3767 const int64_t handle = GetNextId();
3768 instr.set_id(handle);
3769 instr.set_opcode(HloOpcodeString(opcode));
3770 if (instr.name().empty()) {
3771 instr.set_name(instr.opcode());
3772 }
3773 for (const auto& operand : operands) {
3774 if (operand.builder_ == nullptr) {
3775 return InvalidArgument("invalid XlaOp with handle %d", operand.handle());
3776 }
3777 if (operand.builder_ != this) {
3778 return InvalidArgument("Do not add XlaOp from builder %s to builder %s",
3779 operand.builder_->name(), this->name());
3780 }
3781 instr.add_operand_ids(operand.handle());
3782 }
3783
3784 if (one_shot_metadata_.has_value()) {
3785 *instr.mutable_metadata() = one_shot_metadata_.value();
3786 one_shot_metadata_.reset();
3787 } else {
3788 *instr.mutable_metadata() = metadata_;
3789 }
3790 if (sharding_) {
3791 *instr.mutable_sharding() = *sharding_;
3792 }
3793 *instr.mutable_frontend_attributes() = frontend_attributes_;
3794
3795 handle_to_index_[handle] = instructions_.size();
3796 instructions_.push_back(std::move(instr));
3797 instruction_shapes_.push_back(
3798 absl::make_unique<Shape>(instructions_.back().shape()));
3799
3800 XlaOp op(handle, this);
3801 return op;
3802 }
3803
AddOpWithShape(HloOpcode opcode,const Shape & shape,absl::Span<const XlaOp> operands)3804 StatusOr<XlaOp> XlaBuilder::AddOpWithShape(HloOpcode opcode, const Shape& shape,
3805 absl::Span<const XlaOp> operands) {
3806 HloInstructionProto instr;
3807 *instr.mutable_shape() = shape.ToProto();
3808 return AddInstruction(std::move(instr), opcode, operands);
3809 }
3810
AddCalledComputation(const XlaComputation & computation,HloInstructionProto * instr)3811 void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
3812 HloInstructionProto* instr) {
3813 absl::flat_hash_map<int64, int64> remapped_ids;
3814 std::vector<HloComputationProto> imported_computations;
3815 imported_computations.reserve(computation.proto().computations_size());
3816 // Before we import the computations by remapping IDs, and capturing the
3817 // old->new mappings in remapped_ids.
3818 for (const HloComputationProto& e : computation.proto().computations()) {
3819 HloComputationProto new_computation(e);
3820 int64_t computation_id = GetNextId();
3821 remapped_ids[new_computation.id()] = computation_id;
3822 SetProtoIdAndName(&new_computation,
3823 GetBaseName(new_computation.name(), kNameSeparator),
3824 kNameSeparator, computation_id);
3825 for (auto& instruction : *new_computation.mutable_instructions()) {
3826 int64_t instruction_id = GetNextId();
3827 remapped_ids[instruction.id()] = instruction_id;
3828 SetProtoIdAndName(&instruction,
3829 GetBaseName(instruction.name(), kNameSeparator),
3830 kNameSeparator, instruction_id);
3831 }
3832 new_computation.set_root_id(remapped_ids.at(new_computation.root_id()));
3833
3834 imported_computations.push_back(std::move(new_computation));
3835 }
3836 // Once we have imported all the computations, and captured all the ID
3837 // mappings, we go back and fixup the IDs in the imported computations.
3838 instr->add_called_computation_ids(
3839 remapped_ids.at(computation.proto().entry_computation_id()));
3840 for (auto& imported_computation : imported_computations) {
3841 for (auto& instruction : *imported_computation.mutable_instructions()) {
3842 for (auto& operand_id : *instruction.mutable_operand_ids()) {
3843 operand_id = remapped_ids.at(operand_id);
3844 }
3845 for (auto& control_predecessor_id :
3846 *instruction.mutable_control_predecessor_ids()) {
3847 control_predecessor_id = remapped_ids.at(control_predecessor_id);
3848 }
3849 for (auto& called_computation_id :
3850 *instruction.mutable_called_computation_ids()) {
3851 called_computation_id = remapped_ids.at(called_computation_id);
3852 }
3853 }
3854
3855 int64_t computation_id = imported_computation.id();
3856 for (int64_t i = 0; i < imported_computation.instructions_size(); ++i) {
3857 ImportedInstruction imported_instruction;
3858 imported_instruction.computation_id = computation_id;
3859 imported_instruction.instruction_index = i;
3860 handle_to_imported_index_.insert(
3861 {imported_computation.instructions(i).id(), imported_instruction});
3862 }
3863 embedded_.insert({computation_id, std::move(imported_computation)});
3864 }
3865 }
3866
LookUpInstruction(const XlaOp op) const3867 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
3868 const XlaOp op) const {
3869 TF_RETURN_IF_ERROR(first_error_);
3870 return LookUpInstructionInternal<const HloInstructionProto*>(op);
3871 }
3872
LookUpInstructionByHandle(int64_t handle) const3873 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
3874 int64_t handle) const {
3875 return LookUpInstructionByHandleInternal<const HloInstructionProto*>(handle);
3876 }
3877
LookUpMutableInstruction(const XlaOp op)3878 StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstruction(
3879 const XlaOp op) {
3880 TF_RETURN_IF_ERROR(first_error_);
3881 return LookUpInstructionInternal<HloInstructionProto*>(op);
3882 }
3883
LookUpMutableInstructionByHandle(int64_t handle)3884 StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstructionByHandle(
3885 int64_t handle) {
3886 return LookUpInstructionByHandleInternal<HloInstructionProto*>(handle);
3887 }
3888
3889 // Enqueues a "retrieve parameter value" instruction for a parameter that was
3890 // passed to the computation.
Parameter(XlaBuilder * builder,int64_t parameter_number,const Shape & shape,const string & name)3891 XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
3892 const Shape& shape, const string& name) {
3893 std::vector<bool> empty_bools;
3894 return Parameter(builder, parameter_number, shape, name, empty_bools);
3895 }
3896
Parameter(XlaBuilder * builder,int64_t parameter_number,const Shape & shape,const string & name,const std::vector<bool> & replicated_at_leaf_buffers)3897 XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
3898 const Shape& shape, const string& name,
3899 const std::vector<bool>& replicated_at_leaf_buffers) {
3900 return builder->Parameter(parameter_number, shape, name,
3901 replicated_at_leaf_buffers);
3902 }
3903
3904 // Enqueues a constant with the value of the given literal onto the
3905 // computation.
ConstantLiteral(XlaBuilder * builder,const LiteralSlice & literal)3906 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) {
3907 return builder->ConstantLiteral(literal);
3908 }
3909
Broadcast(const XlaOp operand,absl::Span<const int64> broadcast_sizes)3910 XlaOp Broadcast(const XlaOp operand, absl::Span<const int64> broadcast_sizes) {
3911 return operand.builder()->Broadcast(operand, broadcast_sizes);
3912 }
3913
BroadcastInDim(const XlaOp operand,const absl::Span<const int64> out_dim_size,const absl::Span<const int64> broadcast_dimensions)3914 XlaOp BroadcastInDim(const XlaOp operand,
3915 const absl::Span<const int64> out_dim_size,
3916 const absl::Span<const int64> broadcast_dimensions) {
3917 return operand.builder()->BroadcastInDim(operand, out_dim_size,
3918 broadcast_dimensions);
3919 }
3920
Copy(const XlaOp operand)3921 XlaOp Copy(const XlaOp operand) {
3922 return operand.builder()->UnaryOp(HloOpcode::kCopy, operand);
3923 }
3924
Pad(const XlaOp operand,const XlaOp padding_value,const PaddingConfig & padding_config)3925 XlaOp Pad(const XlaOp operand, const XlaOp padding_value,
3926 const PaddingConfig& padding_config) {
3927 return operand.builder()->Pad(operand, padding_value, padding_config);
3928 }
3929
PadInDim(XlaOp operand,XlaOp padding_value,int64_t dimno,int64_t pad_lo,int64_t pad_hi)3930 XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
3931 int64_t pad_lo, int64_t pad_hi) {
3932 return operand.builder()->PadInDim(operand, padding_value, dimno, pad_lo,
3933 pad_hi);
3934 }
3935
Reshape(const XlaOp operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes)3936 XlaOp Reshape(const XlaOp operand, absl::Span<const int64> dimensions,
3937 absl::Span<const int64> new_sizes) {
3938 return operand.builder()->Reshape(operand, dimensions, new_sizes);
3939 }
3940
Reshape(const XlaOp operand,absl::Span<const int64> new_sizes)3941 XlaOp Reshape(const XlaOp operand, absl::Span<const int64> new_sizes) {
3942 return operand.builder()->Reshape(operand, new_sizes);
3943 }
3944
Reshape(const Shape & shape,XlaOp operand)3945 XlaOp Reshape(const Shape& shape, XlaOp operand) {
3946 return operand.builder()->Reshape(shape, operand);
3947 }
3948
DynamicReshape(XlaOp operand,absl::Span<const XlaOp> dim_sizes,absl::Span<const int64> new_size_bounds,const std::vector<bool> & dims_are_dynamic)3949 XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
3950 absl::Span<const int64> new_size_bounds,
3951 const std::vector<bool>& dims_are_dynamic) {
3952 return operand.builder()->DynamicReshape(operand, dim_sizes, new_size_bounds,
3953 dims_are_dynamic);
3954 }
3955
ReshapeWithInferredDimension(XlaOp operand,absl::Span<const int64> new_sizes,int64_t inferred_dimension)3956 XlaOp ReshapeWithInferredDimension(XlaOp operand,
3957 absl::Span<const int64> new_sizes,
3958 int64_t inferred_dimension) {
3959 return operand.builder()->Reshape(operand, new_sizes, inferred_dimension);
3960 }
3961
Collapse(const XlaOp operand,absl::Span<const int64> dimensions)3962 XlaOp Collapse(const XlaOp operand, absl::Span<const int64> dimensions) {
3963 return operand.builder()->Collapse(operand, dimensions);
3964 }
3965
Slice(const XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)3966 XlaOp Slice(const XlaOp operand, absl::Span<const int64> start_indices,
3967 absl::Span<const int64> limit_indices,
3968 absl::Span<const int64> strides) {
3969 return operand.builder()->Slice(operand, start_indices, limit_indices,
3970 strides);
3971 }
3972
SliceInDim(const XlaOp operand,int64_t start_index,int64_t limit_index,int64_t stride,int64_t dimno)3973 XlaOp SliceInDim(const XlaOp operand, int64_t start_index, int64_t limit_index,
3974 int64_t stride, int64_t dimno) {
3975 return operand.builder()->SliceInDim(operand, start_index, limit_index,
3976 stride, dimno);
3977 }
3978
DynamicSlice(const XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)3979 XlaOp DynamicSlice(const XlaOp operand, absl::Span<const XlaOp> start_indices,
3980 absl::Span<const int64> slice_sizes) {
3981 return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
3982 }
3983
DynamicUpdateSlice(const XlaOp operand,const XlaOp update,absl::Span<const XlaOp> start_indices)3984 XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update,
3985 absl::Span<const XlaOp> start_indices) {
3986 return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
3987 }
3988
ConcatInDim(XlaBuilder * builder,absl::Span<const XlaOp> operands,int64_t dimension)3989 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
3990 int64_t dimension) {
3991 return builder->ConcatInDim(operands, dimension);
3992 }
3993
Trace(const string & tag,const XlaOp operand)3994 void Trace(const string& tag, const XlaOp operand) {
3995 return operand.builder()->Trace(tag, operand);
3996 }
3997
Select(const XlaOp pred,const XlaOp on_true,const XlaOp on_false)3998 XlaOp Select(const XlaOp pred, const XlaOp on_true, const XlaOp on_false) {
3999 return pred.builder()->Select(pred, on_true, on_false);
4000 }
4001
Tuple(XlaBuilder * builder,absl::Span<const XlaOp> elements)4002 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements) {
4003 return builder->Tuple(elements);
4004 }
4005
GetTupleElement(const XlaOp tuple_data,int64_t index)4006 XlaOp GetTupleElement(const XlaOp tuple_data, int64_t index) {
4007 return tuple_data.builder()->GetTupleElement(tuple_data, index);
4008 }
4009
Eq(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4010 XlaOp Eq(const XlaOp lhs, const XlaOp rhs,
4011 absl::Span<const int64> broadcast_dimensions) {
4012 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
4013 }
4014
CompareTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions,ComparisonDirection comparison_direction)4015 static XlaOp CompareTotalOrder(const XlaOp lhs, const XlaOp rhs,
4016 absl::Span<const int64> broadcast_dimensions,
4017 ComparisonDirection comparison_direction) {
4018 auto b = lhs.builder();
4019 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
4020 TF_ASSIGN_OR_RETURN(auto operand_shape, b->GetShape(lhs));
4021 auto operand_element_type = operand_shape.element_type();
4022 auto compare_type =
4023 primitive_util::IsFloatingPointType(operand_element_type)
4024 ? Comparison::Type::kFloatTotalOrder
4025 : Comparison::DefaultComparisonType(operand_element_type);
4026 return Compare(lhs, rhs, broadcast_dimensions, comparison_direction,
4027 compare_type);
4028 });
4029 }
4030
EqTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4031 XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs,
4032 absl::Span<const int64> broadcast_dimensions) {
4033 return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4034 ComparisonDirection::kEq);
4035 }
4036
Ne(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4037 XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
4038 absl::Span<const int64> broadcast_dimensions) {
4039 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe);
4040 }
4041
NeTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4042 XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs,
4043 absl::Span<const int64> broadcast_dimensions) {
4044 return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4045 ComparisonDirection::kNe);
4046 }
4047
Ge(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4048 XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
4049 absl::Span<const int64> broadcast_dimensions) {
4050 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe);
4051 }
4052
GeTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4053 XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs,
4054 absl::Span<const int64> broadcast_dimensions) {
4055 return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4056 ComparisonDirection::kGe);
4057 }
4058
Gt(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4059 XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
4060 absl::Span<const int64> broadcast_dimensions) {
4061 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt);
4062 }
4063
GtTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4064 XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs,
4065 absl::Span<const int64> broadcast_dimensions) {
4066 return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4067 ComparisonDirection::kGt);
4068 }
4069
Le(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4070 XlaOp Le(const XlaOp lhs, const XlaOp rhs,
4071 absl::Span<const int64> broadcast_dimensions) {
4072 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe);
4073 }
4074
LeTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4075 XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs,
4076 absl::Span<const int64> broadcast_dimensions) {
4077 return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4078 ComparisonDirection::kLe);
4079 }
4080
Lt(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4081 XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
4082 absl::Span<const int64> broadcast_dimensions) {
4083 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
4084 }
4085
LtTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4086 XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs,
4087 absl::Span<const int64> broadcast_dimensions) {
4088 return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4089 ComparisonDirection::kLt);
4090 }
4091
Compare(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions,ComparisonDirection direction)4092 XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
4093 absl::Span<const int64> broadcast_dimensions,
4094 ComparisonDirection direction) {
4095 return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
4096 broadcast_dimensions, direction);
4097 }
4098
Compare(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions,ComparisonDirection direction,Comparison::Type compare_type)4099 XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
4100 absl::Span<const int64> broadcast_dimensions,
4101 ComparisonDirection direction, Comparison::Type compare_type) {
4102 return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
4103 broadcast_dimensions, direction, compare_type);
4104 }
4105
Compare(const XlaOp lhs,const XlaOp rhs,ComparisonDirection direction)4106 XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) {
4107 return Compare(lhs, rhs, {}, direction);
4108 }
4109
Dot(const XlaOp lhs,const XlaOp rhs,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4110 XlaOp Dot(const XlaOp lhs, const XlaOp rhs,
4111 const PrecisionConfig* precision_config,
4112 absl::optional<PrimitiveType> preferred_element_type) {
4113 return lhs.builder()->Dot(lhs, rhs, precision_config, preferred_element_type);
4114 }
4115
DotGeneral(const XlaOp lhs,const XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4116 XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs,
4117 const DotDimensionNumbers& dimension_numbers,
4118 const PrecisionConfig* precision_config,
4119 absl::optional<PrimitiveType> preferred_element_type) {
4120 return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
4121 precision_config, preferred_element_type);
4122 }
4123
Conv(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4124 XlaOp Conv(const XlaOp lhs, const XlaOp rhs,
4125 absl::Span<const int64> window_strides, Padding padding,
4126 int64_t feature_group_count, int64_t batch_group_count,
4127 const PrecisionConfig* precision_config,
4128 absl::optional<PrimitiveType> preferred_element_type) {
4129 return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
4130 feature_group_count, batch_group_count,
4131 precision_config, preferred_element_type);
4132 }
4133
ConvWithGeneralPadding(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4134 XlaOp ConvWithGeneralPadding(
4135 const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
4136 absl::Span<const std::pair<int64, int64>> padding,
4137 int64_t feature_group_count, int64_t batch_group_count,
4138 const PrecisionConfig* precision_config,
4139 absl::optional<PrimitiveType> preferred_element_type) {
4140 return lhs.builder()->ConvWithGeneralPadding(
4141 lhs, rhs, window_strides, padding, feature_group_count, batch_group_count,
4142 precision_config, preferred_element_type);
4143 }
4144
ConvWithGeneralDimensions(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4145 XlaOp ConvWithGeneralDimensions(
4146 const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
4147 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
4148 int64_t feature_group_count, int64_t batch_group_count,
4149 const PrecisionConfig* precision_config,
4150 absl::optional<PrimitiveType> preferred_element_type) {
4151 return lhs.builder()->ConvWithGeneralDimensions(
4152 lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
4153 batch_group_count, precision_config, preferred_element_type);
4154 }
4155
ConvGeneral(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4156 XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
4157 absl::Span<const int64> window_strides,
4158 absl::Span<const std::pair<int64, int64>> padding,
4159 const ConvolutionDimensionNumbers& dimension_numbers,
4160 int64_t feature_group_count, int64_t batch_group_count,
4161 const PrecisionConfig* precision_config,
4162 absl::optional<PrimitiveType> preferred_element_type) {
4163 return lhs.builder()->ConvGeneral(
4164 lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
4165 batch_group_count, precision_config, preferred_element_type);
4166 }
4167
ConvGeneralDilated(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4168 XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
4169 absl::Span<const int64> window_strides,
4170 absl::Span<const std::pair<int64, int64>> padding,
4171 absl::Span<const int64> lhs_dilation,
4172 absl::Span<const int64> rhs_dilation,
4173 const ConvolutionDimensionNumbers& dimension_numbers,
4174 int64_t feature_group_count, int64_t batch_group_count,
4175 const PrecisionConfig* precision_config,
4176 absl::optional<PrimitiveType> preferred_element_type) {
4177 return lhs.builder()->ConvGeneralDilated(
4178 lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
4179 dimension_numbers, feature_group_count, batch_group_count,
4180 precision_config, preferred_element_type);
4181 }
4182
DynamicConvInputGrad(XlaOp input_sizes,const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)4183 XlaOp DynamicConvInputGrad(
4184 XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs,
4185 absl::Span<const int64> window_strides,
4186 absl::Span<const std::pair<int64, int64>> padding,
4187 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
4188 const ConvolutionDimensionNumbers& dimension_numbers,
4189 int64_t feature_group_count, int64_t batch_group_count,
4190 const PrecisionConfig* precision_config, PaddingType padding_type,
4191 absl::optional<PrimitiveType> preferred_element_type) {
4192 return lhs.builder()->DynamicConvInputGrad(
4193 input_sizes, lhs, rhs, window_strides, padding, lhs_dilation,
4194 rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
4195 precision_config, padding_type, preferred_element_type);
4196 }
4197
DynamicConvKernelGrad(XlaOp activations,XlaOp gradients,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)4198 XlaOp DynamicConvKernelGrad(
4199 XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
4200 absl::Span<const std::pair<int64, int64>> padding,
4201 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
4202 const ConvolutionDimensionNumbers& dimension_numbers,
4203 int64_t feature_group_count, int64_t batch_group_count,
4204 const PrecisionConfig* precision_config, PaddingType padding_type,
4205 absl::optional<PrimitiveType> preferred_element_type) {
4206 return activations.builder()->DynamicConvKernelGrad(
4207 activations, gradients, window_strides, padding, lhs_dilation,
4208 rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
4209 precision_config, padding_type, preferred_element_type);
4210 }
4211
DynamicConvForward(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)4212 XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
4213 absl::Span<const int64> window_strides,
4214 absl::Span<const std::pair<int64, int64>> padding,
4215 absl::Span<const int64> lhs_dilation,
4216 absl::Span<const int64> rhs_dilation,
4217 const ConvolutionDimensionNumbers& dimension_numbers,
4218 int64_t feature_group_count, int64_t batch_group_count,
4219 const PrecisionConfig* precision_config,
4220 PaddingType padding_type,
4221 absl::optional<PrimitiveType> preferred_element_type) {
4222 return lhs.builder()->DynamicConvForward(
4223 lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
4224 dimension_numbers, feature_group_count, batch_group_count,
4225 precision_config, padding_type, preferred_element_type);
4226 }
4227
Fft(const XlaOp operand,FftType fft_type,absl::Span<const int64> fft_length)4228 XlaOp Fft(const XlaOp operand, FftType fft_type,
4229 absl::Span<const int64> fft_length) {
4230 return operand.builder()->Fft(operand, fft_type, fft_length);
4231 }
4232
TriangularSolve(XlaOp a,XlaOp b,bool left_side,bool lower,bool unit_diagonal,TriangularSolveOptions::Transpose transpose_a)4233 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
4234 bool unit_diagonal,
4235 TriangularSolveOptions::Transpose transpose_a) {
4236 XlaBuilder* builder = a.builder();
4237 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
4238 TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a));
4239 TF_ASSIGN_OR_RETURN(const Shape* b_shape, builder->GetShapePtr(b));
4240 xla::TriangularSolveOptions options;
4241 options.set_left_side(left_side);
4242 options.set_lower(lower);
4243 options.set_unit_diagonal(unit_diagonal);
4244 options.set_transpose_a(transpose_a);
4245 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTriangularSolveShape(
4246 *a_shape, *b_shape, options));
4247 return builder->TriangularSolveInternal(shape, a, b, std::move(options));
4248 });
4249 }
4250
Cholesky(XlaOp a,bool lower)4251 XlaOp Cholesky(XlaOp a, bool lower) {
4252 XlaBuilder* builder = a.builder();
4253 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
4254 TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a));
4255 TF_ASSIGN_OR_RETURN(Shape shape,
4256 ShapeInference::InferCholeskyShape(*a_shape));
4257 return builder->CholeskyInternal(shape, a, lower);
4258 });
4259 }
4260
Infeed(XlaBuilder * builder,const Shape & shape,const string & config)4261 XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) {
4262 return builder->Infeed(shape, config);
4263 }
4264
Outfeed(const XlaOp operand,const Shape & shape_with_layout,const string & outfeed_config)4265 void Outfeed(const XlaOp operand, const Shape& shape_with_layout,
4266 const string& outfeed_config) {
4267 return operand.builder()->Outfeed(operand, shape_with_layout, outfeed_config);
4268 }
4269
Call(XlaBuilder * builder,const XlaComputation & computation,absl::Span<const XlaOp> operands)4270 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
4271 absl::Span<const XlaOp> operands) {
4272 return builder->Call(computation, operands);
4273 }
4274
CustomCall(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal,CustomCallSchedule schedule,CustomCallApiVersion api_version)4275 XlaOp CustomCall(
4276 XlaBuilder* builder, const string& call_target_name,
4277 absl::Span<const XlaOp> operands, const Shape& shape, const string& opaque,
4278 bool has_side_effect,
4279 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
4280 output_operand_aliasing,
4281 const Literal* literal, CustomCallSchedule schedule,
4282 CustomCallApiVersion api_version) {
4283 return builder->CustomCall(call_target_name, operands, shape, opaque,
4284 /*operand_shapes_with_layout=*/absl::nullopt,
4285 has_side_effect, output_operand_aliasing, literal,
4286 /*window=*/absl::nullopt, /*dnums=*/absl::nullopt,
4287 schedule, api_version);
4288 }
4289
CustomCallWithComputation(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const XlaComputation & computation,const Shape & shape,const string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal,CustomCallSchedule schedule,CustomCallApiVersion api_version)4290 XlaOp CustomCallWithComputation(
4291 XlaBuilder* builder, const string& call_target_name,
4292 absl::Span<const XlaOp> operands, const XlaComputation& computation,
4293 const Shape& shape, const string& opaque, bool has_side_effect,
4294 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
4295 output_operand_aliasing,
4296 const Literal* literal, CustomCallSchedule schedule,
4297 CustomCallApiVersion api_version) {
4298 return builder->CustomCall(
4299 call_target_name, operands, computation, shape, opaque,
4300 /*operand_shapes_with_layout=*/absl::nullopt, has_side_effect,
4301 output_operand_aliasing, literal, schedule, api_version);
4302 }
4303
CustomCallWithLayout(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,absl::Span<const Shape> operand_shapes_with_layout,const string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal,CustomCallSchedule schedule,CustomCallApiVersion api_version)4304 XlaOp CustomCallWithLayout(
4305 XlaBuilder* builder, const string& call_target_name,
4306 absl::Span<const XlaOp> operands, const Shape& shape,
4307 absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
4308 bool has_side_effect,
4309 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
4310 output_operand_aliasing,
4311 const Literal* literal, CustomCallSchedule schedule,
4312 CustomCallApiVersion api_version) {
4313 return builder->CustomCall(
4314 call_target_name, operands, shape, opaque, operand_shapes_with_layout,
4315 has_side_effect, output_operand_aliasing, literal,
4316 /*window=*/absl::nullopt, /*dnums=*/absl::nullopt, schedule, api_version);
4317 }
4318
CustomCallWithConvDnums(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,absl::Span<const Shape> operand_shapes_with_layout,const string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal,Window window,ConvolutionDimensionNumbers dnums,CustomCallSchedule schedule,CustomCallApiVersion api_version)4319 XlaOp CustomCallWithConvDnums(
4320 XlaBuilder* builder, const string& call_target_name,
4321 absl::Span<const XlaOp> operands, const Shape& shape,
4322 absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
4323 bool has_side_effect,
4324 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
4325 output_operand_aliasing,
4326 const Literal* literal, Window window, ConvolutionDimensionNumbers dnums,
4327 CustomCallSchedule schedule, CustomCallApiVersion api_version) {
4328 absl::optional<absl::Span<const Shape>> maybe_operand_shapes;
4329 if (!operand_shapes_with_layout.empty()) {
4330 maybe_operand_shapes = operand_shapes_with_layout;
4331 }
4332 return builder->CustomCall(call_target_name, operands, shape, opaque,
4333 maybe_operand_shapes, has_side_effect,
4334 output_operand_aliasing, literal, window, dnums,
4335 schedule, api_version);
4336 }
4337
Complex(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4338 XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
4339 absl::Span<const int64> broadcast_dimensions) {
4340 return lhs.builder()->BinaryOp(HloOpcode::kComplex, lhs, rhs,
4341 broadcast_dimensions);
4342 }
4343
Conj(const XlaOp operand)4344 XlaOp Conj(const XlaOp operand) {
4345 return Complex(Real(operand), Neg(Imag(operand)));
4346 }
4347
Add(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4348 XlaOp Add(const XlaOp lhs, const XlaOp rhs,
4349 absl::Span<const int64> broadcast_dimensions) {
4350 return lhs.builder()->BinaryOp(HloOpcode::kAdd, lhs, rhs,
4351 broadcast_dimensions);
4352 }
4353
Sub(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4354 XlaOp Sub(const XlaOp lhs, const XlaOp rhs,
4355 absl::Span<const int64> broadcast_dimensions) {
4356 return lhs.builder()->BinaryOp(HloOpcode::kSubtract, lhs, rhs,
4357 broadcast_dimensions);
4358 }
4359
Mul(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4360 XlaOp Mul(const XlaOp lhs, const XlaOp rhs,
4361 absl::Span<const int64> broadcast_dimensions) {
4362 return lhs.builder()->BinaryOp(HloOpcode::kMultiply, lhs, rhs,
4363 broadcast_dimensions);
4364 }
4365
Div(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4366 XlaOp Div(const XlaOp lhs, const XlaOp rhs,
4367 absl::Span<const int64> broadcast_dimensions) {
4368 return lhs.builder()->BinaryOp(HloOpcode::kDivide, lhs, rhs,
4369 broadcast_dimensions);
4370 }
4371
Rem(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4372 XlaOp Rem(const XlaOp lhs, const XlaOp rhs,
4373 absl::Span<const int64> broadcast_dimensions) {
4374 return lhs.builder()->BinaryOp(HloOpcode::kRemainder, lhs, rhs,
4375 broadcast_dimensions);
4376 }
4377
Max(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4378 XlaOp Max(const XlaOp lhs, const XlaOp rhs,
4379 absl::Span<const int64> broadcast_dimensions) {
4380 return lhs.builder()->BinaryOp(HloOpcode::kMaximum, lhs, rhs,
4381 broadcast_dimensions);
4382 }
4383
Min(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4384 XlaOp Min(const XlaOp lhs, const XlaOp rhs,
4385 absl::Span<const int64> broadcast_dimensions) {
4386 return lhs.builder()->BinaryOp(HloOpcode::kMinimum, lhs, rhs,
4387 broadcast_dimensions);
4388 }
4389
And(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4390 XlaOp And(const XlaOp lhs, const XlaOp rhs,
4391 absl::Span<const int64> broadcast_dimensions) {
4392 return lhs.builder()->BinaryOp(HloOpcode::kAnd, lhs, rhs,
4393 broadcast_dimensions);
4394 }
4395
Or(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4396 XlaOp Or(const XlaOp lhs, const XlaOp rhs,
4397 absl::Span<const int64> broadcast_dimensions) {
4398 return lhs.builder()->BinaryOp(HloOpcode::kOr, lhs, rhs,
4399 broadcast_dimensions);
4400 }
4401
Xor(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4402 XlaOp Xor(const XlaOp lhs, const XlaOp rhs,
4403 absl::Span<const int64> broadcast_dimensions) {
4404 return lhs.builder()->BinaryOp(HloOpcode::kXor, lhs, rhs,
4405 broadcast_dimensions);
4406 }
4407
Not(const XlaOp operand)4408 XlaOp Not(const XlaOp operand) {
4409 return operand.builder()->UnaryOp(HloOpcode::kNot, operand);
4410 }
4411
PopulationCount(const XlaOp operand)4412 XlaOp PopulationCount(const XlaOp operand) {
4413 return operand.builder()->UnaryOp(HloOpcode::kPopulationCount, operand);
4414 }
4415
ShiftLeft(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4416 XlaOp ShiftLeft(const XlaOp lhs, const XlaOp rhs,
4417 absl::Span<const int64> broadcast_dimensions) {
4418 return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs,
4419 broadcast_dimensions);
4420 }
4421
ShiftRightArithmetic(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4422 XlaOp ShiftRightArithmetic(const XlaOp lhs, const XlaOp rhs,
4423 absl::Span<const int64> broadcast_dimensions) {
4424 return lhs.builder()->BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs,
4425 broadcast_dimensions);
4426 }
4427
ShiftRightLogical(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4428 XlaOp ShiftRightLogical(const XlaOp lhs, const XlaOp rhs,
4429 absl::Span<const int64> broadcast_dimensions) {
4430 return lhs.builder()->BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs,
4431 broadcast_dimensions);
4432 }
4433
Reduce(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)4434 XlaOp Reduce(const XlaOp operand, const XlaOp init_value,
4435 const XlaComputation& computation,
4436 absl::Span<const int64> dimensions_to_reduce) {
4437 return operand.builder()->Reduce(operand, init_value, computation,
4438 dimensions_to_reduce);
4439 }
4440
4441 // Reduces several arrays simultaneously among the provided dimensions, given
4442 // "computation" as a reduction operator.
Reduce(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)4443 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
4444 absl::Span<const XlaOp> init_values,
4445 const XlaComputation& computation,
4446 absl::Span<const int64> dimensions_to_reduce) {
4447 return builder->Reduce(operands, init_values, computation,
4448 dimensions_to_reduce);
4449 }
4450
ReduceAll(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation)4451 XlaOp ReduceAll(const XlaOp operand, const XlaOp init_value,
4452 const XlaComputation& computation) {
4453 return operand.builder()->ReduceAll(operand, init_value, computation);
4454 }
4455
ReduceWindow(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)4456 XlaOp ReduceWindow(const XlaOp operand, const XlaOp init_value,
4457 const XlaComputation& computation,
4458 absl::Span<const int64> window_dimensions,
4459 absl::Span<const int64> window_strides, Padding padding) {
4460 return operand.builder()->ReduceWindow(operand, init_value, computation,
4461 window_dimensions, window_strides,
4462 padding);
4463 }
4464
ReduceWindow(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)4465 XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
4466 absl::Span<const XlaOp> init_values,
4467 const XlaComputation& computation,
4468 absl::Span<const int64> window_dimensions,
4469 absl::Span<const int64> window_strides, Padding padding) {
4470 CHECK(!operands.empty());
4471 return operands[0].builder()->ReduceWindow(operands, init_values, computation,
4472 window_dimensions, window_strides,
4473 padding);
4474 }
4475
ReduceWindowWithGeneralPadding(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)4476 XlaOp ReduceWindowWithGeneralPadding(
4477 const XlaOp operand, const XlaOp init_value,
4478 const XlaComputation& computation,
4479 absl::Span<const int64> window_dimensions,
4480 absl::Span<const int64> window_strides,
4481 absl::Span<const int64> base_dilations,
4482 absl::Span<const int64> window_dilations,
4483 absl::Span<const std::pair<int64, int64>> padding) {
4484 return operand.builder()->ReduceWindowWithGeneralPadding(
4485 absl::MakeSpan(&operand, 1), absl::MakeSpan(&init_value, 1), computation,
4486 window_dimensions, window_strides, base_dilations, window_dilations,
4487 padding);
4488 }
4489
ReduceWindowWithGeneralPadding(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)4490 XlaOp ReduceWindowWithGeneralPadding(
4491 absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
4492 const XlaComputation& computation,
4493 absl::Span<const int64> window_dimensions,
4494 absl::Span<const int64> window_strides,
4495 absl::Span<const int64> base_dilations,
4496 absl::Span<const int64> window_dilations,
4497 absl::Span<const std::pair<int64, int64>> padding) {
4498 CHECK(!operands.empty());
4499 return operands[0].builder()->ReduceWindowWithGeneralPadding(
4500 operands, init_values, computation, window_dimensions, window_strides,
4501 base_dilations, window_dilations, padding);
4502 }
4503
AllGather(const XlaOp operand,int64_t all_gather_dimension,int64_t shard_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Layout> & layout,const absl::optional<bool> use_global_device_ids)4504 XlaOp AllGather(const XlaOp operand, int64_t all_gather_dimension,
4505 int64_t shard_count,
4506 absl::Span<const ReplicaGroup> replica_groups,
4507 const absl::optional<ChannelHandle>& channel_id,
4508 const absl::optional<Layout>& layout,
4509 const absl::optional<bool> use_global_device_ids) {
4510 return operand.builder()->AllGather(operand, all_gather_dimension,
4511 shard_count, replica_groups, channel_id,
4512 layout, use_global_device_ids);
4513 }
4514
CrossReplicaSum(const XlaOp operand,absl::Span<const ReplicaGroup> replica_groups)4515 XlaOp CrossReplicaSum(const XlaOp operand,
4516 absl::Span<const ReplicaGroup> replica_groups) {
4517 return operand.builder()->CrossReplicaSum(operand, replica_groups);
4518 }
4519
AllReduce(const XlaOp operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Shape> & shape_with_layout)4520 XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
4521 absl::Span<const ReplicaGroup> replica_groups,
4522 const absl::optional<ChannelHandle>& channel_id,
4523 const absl::optional<Shape>& shape_with_layout) {
4524 return operand.builder()->AllReduce(operand, computation, replica_groups,
4525 channel_id, shape_with_layout);
4526 }
4527
ReduceScatter(const XlaOp operand,const XlaComputation & computation,int64_t scatter_dimension,int64_t shard_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Layout> & layout,const absl::optional<bool> use_global_device_ids)4528 XlaOp ReduceScatter(const XlaOp operand, const XlaComputation& computation,
4529 int64_t scatter_dimension, int64_t shard_count,
4530 absl::Span<const ReplicaGroup> replica_groups,
4531 const absl::optional<ChannelHandle>& channel_id,
4532 const absl::optional<Layout>& layout,
4533 const absl::optional<bool> use_global_device_ids) {
4534 return operand.builder()->ReduceScatter(
4535 operand, computation, scatter_dimension, shard_count, replica_groups,
4536 channel_id, layout, use_global_device_ids);
4537 }
4538
AllToAll(const XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<Layout> & layout)4539 XlaOp AllToAll(const XlaOp operand, int64_t split_dimension,
4540 int64_t concat_dimension, int64_t split_count,
4541 absl::Span<const ReplicaGroup> replica_groups,
4542 const absl::optional<Layout>& layout) {
4543 return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
4544 split_count, replica_groups, layout);
4545 }
4546
AllToAllTuple(const XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<Layout> & layout)4547 XlaOp AllToAllTuple(const XlaOp operand, int64_t split_dimension,
4548 int64_t concat_dimension, int64_t split_count,
4549 absl::Span<const ReplicaGroup> replica_groups,
4550 const absl::optional<Layout>& layout) {
4551 return operand.builder()->AllToAllTuple(operand, split_dimension,
4552 concat_dimension, split_count,
4553 replica_groups, layout);
4554 }
4555
CollectivePermute(const XlaOp operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)4556 XlaOp CollectivePermute(
4557 const XlaOp operand,
4558 const std::vector<std::pair<int64, int64>>& source_target_pairs) {
4559 return operand.builder()->CollectivePermute(operand, source_target_pairs);
4560 }
4561
ReplicaId(XlaBuilder * builder)4562 XlaOp ReplicaId(XlaBuilder* builder) { return builder->ReplicaId(); }
4563
SelectAndScatter(const XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding,const XlaOp source,const XlaOp init_value,const XlaComputation & scatter)4564 XlaOp SelectAndScatter(const XlaOp operand, const XlaComputation& select,
4565 absl::Span<const int64> window_dimensions,
4566 absl::Span<const int64> window_strides, Padding padding,
4567 const XlaOp source, const XlaOp init_value,
4568 const XlaComputation& scatter) {
4569 return operand.builder()->SelectAndScatter(operand, select, window_dimensions,
4570 window_strides, padding, source,
4571 init_value, scatter);
4572 }
4573
SelectAndScatterWithGeneralPadding(const XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const XlaOp source,const XlaOp init_value,const XlaComputation & scatter)4574 XlaOp SelectAndScatterWithGeneralPadding(
4575 const XlaOp operand, const XlaComputation& select,
4576 absl::Span<const int64> window_dimensions,
4577 absl::Span<const int64> window_strides,
4578 absl::Span<const std::pair<int64, int64>> padding, const XlaOp source,
4579 const XlaOp init_value, const XlaComputation& scatter) {
4580 return operand.builder()->SelectAndScatterWithGeneralPadding(
4581 operand, select, window_dimensions, window_strides, padding, source,
4582 init_value, scatter);
4583 }
4584
Abs(const XlaOp operand)4585 XlaOp Abs(const XlaOp operand) {
4586 return operand.builder()->UnaryOp(HloOpcode::kAbs, operand);
4587 }
4588
Atan2(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4589 XlaOp Atan2(const XlaOp lhs, const XlaOp rhs,
4590 absl::Span<const int64> broadcast_dimensions) {
4591 return lhs.builder()->BinaryOp(HloOpcode::kAtan2, lhs, rhs,
4592 broadcast_dimensions);
4593 }
4594
Exp(const XlaOp operand)4595 XlaOp Exp(const XlaOp operand) {
4596 return operand.builder()->UnaryOp(HloOpcode::kExp, operand);
4597 }
Expm1(const XlaOp operand)4598 XlaOp Expm1(const XlaOp operand) {
4599 return operand.builder()->UnaryOp(HloOpcode::kExpm1, operand);
4600 }
Floor(const XlaOp operand)4601 XlaOp Floor(const XlaOp operand) {
4602 return operand.builder()->UnaryOp(HloOpcode::kFloor, operand);
4603 }
Ceil(const XlaOp operand)4604 XlaOp Ceil(const XlaOp operand) {
4605 return operand.builder()->UnaryOp(HloOpcode::kCeil, operand);
4606 }
Round(const XlaOp operand)4607 XlaOp Round(const XlaOp operand) {
4608 return operand.builder()->UnaryOp(HloOpcode::kRoundNearestAfz, operand);
4609 }
Log(const XlaOp operand)4610 XlaOp Log(const XlaOp operand) {
4611 return operand.builder()->UnaryOp(HloOpcode::kLog, operand);
4612 }
Log1p(const XlaOp operand)4613 XlaOp Log1p(const XlaOp operand) {
4614 return operand.builder()->UnaryOp(HloOpcode::kLog1p, operand);
4615 }
Logistic(const XlaOp operand)4616 XlaOp Logistic(const XlaOp operand) {
4617 return operand.builder()->UnaryOp(HloOpcode::kLogistic, operand);
4618 }
Sign(const XlaOp operand)4619 XlaOp Sign(const XlaOp operand) {
4620 return operand.builder()->UnaryOp(HloOpcode::kSign, operand);
4621 }
Clz(const XlaOp operand)4622 XlaOp Clz(const XlaOp operand) {
4623 return operand.builder()->UnaryOp(HloOpcode::kClz, operand);
4624 }
Cos(const XlaOp operand)4625 XlaOp Cos(const XlaOp operand) {
4626 return operand.builder()->UnaryOp(HloOpcode::kCos, operand);
4627 }
Sin(const XlaOp operand)4628 XlaOp Sin(const XlaOp operand) {
4629 return operand.builder()->UnaryOp(HloOpcode::kSin, operand);
4630 }
Tanh(const XlaOp operand)4631 XlaOp Tanh(const XlaOp operand) {
4632 return operand.builder()->UnaryOp(HloOpcode::kTanh, operand);
4633 }
Real(const XlaOp operand)4634 XlaOp Real(const XlaOp operand) {
4635 return operand.builder()->UnaryOp(HloOpcode::kReal, operand);
4636 }
Imag(const XlaOp operand)4637 XlaOp Imag(const XlaOp operand) {
4638 return operand.builder()->UnaryOp(HloOpcode::kImag, operand);
4639 }
Sqrt(const XlaOp operand)4640 XlaOp Sqrt(const XlaOp operand) {
4641 return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand);
4642 }
Cbrt(const XlaOp operand)4643 XlaOp Cbrt(const XlaOp operand) {
4644 return operand.builder()->UnaryOp(HloOpcode::kCbrt, operand);
4645 }
Rsqrt(const XlaOp operand)4646 XlaOp Rsqrt(const XlaOp operand) {
4647 return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand);
4648 }
4649
Pow(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4650 XlaOp Pow(const XlaOp lhs, const XlaOp rhs,
4651 absl::Span<const int64> broadcast_dimensions) {
4652 return lhs.builder()->BinaryOp(HloOpcode::kPower, lhs, rhs,
4653 broadcast_dimensions);
4654 }
4655
IsFinite(const XlaOp operand)4656 XlaOp IsFinite(const XlaOp operand) {
4657 return operand.builder()->UnaryOp(HloOpcode::kIsFinite, operand);
4658 }
4659
ConvertElementType(const XlaOp operand,PrimitiveType new_element_type)4660 XlaOp ConvertElementType(const XlaOp operand, PrimitiveType new_element_type) {
4661 return operand.builder()->ConvertElementType(operand, new_element_type);
4662 }
4663
BitcastConvertType(const XlaOp operand,PrimitiveType new_element_type)4664 XlaOp BitcastConvertType(const XlaOp operand, PrimitiveType new_element_type) {
4665 return operand.builder()->BitcastConvertType(operand, new_element_type);
4666 }
4667
Neg(const XlaOp operand)4668 XlaOp Neg(const XlaOp operand) {
4669 return operand.builder()->UnaryOp(HloOpcode::kNegate, operand);
4670 }
4671
Transpose(const XlaOp operand,absl::Span<const int64> permutation)4672 XlaOp Transpose(const XlaOp operand, absl::Span<const int64> permutation) {
4673 return operand.builder()->Transpose(operand, permutation);
4674 }
4675
Rev(const XlaOp operand,absl::Span<const int64> dimensions)4676 XlaOp Rev(const XlaOp operand, absl::Span<const int64> dimensions) {
4677 return operand.builder()->Rev(operand, dimensions);
4678 }
4679
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64_t dimension,bool is_stable)4680 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
4681 int64_t dimension, bool is_stable) {
4682 return operands[0].builder()->Sort(operands, comparator, dimension,
4683 is_stable);
4684 }
4685
Clamp(const XlaOp min,const XlaOp operand,const XlaOp max)4686 XlaOp Clamp(const XlaOp min, const XlaOp operand, const XlaOp max) {
4687 return min.builder()->Clamp(min, operand, max);
4688 }
4689
Map(XlaBuilder * builder,absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64> dimensions,absl::Span<const XlaOp> static_operands)4690 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
4691 const XlaComputation& computation, absl::Span<const int64> dimensions,
4692 absl::Span<const XlaOp> static_operands) {
4693 return builder->Map(operands, computation, dimensions, static_operands);
4694 }
4695
RngNormal(const XlaOp mu,const XlaOp sigma,const Shape & shape)4696 XlaOp RngNormal(const XlaOp mu, const XlaOp sigma, const Shape& shape) {
4697 return mu.builder()->RngNormal(mu, sigma, shape);
4698 }
4699
RngUniform(const XlaOp a,const XlaOp b,const Shape & shape)4700 XlaOp RngUniform(const XlaOp a, const XlaOp b, const Shape& shape) {
4701 return a.builder()->RngUniform(a, b, shape);
4702 }
4703
RngBitGenerator(RandomAlgorithm algorithm,const XlaOp initial_state,const Shape & shape)4704 XlaOp RngBitGenerator(RandomAlgorithm algorithm, const XlaOp initial_state,
4705 const Shape& shape) {
4706 return initial_state.builder()->RngBitGenerator(algorithm, initial_state,
4707 shape);
4708 }
4709
While(const XlaComputation & condition,const XlaComputation & body,const XlaOp init)4710 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
4711 const XlaOp init) {
4712 return init.builder()->While(condition, body, init);
4713 }
4714
Conditional(const XlaOp predicate,const XlaOp true_operand,const XlaComputation & true_computation,const XlaOp false_operand,const XlaComputation & false_computation)4715 XlaOp Conditional(const XlaOp predicate, const XlaOp true_operand,
4716 const XlaComputation& true_computation,
4717 const XlaOp false_operand,
4718 const XlaComputation& false_computation) {
4719 return predicate.builder()->Conditional(predicate, true_operand,
4720 true_computation, false_operand,
4721 false_computation);
4722 }
4723
Conditional(const XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)4724 XlaOp Conditional(const XlaOp branch_index,
4725 absl::Span<const XlaComputation* const> branch_computations,
4726 absl::Span<const XlaOp> branch_operands) {
4727 return branch_index.builder()->Conditional(branch_index, branch_computations,
4728 branch_operands);
4729 }
4730
ReducePrecision(const XlaOp operand,const int exponent_bits,const int mantissa_bits)4731 XlaOp ReducePrecision(const XlaOp operand, const int exponent_bits,
4732 const int mantissa_bits) {
4733 return operand.builder()->ReducePrecision(operand, exponent_bits,
4734 mantissa_bits);
4735 }
4736
Gather(const XlaOp input,const XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)4737 XlaOp Gather(const XlaOp input, const XlaOp start_indices,
4738 const GatherDimensionNumbers& dimension_numbers,
4739 absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
4740 return input.builder()->Gather(input, start_indices, dimension_numbers,
4741 slice_sizes, indices_are_sorted);
4742 }
4743
Scatter(const XlaOp input,const XlaOp scatter_indices,const XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)4744 XlaOp Scatter(const XlaOp input, const XlaOp scatter_indices,
4745 const XlaOp updates, const XlaComputation& update_computation,
4746 const ScatterDimensionNumbers& dimension_numbers,
4747 bool indices_are_sorted, bool unique_indices) {
4748 return input.builder()->Scatter(input, scatter_indices, updates,
4749 update_computation, dimension_numbers,
4750 indices_are_sorted, unique_indices);
4751 }
4752
Send(const XlaOp operand,const ChannelHandle & handle)4753 void Send(const XlaOp operand, const ChannelHandle& handle) {
4754 return operand.builder()->Send(operand, handle);
4755 }
4756
Recv(XlaBuilder * builder,const Shape & shape,const ChannelHandle & handle)4757 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
4758 const ChannelHandle& handle) {
4759 return builder->Recv(shape, handle);
4760 }
4761
SendWithToken(const XlaOp operand,const XlaOp token,const ChannelHandle & handle)4762 XlaOp SendWithToken(const XlaOp operand, const XlaOp token,
4763 const ChannelHandle& handle) {
4764 return operand.builder()->SendWithToken(operand, token, handle);
4765 }
4766
RecvWithToken(const XlaOp token,const Shape & shape,const ChannelHandle & handle)4767 XlaOp RecvWithToken(const XlaOp token, const Shape& shape,
4768 const ChannelHandle& handle) {
4769 return token.builder()->RecvWithToken(token, shape, handle);
4770 }
4771
SendToHost(const XlaOp operand,const XlaOp token,const Shape & shape_with_layout,const ChannelHandle & handle)4772 XlaOp SendToHost(const XlaOp operand, const XlaOp token,
4773 const Shape& shape_with_layout, const ChannelHandle& handle) {
4774 return operand.builder()->SendToHost(operand, token, shape_with_layout,
4775 handle);
4776 }
4777
RecvFromHost(const XlaOp token,const Shape & shape,const ChannelHandle & handle)4778 XlaOp RecvFromHost(const XlaOp token, const Shape& shape,
4779 const ChannelHandle& handle) {
4780 return token.builder()->RecvFromHost(token, shape, handle);
4781 }
4782
InfeedWithToken(const XlaOp token,const Shape & shape,const string & config)4783 XlaOp InfeedWithToken(const XlaOp token, const Shape& shape,
4784 const string& config) {
4785 return token.builder()->InfeedWithToken(token, shape, config);
4786 }
4787
OutfeedWithToken(const XlaOp operand,const XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)4788 XlaOp OutfeedWithToken(const XlaOp operand, const XlaOp token,
4789 const Shape& shape_with_layout,
4790 const string& outfeed_config) {
4791 return operand.builder()->OutfeedWithToken(operand, token, shape_with_layout,
4792 outfeed_config);
4793 }
4794
CreateToken(XlaBuilder * builder)4795 XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); }
4796
AfterAll(XlaBuilder * builder,absl::Span<const XlaOp> tokens)4797 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens) {
4798 return builder->AfterAll(tokens);
4799 }
4800
BatchNormTraining(const XlaOp operand,const XlaOp scale,const XlaOp offset,float epsilon,int64_t feature_index)4801 XlaOp BatchNormTraining(const XlaOp operand, const XlaOp scale,
4802 const XlaOp offset, float epsilon,
4803 int64_t feature_index) {
4804 return operand.builder()->BatchNormTraining(operand, scale, offset, epsilon,
4805 feature_index);
4806 }
4807
BatchNormInference(const XlaOp operand,const XlaOp scale,const XlaOp offset,const XlaOp mean,const XlaOp variance,float epsilon,int64_t feature_index)4808 XlaOp BatchNormInference(const XlaOp operand, const XlaOp scale,
4809 const XlaOp offset, const XlaOp mean,
4810 const XlaOp variance, float epsilon,
4811 int64_t feature_index) {
4812 return operand.builder()->BatchNormInference(
4813 operand, scale, offset, mean, variance, epsilon, feature_index);
4814 }
4815
BatchNormGrad(const XlaOp operand,const XlaOp scale,const XlaOp batch_mean,const XlaOp batch_var,const XlaOp grad_output,float epsilon,int64_t feature_index)4816 XlaOp BatchNormGrad(const XlaOp operand, const XlaOp scale,
4817 const XlaOp batch_mean, const XlaOp batch_var,
4818 const XlaOp grad_output, float epsilon,
4819 int64_t feature_index) {
4820 return operand.builder()->BatchNormGrad(operand, scale, batch_mean, batch_var,
4821 grad_output, epsilon, feature_index);
4822 }
4823
Iota(XlaBuilder * builder,PrimitiveType type,int64_t size)4824 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size) {
4825 return builder->Iota(type, size);
4826 }
4827
Iota(XlaBuilder * builder,const Shape & shape,int64_t iota_dimension)4828 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64_t iota_dimension) {
4829 return builder->Iota(shape, iota_dimension);
4830 }
4831
GetDimensionSize(const XlaOp operand,int64_t dimension)4832 XlaOp GetDimensionSize(const XlaOp operand, int64_t dimension) {
4833 return operand.builder()->GetDimensionSize(operand, dimension);
4834 }
4835
SetDimensionSize(const XlaOp operand,const XlaOp val,int64_t dimension)4836 XlaOp SetDimensionSize(const XlaOp operand, const XlaOp val,
4837 int64_t dimension) {
4838 return operand.builder()->SetDimensionSize(operand, val, dimension);
4839 }
4840
RemoveDynamicDimension(const XlaOp operand,int64_t dimension)4841 XlaOp RemoveDynamicDimension(const XlaOp operand, int64_t dimension) {
4842 return operand.builder()->RemoveDynamicDimension(operand, dimension);
4843 }
4844
4845 } // namespace xla
4846