1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/xla_builder.h"
17
18 #include <functional>
19 #include <iterator>
20 #include <memory>
21 #include <numeric>
22 #include <optional>
23 #include <queue>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/match.h"
32 #include "absl/strings/numbers.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/str_split.h"
36 #include "absl/types/span.h"
37 #include "tensorflow/compiler/xla/client/sharding_builder.h"
38 #include "tensorflow/compiler/xla/client/xla_computation.h"
39 #include "tensorflow/compiler/xla/comparison_util.h"
40 #include "tensorflow/compiler/xla/layout_util.h"
41 #include "tensorflow/compiler/xla/permutation_util.h"
42 #include "tensorflow/compiler/xla/primitive_util.h"
43 #include "tensorflow/compiler/xla/service/hlo.pb.h"
44 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
45 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
46 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
47 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
48 #include "tensorflow/compiler/xla/service/shape_inference.h"
49 #include "tensorflow/compiler/xla/sharding_op_util.h"
50 #include "tensorflow/compiler/xla/status_macros.h"
51 #include "tensorflow/compiler/xla/util.h"
52 #include "tensorflow/compiler/xla/window_util.h"
53 #include "tensorflow/compiler/xla/xla_data.pb.h"
54 #include "tensorflow/core/platform/errors.h"
55
56 namespace xla {
57
58 using absl::StrCat;
59
60 namespace {
61
62 static const char kNameSeparator = '.';
63
64 // Retrieves the base name of an instruction or computation fully qualified
65 // name, using separator as boundary between the initial base name part, and
66 // the numeric identification.
GetBaseName(const std::string & name,char separator)67 std::string GetBaseName(const std::string& name, char separator) {
68 auto pos = name.rfind(separator);
69 CHECK_NE(pos, std::string::npos) << name;
70 return name.substr(0, pos);
71 }
72
73 // Generates a fully qualified computation/instruction name.
GetFullName(const std::string & base_name,char separator,int64_t id)74 std::string GetFullName(const std::string& base_name, char separator,
75 int64_t id) {
76 const char separator_str[] = {separator, '\0'};
77 return StrCat(base_name, separator_str, id);
78 }
79
80 // Common function to standardize setting name and IDs on computation and
81 // instruction proto entities.
82 template <typename T>
SetProtoIdAndName(T * entry,const std::string & base_name,char separator,int64_t id)83 void SetProtoIdAndName(T* entry, const std::string& base_name, char separator,
84 int64_t id) {
85 entry->set_id(id);
86 entry->set_name(GetFullName(base_name, separator, id));
87 }
88
InstrIsSetBound(const HloInstructionProto * instr_proto)89 bool InstrIsSetBound(const HloInstructionProto* instr_proto) {
90 HloOpcode opcode = StringToHloOpcode(instr_proto->opcode()).ValueOrDie();
91 if (opcode == HloOpcode::kCustomCall &&
92 instr_proto->custom_call_target() == "SetBound") {
93 return true;
94 }
95 return false;
96 }
97
98 } // namespace
99
100 namespace internal {
101
BuildAddDependency(XlaBuilder * builder,XlaOp operand,XlaOp token,const Shape & shape)102 XlaOp XlaBuilderFriend::BuildAddDependency(XlaBuilder* builder, XlaOp operand,
103 XlaOp token, const Shape& shape) {
104 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
105 HloInstructionProto instr;
106 *instr.mutable_shape() = shape.ToProto();
107 return builder->AddInstruction(std::move(instr), HloOpcode::kAddDependency,
108 {operand, token});
109 });
110 }
111
BuildFusion(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::string_view fusion_kind,const XlaComputation & fused_computation)112 XlaOp XlaBuilderFriend::BuildFusion(XlaBuilder* builder,
113 absl::Span<const XlaOp> operands,
114 absl::string_view fusion_kind,
115 const XlaComputation& fused_computation) {
116 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
117 HloInstructionProto instr;
118 instr.set_fusion_kind(std::string(fusion_kind));
119 std::vector<const Shape*> operand_shape_ptrs;
120 TF_ASSIGN_OR_RETURN(auto program_shape,
121 fused_computation.GetProgramShape());
122 *instr.mutable_shape() = program_shape.result().ToProto();
123 builder->AddCalledComputation(fused_computation, &instr);
124 return builder->AddInstruction(std::move(instr), HloOpcode::kFusion,
125 operands);
126 });
127 }
128
BuildBitcast(XlaBuilder * builder,XlaOp operand,const Shape & shape)129 XlaOp XlaBuilderFriend::BuildBitcast(XlaBuilder* builder, XlaOp operand,
130 const Shape& shape) {
131 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
132 HloInstructionProto instr;
133 *instr.mutable_shape() = shape.ToProto();
134 return builder->AddInstruction(std::move(instr), HloOpcode::kBitcast,
135 {operand});
136 });
137 }
138
BuildDomain(XlaBuilder * builder,XlaOp operand,const OpSharding entry,const OpSharding exit,const Shape & shape)139 XlaOp XlaBuilderFriend::BuildDomain(XlaBuilder* builder, XlaOp operand,
140 const OpSharding entry,
141 const OpSharding exit, const Shape& shape) {
142 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
143 HloInstructionProto instr;
144 *instr.mutable_domain_entry_sharding() = entry;
145 *instr.mutable_domain_exit_sharding() = exit;
146 *instr.mutable_shape() = shape.ToProto();
147 return builder->AddInstruction(std::move(instr), HloOpcode::kDomain,
148 {operand});
149 });
150 }
151
BuildPartitionId(XlaBuilder * builder,const Shape & shape)152 XlaOp XlaBuilderFriend::BuildPartitionId(XlaBuilder* builder,
153 const Shape& shape) {
154 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
155 HloInstructionProto instr;
156 *instr.mutable_shape() = shape.ToProto();
157 return builder->AddInstruction(std::move(instr), HloOpcode::kPartitionId);
158 });
159 }
160
BuildRngGetAndUpdateState(XlaBuilder * builder,int64_t delta,const Shape & shape)161 XlaOp XlaBuilderFriend::BuildRngGetAndUpdateState(XlaBuilder* builder,
162
163 int64_t delta,
164 const Shape& shape) {
165 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
166 HloInstructionProto instr;
167 instr.set_delta(delta);
168 *instr.mutable_shape() = shape.ToProto();
169 return builder->AddInstruction(std::move(instr),
170 HloOpcode::kRngGetAndUpdateState);
171 });
172 }
173
GetInstruction(XlaOp op)174 HloInstructionProto* XlaBuilderFriend::GetInstruction(XlaOp op) {
175 return &op.builder()
176 ->instructions_[op.builder()->handle_to_index_[op.handle_]];
177 }
178
GetInstructionByHandle(XlaBuilder * builder,int64_t handle)179 HloInstructionProto* XlaBuilderFriend::GetInstructionByHandle(
180 XlaBuilder* builder, int64_t handle) {
181 return &builder->instructions_[builder->handle_to_index_[handle]];
182 }
183
184 } // namespace internal
185
operator -(XlaOp x)186 XlaOp operator-(XlaOp x) { return Neg(x); }
operator +(XlaOp x,XlaOp y)187 XlaOp operator+(XlaOp x, XlaOp y) { return Add(x, y); }
operator -(XlaOp x,XlaOp y)188 XlaOp operator-(XlaOp x, XlaOp y) { return Sub(x, y); }
operator *(XlaOp x,XlaOp y)189 XlaOp operator*(XlaOp x, XlaOp y) { return Mul(x, y); }
operator /(XlaOp x,XlaOp y)190 XlaOp operator/(XlaOp x, XlaOp y) { return Div(x, y); }
operator %(XlaOp x,XlaOp y)191 XlaOp operator%(XlaOp x, XlaOp y) { return Rem(x, y); }
192
operator ~(XlaOp x)193 XlaOp operator~(XlaOp x) { return Not(x); }
operator &(XlaOp x,XlaOp y)194 XlaOp operator&(XlaOp x, XlaOp y) { return And(x, y); }
operator |(XlaOp x,XlaOp y)195 XlaOp operator|(XlaOp x, XlaOp y) { return Or(x, y); }
operator ^(XlaOp x,XlaOp y)196 XlaOp operator^(XlaOp x, XlaOp y) { return Xor(x, y); }
operator <<(XlaOp x,XlaOp y)197 XlaOp operator<<(XlaOp x, XlaOp y) { return ShiftLeft(x, y); }
198
operator >>(XlaOp x,XlaOp y)199 XlaOp operator>>(XlaOp x, XlaOp y) {
200 XlaBuilder* builder = x.builder();
201 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
202 TF_ASSIGN_OR_RETURN(const Shape* shape, builder->GetShapePtr(x));
203 if (!ShapeUtil::ElementIsIntegral(*shape)) {
204 return InvalidArgument(
205 "Argument to >> operator does not have an integral type (%s).",
206 ShapeUtil::HumanString(*shape));
207 }
208 if (ShapeUtil::ElementIsSigned(*shape)) {
209 return ShiftRightArithmetic(x, y);
210 } else {
211 return ShiftRightLogical(x, y);
212 }
213 });
214 }
215
GetShapePtr(XlaOp op) const216 StatusOr<const Shape*> XlaBuilder::GetShapePtr(XlaOp op) const {
217 TF_RETURN_IF_ERROR(first_error_);
218 TF_RETURN_IF_ERROR(CheckOpBuilder(op));
219 auto it = handle_to_index_.find(op.handle());
220 if (it == handle_to_index_.end()) {
221 return InvalidArgument("No XlaOp with handle %d", op.handle());
222 }
223 return instruction_shapes_.at(it->second).get();
224 }
225
GetShape(XlaOp op) const226 StatusOr<Shape> XlaBuilder::GetShape(XlaOp op) const {
227 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(op));
228 return *shape;
229 }
230
GetOperandShapes(absl::Span<const XlaOp> operands) const231 StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
232 absl::Span<const XlaOp> operands) const {
233 std::vector<Shape> operand_shapes;
234 operand_shapes.reserve(operands.size());
235 for (XlaOp operand : operands) {
236 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
237 operand_shapes.push_back(*shape);
238 }
239 return operand_shapes;
240 }
241
OpToString(XlaOp op) const242 std::string XlaBuilder::OpToString(XlaOp op) const {
243 std::string s;
244 ToStringHelper(&s, /*ident=*/0, op.handle());
245 return s;
246 }
247
ShapeToString(const ShapeProto & shape)248 static std::string ShapeToString(const ShapeProto& shape) {
249 if (shape.tuple_shapes_size() > 1) {
250 return absl::StrCat(
251 "(",
252 absl::StrJoin(shape.tuple_shapes(), ", ",
253 [&](std::string* s, const ShapeProto& subshape) {
254 absl::StrAppend(s, ShapeToString(subshape));
255 }),
256 ")");
257 }
258 return absl::StrCat("[", absl::StrJoin(shape.dimensions(), ", "), "]");
259 }
260
ToStringHelper(std::string * out,int ident,int64_t op_handle) const261 void XlaBuilder::ToStringHelper(std::string* out, int ident,
262 int64_t op_handle) const {
263 const HloInstructionProto& instr =
264 *(LookUpInstructionByHandle(op_handle).ValueOrDie());
265 absl::StrAppend(out, std::string(ident, ' '), instr.opcode(),
266 ", shape=", ShapeToString(instr.shape()));
267 if (instr.has_metadata()) {
268 absl::StrAppend(out, ", metadata={", instr.metadata().source_file(), ":",
269 instr.metadata().source_line(), "}");
270 }
271 if (instr.operand_ids_size()) {
272 absl::StrAppend(out, "\n");
273 }
274 absl::StrAppend(out, absl::StrJoin(instr.operand_ids(), "\n",
275 [&](std::string* s, int64_t subop) {
276 ToStringHelper(s, ident + 2, subop);
277 }));
278 }
279
XlaBuilder(const std::string & computation_name)280 XlaBuilder::XlaBuilder(const std::string& computation_name)
281 : name_(computation_name) {}
282
~XlaBuilder()283 XlaBuilder::~XlaBuilder() {}
284
ReportError(const Status & error)285 XlaOp XlaBuilder::ReportError(const Status& error) {
286 CHECK(!error.ok());
287 if (die_immediately_on_error_) {
288 LOG(FATAL) << "error building computation: " << error;
289 }
290
291 if (first_error_.ok()) {
292 first_error_ = error;
293 first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
294 }
295 return XlaOp(this);
296 }
297
ReportErrorOrReturn(const StatusOr<XlaOp> & op)298 XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr<XlaOp>& op) {
299 if (!first_error_.ok()) {
300 return XlaOp(this);
301 }
302 if (!op.ok()) {
303 return ReportError(op.status());
304 }
305 return op.ValueOrDie();
306 }
307
ReportErrorOrReturn(const std::function<StatusOr<XlaOp> ()> & op_creator)308 XlaOp XlaBuilder::ReportErrorOrReturn(
309 const std::function<StatusOr<XlaOp>()>& op_creator) {
310 return ReportErrorOrReturn(op_creator());
311 }
312
GetProgramShape(int64_t root_id) const313 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64_t root_id) const {
314 TF_RETURN_IF_ERROR(first_error_);
315 TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
316 LookUpInstructionByHandle(root_id));
317
318 ProgramShape program_shape;
319
320 *program_shape.mutable_result() = Shape(root_proto->shape());
321
322 // Check that the parameter numbers are continuous from 0, and add parameter
323 // shapes and names to the program shape.
324 const int64_t param_count = parameter_numbers_.size();
325 for (int64_t i = 0; i < param_count; i++) {
326 program_shape.add_parameters();
327 program_shape.add_parameter_names();
328 }
329 for (const HloInstructionProto& instr : instructions_) {
330 // Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So
331 // to verify continuity, we just need to verify that every parameter is in
332 // the right range.
333 if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
334 const int64_t index = instr.parameter_number();
335 TF_RET_CHECK(index >= 0 && index < param_count)
336 << "invalid parameter number: " << index;
337 *program_shape.mutable_parameters(index) = Shape(instr.shape());
338 *program_shape.mutable_parameter_names(index) = instr.name();
339 }
340 }
341 return program_shape;
342 }
343
GetProgramShape() const344 StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const {
345 TF_RET_CHECK(!instructions_.empty());
346 return GetProgramShape(instructions_.back().id());
347 }
348
GetProgramShape(XlaOp root) const349 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(XlaOp root) const {
350 if (root.builder_ != this) {
351 return InvalidArgument("Given root operation is not in this computation.");
352 }
353 return GetProgramShape(root.handle());
354 }
355
IsConstantVisitor(const int64_t op_handle,int depth,absl::flat_hash_set<int64_t> * visited,bool * is_constant) const356 void XlaBuilder::IsConstantVisitor(const int64_t op_handle, int depth,
357 absl::flat_hash_set<int64_t>* visited,
358 bool* is_constant) const {
359 if (visited->contains(op_handle) || !*is_constant) {
360 return;
361 }
362
363 const HloInstructionProto& instr =
364 *(LookUpInstructionByHandle(op_handle).ValueOrDie());
365 HloInstructionProto to_print(instr);
366 to_print.clear_shape();
367 const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
368 const std::string indent =
369 absl::StrJoin(std::vector<absl::string_view>(depth, " "), "");
370 if (VLOG_IS_ON(2)) {
371 VLOG(2) << indent << "Visiting:";
372 for (const auto& l : absl::StrSplit(to_print.DebugString(), '\n')) {
373 VLOG(2) << indent << l;
374 }
375 }
376 switch (opcode) {
377 default:
378 for (const int64_t operand_id : instr.operand_ids()) {
379 IsConstantVisitor(operand_id, depth + 1, visited, is_constant);
380 }
381 // TODO(b/32495713): We aren't checking the called computations.
382 break;
383
384 case HloOpcode::kGetDimensionSize:
385 // GetDimensionSize is always considered constant in XLA -- If a dynamic
386 // dimension is presented, -1 is returned.
387 break;
388 // Non functional ops.
389 case HloOpcode::kRng:
390 case HloOpcode::kAllReduce:
391 case HloOpcode::kReduceScatter:
392 // TODO(b/33009255): Implement constant folding for cross replica sum.
393 case HloOpcode::kInfeed:
394 case HloOpcode::kOutfeed:
395 case HloOpcode::kCall:
396 // TODO(b/32495713): We aren't checking the to_apply computation itself,
397 // so we conservatively say that computations containing the Call op
398 // cannot be constant. We cannot set is_functional=false in other similar
399 // cases since we're already relying on IsConstant to return true.
400 case HloOpcode::kCustomCall:
401 if (instr.custom_call_target() == "SetBound") {
402 // Set bound is considered constant -- the bound is used as the value.
403 break;
404 }
405 [[fallthrough]];
406 case HloOpcode::kWhile:
407 // TODO(b/32495713): We aren't checking the condition and body
408 // computations themselves.
409 case HloOpcode::kScatter:
410 // TODO(b/32495713): We aren't checking the embedded computation in
411 // Scatter.
412 case HloOpcode::kSend:
413 case HloOpcode::kRecv:
414 case HloOpcode::kParameter:
415 *is_constant = false;
416 break;
417 case HloOpcode::kGetTupleElement: {
418 const HloInstructionProto& operand_instr =
419 *(LookUpInstructionByHandle(instr.operand_ids(0)).ValueOrDie());
420 if (HloOpcodeString(HloOpcode::kTuple) == operand_instr.opcode()) {
421 IsConstantVisitor(operand_instr.operand_ids(instr.tuple_index()),
422 depth + 1, visited, is_constant);
423 } else {
424 for (const int64_t operand_id : instr.operand_ids()) {
425 IsConstantVisitor(operand_id, depth + 1, visited, is_constant);
426 }
427 }
428 }
429 }
430 if (VLOG_IS_ON(1) && !*is_constant) {
431 VLOG(1) << indent << "Non-constant: ";
432 for (const auto& l : absl::StrSplit(to_print.DebugString(), '\n')) {
433 VLOG(1) << indent << l;
434 }
435 }
436 visited->insert(op_handle);
437 }
438
SetDynamicBinding(int64_t dynamic_size_param_num,ShapeIndex dynamic_size_param_index,int64_t target_param_num,ShapeIndex target_param_index,int64_t target_dim_num)439 Status XlaBuilder::SetDynamicBinding(int64_t dynamic_size_param_num,
440 ShapeIndex dynamic_size_param_index,
441 int64_t target_param_num,
442 ShapeIndex target_param_index,
443 int64_t target_dim_num) {
444 bool param_exists = false;
445 for (size_t index = 0; index < instructions_.size(); ++index) {
446 HloInstructionProto& instr = instructions_[index];
447 if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
448 instr.parameter_number() == target_param_num) {
449 param_exists = true;
450 Shape param_shape(instr.shape());
451 Shape* param_shape_ptr = ¶m_shape;
452 for (int64_t index : target_param_index) {
453 param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index);
454 }
455 param_shape_ptr->set_dynamic_dimension(target_dim_num,
456 /*is_dynamic=*/true);
457 *instr.mutable_shape() = param_shape.ToProto();
458 instruction_shapes_[index] =
459 std::make_unique<Shape>(std::move(param_shape));
460 }
461 }
462 if (!param_exists) {
463 return InvalidArgument(
464 "Asked to mark parameter %lld as dynamic sized parameter, but the "
465 "doesn't exists",
466 target_param_num);
467 }
468
469 TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind(
470 DynamicParameterBinding::DynamicParameter{dynamic_size_param_num,
471 dynamic_size_param_index},
472 DynamicParameterBinding::DynamicDimension{
473 target_param_num, target_param_index, target_dim_num}));
474 return OkStatus();
475 }
476
SetInstructionFrontendAttribute(const XlaOp op,std::string attribute,std::string value)477 Status XlaBuilder::SetInstructionFrontendAttribute(const XlaOp op,
478 std::string attribute,
479 std::string value) {
480 TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op));
481 auto* frontend_attributes = instr_proto->mutable_frontend_attributes();
482 (*frontend_attributes->mutable_map())[attribute] = std::move(value);
483 return OkStatus();
484 }
485
BuildAndNoteError()486 XlaComputation XlaBuilder::BuildAndNoteError() {
487 DCHECK(parent_builder_ != nullptr);
488 auto build_status = Build();
489 if (!build_status.ok()) {
490 parent_builder_->ReportError(
491 AddStatus(build_status.status(), absl::StrCat("error from: ", name_)));
492 return {};
493 }
494 return std::move(build_status).value();
495 }
496
GetCurrentStatus() const497 Status XlaBuilder::GetCurrentStatus() const {
498 if (!first_error_.ok()) {
499 std::string backtrace;
500 first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
501 return AppendStatus(first_error_, backtrace);
502 }
503 return OkStatus();
504 }
505
Build(bool remove_dynamic_dimensions)506 StatusOr<XlaComputation> XlaBuilder::Build(bool remove_dynamic_dimensions) {
507 TF_RETURN_IF_ERROR(GetCurrentStatus());
508 return Build(instructions_.back().id(), remove_dynamic_dimensions);
509 }
510
Build(XlaOp root,bool remove_dynamic_dimensions)511 StatusOr<XlaComputation> XlaBuilder::Build(XlaOp root,
512 bool remove_dynamic_dimensions) {
513 if (root.builder_ != this) {
514 return InvalidArgument("Given root operation is not in this computation.");
515 }
516 return Build(root.handle(), remove_dynamic_dimensions);
517 }
518
Build(int64_t root_id,bool remove_dynamic_dimensions)519 StatusOr<XlaComputation> XlaBuilder::Build(int64_t root_id,
520 bool remove_dynamic_dimensions) {
521 TF_RETURN_IF_ERROR(GetCurrentStatus());
522
523 // TODO(b/121223198): XLA backend cannot handle dynamic dimensions yet, remove
524 // all dynamic dimensions before building xla program until we have support in
525 // the backend.
526 if (remove_dynamic_dimensions) {
527 std::function<void(Shape*)> remove_dynamic_dimension = [&](Shape* shape) {
528 if (shape->tuple_shapes_size() != 0) {
529 for (int i = 0; i < shape->tuple_shapes_size(); ++i) {
530 remove_dynamic_dimension(shape->mutable_tuple_shapes(i));
531 }
532 }
533 for (int64_t i = 0; i < shape->dimensions_size(); ++i) {
534 shape->set_dynamic_dimension(i, false);
535 }
536 };
537 for (size_t index = 0; index < instructions_.size(); ++index) {
538 remove_dynamic_dimension(instruction_shapes_[index].get());
539 *instructions_[index].mutable_shape() =
540 instruction_shapes_[index]->ToProto();
541 }
542 }
543
544 HloComputationProto entry;
545 SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId());
546 TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id));
547 *entry.mutable_program_shape() = program_shape.ToProto();
548 entry.set_root_id(root_id);
549
550 for (auto& instruction : instructions_) {
551 // Ensures that the instruction names are unique among the whole graph.
552 instruction.set_name(
553 GetFullName(instruction.name(), kNameSeparator, instruction.id()));
554 entry.add_instructions()->Swap(&instruction);
555 }
556
557 XlaComputation computation(entry.id());
558 HloModuleProto* module = computation.mutable_proto();
559 module->set_name(entry.name());
560 module->set_id(entry.id());
561 module->set_entry_computation_name(entry.name());
562 module->set_entry_computation_id(entry.id());
563 *module->mutable_host_program_shape() = entry.program_shape();
564 for (auto& e : embedded_) {
565 module->add_computations()->Swap(&e.second);
566 }
567 module->add_computations()->Swap(&entry);
568 if (!input_output_aliases_.empty()) {
569 TF_RETURN_IF_ERROR(
570 PopulateInputOutputAlias(module, program_shape, input_output_aliases_));
571 }
572 *(module->mutable_dynamic_parameter_binding()) =
573 dynamic_parameter_binding_.ToProto();
574
575 // Clear data held by this builder.
576 this->instructions_.clear();
577 this->instruction_shapes_.clear();
578 this->handle_to_index_.clear();
579 this->embedded_.clear();
580 this->parameter_numbers_.clear();
581
582 return std::move(computation);
583 }
584
PopulateInputOutputAlias(HloModuleProto * module,const ProgramShape & program_shape,const std::vector<InputOutputAlias> & input_output_aliases)585 /* static */ Status XlaBuilder::PopulateInputOutputAlias(
586 HloModuleProto* module, const ProgramShape& program_shape,
587 const std::vector<InputOutputAlias>& input_output_aliases) {
588 HloInputOutputAliasConfig config(program_shape.result());
589 for (auto& alias : input_output_aliases) {
590 // The HloInputOutputAliasConfig does not do parameter validation as it only
591 // carries the result shape. Maybe it should be constructed with a
592 // ProgramShape to allow full validation. We will still get an error when
593 // trying to compile the HLO module, but would be better to have validation
594 // at this stage.
595 if (alias.param_number >= program_shape.parameters_size()) {
596 return InvalidArgument("Invalid parameter number %ld (total %ld)",
597 alias.param_number,
598 program_shape.parameters_size());
599 }
600 const Shape& parameter_shape = program_shape.parameters(alias.param_number);
601 if (!ShapeUtil::IndexIsValid(parameter_shape, alias.param_index)) {
602 return InvalidArgument("Invalid parameter %ld index: %s",
603 alias.param_number,
604 alias.param_index.ToString().c_str());
605 }
606 TF_RETURN_IF_ERROR(config.SetUpAlias(alias.output_index, alias.param_number,
607 alias.param_index, alias.kind));
608 }
609 *module->mutable_input_output_alias() = config.ToProto();
610 return OkStatus();
611 }
612
InDimBroadcast(const Shape & shape,XlaOp operand,absl::Span<const int64_t> broadcast_dimensions)613 StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
614 const Shape& shape, XlaOp operand,
615 absl::Span<const int64_t> broadcast_dimensions) {
616 TF_RETURN_IF_ERROR(first_error_);
617
618 HloInstructionProto instr;
619 *instr.mutable_shape() = shape.ToProto();
620 for (int64_t dim : broadcast_dimensions) {
621 instr.add_dimensions(dim);
622 }
623
624 return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand});
625 }
626
AddBroadcastSequence(const Shape & output_shape,XlaOp operand)627 StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
628 XlaOp operand) {
629 TF_RETURN_IF_ERROR(first_error_);
630
631 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
632
633 CHECK(ShapeUtil::IsScalar(*operand_shape) ||
634 operand_shape->rank() == output_shape.rank());
635 Shape broadcast_shape =
636 ShapeUtil::ChangeElementType(output_shape, operand_shape->element_type());
637
638 // Do explicit broadcast for scalar.
639 if (ShapeUtil::IsScalar(*operand_shape)) {
640 return InDimBroadcast(broadcast_shape, operand, {});
641 }
642
643 // Do explicit broadcast for degenerate broadcast.
644 std::vector<int64_t> broadcast_dimensions;
645 std::vector<int64_t> reshaped_dimensions;
646 for (int i = 0; i < operand_shape->rank(); i++) {
647 if (operand_shape->dimensions(i) == output_shape.dimensions(i)) {
648 broadcast_dimensions.push_back(i);
649 reshaped_dimensions.push_back(operand_shape->dimensions(i));
650 } else {
651 TF_RET_CHECK(operand_shape->dimensions(i) == 1)
652 << "An explicit broadcast sequence requires the broadcasted "
653 "dimensions to be trivial; operand shape: "
654 << *operand_shape << "; output_shape: " << output_shape;
655 }
656 }
657
658 Shape reshaped_shape =
659 ShapeUtil::MakeShape(operand_shape->element_type(), reshaped_dimensions);
660
661 std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
662 ShapeUtil::DimensionsUnmodifiedByReshape(*operand_shape, reshaped_shape);
663
664 for (auto& unmodified : unmodified_dims) {
665 if (operand_shape->is_dynamic_dimension(unmodified.first)) {
666 reshaped_shape.set_dynamic_dimension(unmodified.second, true);
667 }
668 }
669
670 // Eliminate the size one dimensions.
671 TF_ASSIGN_OR_RETURN(
672 XlaOp reshaped_operand,
673 ReshapeInternal(reshaped_shape, operand, /*inferred_dimension=*/-1));
674 // Broadcast 'reshape' up to the larger size.
675 return InDimBroadcast(broadcast_shape, reshaped_operand,
676 broadcast_dimensions);
677 }
678
UnaryOp(HloOpcode unop,XlaOp operand)679 XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
680 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
681 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
682 TF_ASSIGN_OR_RETURN(
683 Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape));
684 return AddOpWithShape(unop, shape, {operand});
685 });
686 }
687
BinaryOp(HloOpcode binop,XlaOp lhs,XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions,std::optional<ComparisonDirection> direction,std::optional<Comparison::Type> type)688 XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
689 absl::Span<const int64_t> broadcast_dimensions,
690 std::optional<ComparisonDirection> direction,
691 std::optional<Comparison::Type> type) {
692 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
693 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
694 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
695 TF_ASSIGN_OR_RETURN(
696 Shape shape, ShapeInference::InferBinaryOpShape(
697 binop, *lhs_shape, *rhs_shape, broadcast_dimensions));
698
699 const int64_t lhs_rank = lhs_shape->rank();
700 const int64_t rhs_rank = rhs_shape->rank();
701
702 XlaOp updated_lhs = lhs;
703 XlaOp updated_rhs = rhs;
704 if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) {
705 const bool should_broadcast_lhs = lhs_rank < rhs_rank;
706 XlaOp from = should_broadcast_lhs ? lhs : rhs;
707 const Shape& from_shape = should_broadcast_lhs ? *lhs_shape : *rhs_shape;
708
709 std::vector<int64_t> to_size;
710 std::vector<bool> to_size_is_dynamic;
711 const auto rank = shape.rank();
712 to_size.reserve(rank);
713 to_size_is_dynamic.reserve(rank);
714 for (int i = 0; i < rank; i++) {
715 to_size.push_back(shape.dimensions(i));
716 to_size_is_dynamic.push_back(shape.is_dynamic_dimension(i));
717 }
718 for (int64_t from_dim = 0; from_dim < from_shape.rank(); from_dim++) {
719 int64_t to_dim = broadcast_dimensions[from_dim];
720 to_size[to_dim] = from_shape.dimensions(from_dim);
721 to_size_is_dynamic[to_dim] = from_shape.is_dynamic_dimension(from_dim);
722 }
723
724 const Shape& broadcasted_shape = ShapeUtil::MakeShape(
725 from_shape.element_type(), to_size, to_size_is_dynamic);
726 TF_ASSIGN_OR_RETURN(
727 XlaOp broadcasted_operand,
728 InDimBroadcast(broadcasted_shape, from, broadcast_dimensions));
729
730 updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs;
731 updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs;
732 }
733
734 TF_ASSIGN_OR_RETURN(const Shape* updated_lhs_shape,
735 GetShapePtr(updated_lhs));
736 if (!ShapeUtil::SameDimensions(shape, *updated_lhs_shape)) {
737 TF_ASSIGN_OR_RETURN(updated_lhs,
738 AddBroadcastSequence(shape, updated_lhs));
739 }
740 TF_ASSIGN_OR_RETURN(const Shape* updated_rhs_shape,
741 GetShapePtr(updated_rhs));
742 if (!ShapeUtil::SameDimensions(shape, *updated_rhs_shape)) {
743 TF_ASSIGN_OR_RETURN(updated_rhs,
744 AddBroadcastSequence(shape, updated_rhs));
745 }
746
747 if (binop == HloOpcode::kCompare) {
748 if (!direction.has_value()) {
749 return InvalidArgument(
750 "kCompare expects a ComparisonDirection, but none provided.");
751 }
752 if (type == std::nullopt) {
753 return Compare(shape, updated_lhs, updated_rhs, *direction);
754 } else {
755 return Compare(shape, updated_lhs, updated_rhs, *direction, *type);
756 }
757 }
758
759 if (direction.has_value()) {
760 return InvalidArgument(
761 "A comparison direction is provided for a non-compare opcode: %s.",
762 HloOpcodeString(binop));
763 }
764 return BinaryOpNoBroadcast(binop, shape, updated_lhs, updated_rhs);
765 });
766 }
767
BinaryOpNoBroadcast(HloOpcode binop,const Shape & shape,XlaOp lhs,XlaOp rhs)768 XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
769 XlaOp lhs, XlaOp rhs) {
770 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
771 HloInstructionProto instr;
772 *instr.mutable_shape() = shape.ToProto();
773 return AddInstruction(std::move(instr), binop, {lhs, rhs});
774 });
775 }
776
Compare(const Shape & shape,XlaOp lhs,XlaOp rhs,ComparisonDirection direction)777 StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
778 ComparisonDirection direction) {
779 TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(lhs));
780 return Compare(
781 shape, lhs, rhs, direction,
782 Comparison::DefaultComparisonType(operand_shape.element_type()));
783 }
784
Compare(const Shape & shape,XlaOp lhs,XlaOp rhs,ComparisonDirection direction,Comparison::Type type)785 StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
786 ComparisonDirection direction,
787 Comparison::Type type) {
788 HloInstructionProto instr;
789 instr.set_comparison_direction(ComparisonDirectionToString(direction));
790 instr.set_comparison_type(ComparisonTypeToString(type));
791 *instr.mutable_shape() = shape.ToProto();
792 return AddInstruction(std::move(instr), HloOpcode::kCompare, {lhs, rhs});
793 }
794
TernaryOp(HloOpcode triop,XlaOp lhs,XlaOp rhs,XlaOp ehs)795 XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) {
796 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
797 XlaOp updated_lhs = lhs;
798 XlaOp updated_rhs = rhs;
799 XlaOp updated_ehs = ehs;
800 // The client API supports implicit broadcast for kSelect and kClamp, but
801 // XLA does not support implicit broadcast. Make implicit broadcast explicit
802 // and update the operands.
803 if (triop == HloOpcode::kSelect || triop == HloOpcode::kClamp) {
804 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
805 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
806 TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(ehs));
807
808 std::optional<Shape> non_scalar_shape;
809 for (const Shape* shape : {lhs_shape, rhs_shape, ehs_shape}) {
810 if (shape->IsArray() && shape->rank() != 0) {
811 if (non_scalar_shape.has_value()) {
812 // TODO(jpienaar): The case where we need to compute the broadcasted
813 // shape by considering multiple of the shapes is not implemented.
814 // Consider reusing getBroadcastedType from mlir/Dialect/Traits.h.
815 TF_RET_CHECK(non_scalar_shape.value().dimensions() ==
816 shape->dimensions())
817 << "Unimplemented implicit broadcast.";
818 } else {
819 non_scalar_shape = *shape;
820 }
821 }
822 }
823 if (non_scalar_shape.has_value()) {
824 if (ShapeUtil::IsScalar(*lhs_shape)) {
825 TF_ASSIGN_OR_RETURN(updated_lhs,
826 AddBroadcastSequence(*non_scalar_shape, lhs));
827 }
828 if (ShapeUtil::IsScalar(*rhs_shape)) {
829 TF_ASSIGN_OR_RETURN(updated_rhs,
830 AddBroadcastSequence(*non_scalar_shape, rhs));
831 }
832 if (ShapeUtil::IsScalar(*ehs_shape)) {
833 TF_ASSIGN_OR_RETURN(updated_ehs,
834 AddBroadcastSequence(*non_scalar_shape, ehs));
835 }
836 }
837 }
838
839 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(updated_lhs));
840 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(updated_rhs));
841 TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(updated_ehs));
842 StatusOr<const Shape> status_or_shape = ShapeInference::InferTernaryOpShape(
843 triop, *lhs_shape, *rhs_shape, *ehs_shape);
844 if (!status_or_shape.status().ok()) {
845 return InvalidArgument(
846 "%s Input scalar shapes may have been changed to non-scalar shapes.",
847 status_or_shape.status().error_message());
848 }
849
850 return AddOpWithShape(triop, status_or_shape.ValueOrDie(),
851 {updated_lhs, updated_rhs, updated_ehs});
852 });
853 }
854
ConstantLiteral(const LiteralSlice & literal)855 XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
856 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
857 if (literal.shape().IsArray() && literal.element_count() > 1 &&
858 literal.IsAllFirst()) {
859 Literal scalar = LiteralUtil::GetFirstScalarLiteral(literal);
860 HloInstructionProto instr;
861 *instr.mutable_shape() = scalar.shape().ToProto();
862 *instr.mutable_literal() = scalar.ToProto();
863 TF_ASSIGN_OR_RETURN(
864 XlaOp scalar_op,
865 AddInstruction(std::move(instr), HloOpcode::kConstant));
866 return Broadcast(scalar_op, literal.shape().dimensions());
867 } else {
868 HloInstructionProto instr;
869 *instr.mutable_shape() = literal.shape().ToProto();
870 *instr.mutable_literal() = literal.ToProto();
871 return AddInstruction(std::move(instr), HloOpcode::kConstant);
872 }
873 });
874 }
875
Iota(const Shape & shape,int64_t iota_dimension)876 XlaOp XlaBuilder::Iota(const Shape& shape, int64_t iota_dimension) {
877 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
878 HloInstructionProto instr;
879 *instr.mutable_shape() = shape.ToProto();
880 instr.add_dimensions(iota_dimension);
881 return AddInstruction(std::move(instr), HloOpcode::kIota);
882 });
883 }
884
Iota(PrimitiveType type,int64_t size)885 XlaOp XlaBuilder::Iota(PrimitiveType type, int64_t size) {
886 return Iota(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0);
887 }
888
Call(const XlaComputation & computation,absl::Span<const XlaOp> operands)889 XlaOp XlaBuilder::Call(const XlaComputation& computation,
890 absl::Span<const XlaOp> operands) {
891 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
892 HloInstructionProto instr;
893 std::vector<const Shape*> operand_shape_ptrs;
894 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
895 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
896 [](const Shape& shape) { return &shape; });
897 TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
898 computation.GetProgramShape());
899 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape(
900 operand_shape_ptrs,
901 /*to_apply=*/called_program_shape));
902 *instr.mutable_shape() = shape.ToProto();
903
904 AddCalledComputation(computation, &instr);
905
906 return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
907 });
908 }
909
Parameter(int64_t parameter_number,const Shape & shape,const std::string & name,const std::vector<bool> & replicated_at_leaf_buffers)910 XlaOp XlaBuilder::Parameter(
911 int64_t parameter_number, const Shape& shape, const std::string& name,
912 const std::vector<bool>& replicated_at_leaf_buffers) {
913 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
914 HloInstructionProto instr;
915 if (!parameter_numbers_.insert(parameter_number).second) {
916 return InvalidArgument("parameter %d already registered",
917 parameter_number);
918 }
919 instr.set_parameter_number(parameter_number);
920 instr.set_name(name);
921 *instr.mutable_shape() = shape.ToProto();
922 if (!replicated_at_leaf_buffers.empty()) {
923 auto replication = instr.mutable_parameter_replication();
924 for (bool replicated : replicated_at_leaf_buffers) {
925 replication->add_replicated_at_leaf_buffers(replicated);
926 }
927 }
928 return AddInstruction(std::move(instr), HloOpcode::kParameter);
929 });
930 }
931
Broadcast(XlaOp operand,absl::Span<const int64_t> broadcast_sizes)932 XlaOp XlaBuilder::Broadcast(XlaOp operand,
933 absl::Span<const int64_t> broadcast_sizes) {
934 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
935 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
936 TF_ASSIGN_OR_RETURN(
937 const Shape& shape,
938 ShapeInference::InferBroadcastShape(*operand_shape, broadcast_sizes));
939
940 // The client-level broadcast op just appends dimensions on the left (adds
941 // lowest numbered dimensions). The HLO broadcast instruction is more
942 // flexible and can add new dimensions anywhere. The instruction's
943 // dimensions field maps operand dimensions to dimensions in the broadcast
944 // output, so to append dimensions on the left the instruction's dimensions
945 // should just be the n highest dimension numbers of the output shape where
946 // n is the number of input dimensions.
947 const int64_t operand_rank = operand_shape->rank();
948 std::vector<int64_t> dimensions(operand_rank);
949 for (int i = 0; i < operand_rank; ++i) {
950 dimensions[i] = i + shape.rank() - operand_rank;
951 }
952 return InDimBroadcast(shape, operand, dimensions);
953 });
954 }
955
BroadcastInDim(XlaOp operand,const absl::Span<const int64_t> out_dim_size,const absl::Span<const int64_t> broadcast_dimensions)956 XlaOp XlaBuilder::BroadcastInDim(
957 XlaOp operand, const absl::Span<const int64_t> out_dim_size,
958 const absl::Span<const int64_t> broadcast_dimensions) {
959 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
960 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
961 // Output shape, in the case of degenerate broadcast, the out_dim_size is
962 // not necessarily the same as the dimension sizes of the output shape.
963 TF_ASSIGN_OR_RETURN(auto output_shape,
964 ShapeUtil::MakeValidatedShape(
965 operand_shape->element_type(), out_dim_size));
966 int64_t broadcast_rank = broadcast_dimensions.size();
967 if (operand_shape->rank() != broadcast_rank) {
968 return InvalidArgument(
969 "Size of broadcast_dimensions has to match operand's rank; operand "
970 "rank: %lld, size of broadcast_dimensions %u.",
971 operand_shape->rank(), broadcast_dimensions.size());
972 }
973 for (int i = 0; i < broadcast_rank; i++) {
974 const int64_t num_dims = out_dim_size.size();
975 if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] > num_dims) {
976 return InvalidArgument("Broadcast dimension %lld is out of bound",
977 broadcast_dimensions[i]);
978 }
979 output_shape.set_dynamic_dimension(
980 broadcast_dimensions[i], operand_shape->is_dynamic_dimension(i));
981 }
982
983 TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape(
984 *operand_shape, output_shape, broadcast_dimensions)
985 .status());
986 std::vector<int64_t> in_dim_size(out_dim_size.begin(), out_dim_size.end());
987 for (int i = 0; i < broadcast_rank; i++) {
988 in_dim_size[broadcast_dimensions[i]] = operand_shape->dimensions(i);
989 }
990 const auto& in_dim_shape =
991 ShapeUtil::MakeShape(operand_shape->element_type(), in_dim_size);
992 TF_ASSIGN_OR_RETURN(
993 XlaOp in_dim_broadcast,
994 InDimBroadcast(in_dim_shape, operand, broadcast_dimensions));
995
996 // If broadcast is not degenerate, return broadcasted result.
997 if (ShapeUtil::Equal(in_dim_shape, output_shape)) {
998 return in_dim_broadcast;
999 }
1000
1001 // Otherwise handle degenerate broadcast case.
1002 return AddBroadcastSequence(output_shape, in_dim_broadcast);
1003 });
1004 }
1005
ReshapeInternal(const Shape & shape,XlaOp operand,int64_t inferred_dimension)1006 StatusOr<XlaOp> XlaBuilder::ReshapeInternal(const Shape& shape, XlaOp operand,
1007 int64_t inferred_dimension) {
1008 TF_RETURN_IF_ERROR(first_error_);
1009
1010 HloInstructionProto instr;
1011 *instr.mutable_shape() = shape.ToProto();
1012 if (inferred_dimension != -1) {
1013 instr.add_dimensions(inferred_dimension);
1014 }
1015 return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand});
1016 }
1017
Slice(XlaOp operand,absl::Span<const int64_t> start_indices,absl::Span<const int64_t> limit_indices,absl::Span<const int64_t> strides)1018 XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span<const int64_t> start_indices,
1019 absl::Span<const int64_t> limit_indices,
1020 absl::Span<const int64_t> strides) {
1021 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1022 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1023 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSliceShape(
1024 *operand_shape, start_indices,
1025 limit_indices, strides));
1026 return SliceInternal(shape, operand, start_indices, limit_indices, strides);
1027 });
1028 }
1029
SliceInternal(const Shape & shape,XlaOp operand,absl::Span<const int64_t> start_indices,absl::Span<const int64_t> limit_indices,absl::Span<const int64_t> strides)1030 StatusOr<XlaOp> XlaBuilder::SliceInternal(
1031 const Shape& shape, XlaOp operand, absl::Span<const int64_t> start_indices,
1032 absl::Span<const int64_t> limit_indices,
1033 absl::Span<const int64_t> strides) {
1034 HloInstructionProto instr;
1035 *instr.mutable_shape() = shape.ToProto();
1036 for (int i = 0, end = start_indices.size(); i < end; i++) {
1037 auto* slice_config = instr.add_slice_dimensions();
1038 slice_config->set_start(start_indices[i]);
1039 slice_config->set_limit(limit_indices[i]);
1040 slice_config->set_stride(strides[i]);
1041 }
1042 return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
1043 }
1044
SliceInDim(XlaOp operand,int64_t start_index,int64_t limit_index,int64_t stride,int64_t dimno)1045 XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64_t start_index,
1046 int64_t limit_index, int64_t stride,
1047 int64_t dimno) {
1048 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1049 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
1050 std::vector<int64_t> starts(shape->rank(), 0);
1051 std::vector<int64_t> limits(shape->dimensions().begin(),
1052 shape->dimensions().end());
1053 std::vector<int64_t> strides(shape->rank(), 1);
1054 starts[dimno] = start_index;
1055 limits[dimno] = limit_index;
1056 strides[dimno] = stride;
1057 return Slice(operand, starts, limits, strides);
1058 });
1059 }
1060
DynamicSlice(XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64_t> slice_sizes)1061 XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
1062 absl::Span<const XlaOp> start_indices,
1063 absl::Span<const int64_t> slice_sizes) {
1064 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1065 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1066 std::vector<const Shape*> start_indices_shape_ptrs;
1067 TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
1068 GetOperandShapes(start_indices));
1069 absl::c_transform(start_indices_shapes,
1070 std::back_inserter(start_indices_shape_ptrs),
1071 [](const Shape& shape) { return &shape; });
1072 TF_ASSIGN_OR_RETURN(Shape shape,
1073 ShapeInference::InferDynamicSliceShape(
1074 *operand_shape, start_indices_shapes, slice_sizes));
1075 return DynamicSliceInternal(shape, operand, start_indices, slice_sizes);
1076 });
1077 }
1078
DynamicSliceInternal(const Shape & shape,XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64_t> slice_sizes)1079 StatusOr<XlaOp> XlaBuilder::DynamicSliceInternal(
1080 const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
1081 absl::Span<const int64_t> slice_sizes) {
1082 HloInstructionProto instr;
1083 *instr.mutable_shape() = shape.ToProto();
1084
1085 for (int64_t size : slice_sizes) {
1086 instr.add_dynamic_slice_sizes(size);
1087 }
1088
1089 std::vector<XlaOp> operands = {operand};
1090 operands.insert(operands.end(), start_indices.begin(), start_indices.end());
1091 return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands);
1092 }
1093
DynamicUpdateSlice(XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)1094 XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
1095 absl::Span<const XlaOp> start_indices) {
1096 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1097 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1098 TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
1099 std::vector<const Shape*> start_indices_shape_ptrs;
1100 TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
1101 GetOperandShapes(start_indices));
1102 absl::c_transform(start_indices_shapes,
1103 std::back_inserter(start_indices_shape_ptrs),
1104 [](const Shape& shape) { return &shape; });
1105 TF_ASSIGN_OR_RETURN(
1106 Shape shape, ShapeInference::InferDynamicUpdateSliceShape(
1107 *operand_shape, *update_shape, start_indices_shapes));
1108 return DynamicUpdateSliceInternal(shape, operand, update, start_indices);
1109 });
1110 }
1111
DynamicUpdateSliceInternal(const Shape & shape,XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)1112 StatusOr<XlaOp> XlaBuilder::DynamicUpdateSliceInternal(
1113 const Shape& shape, XlaOp operand, XlaOp update,
1114 absl::Span<const XlaOp> start_indices) {
1115 HloInstructionProto instr;
1116 *instr.mutable_shape() = shape.ToProto();
1117
1118 std::vector<XlaOp> operands = {operand, update};
1119 operands.insert(operands.end(), start_indices.begin(), start_indices.end());
1120 return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
1121 operands);
1122 }
1123
ConcatInDim(absl::Span<const XlaOp> operands,int64_t dimension)1124 XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
1125 int64_t dimension) {
1126 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1127 std::vector<const Shape*> operand_shape_ptrs;
1128 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
1129 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1130 [](const Shape& shape) { return &shape; });
1131 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape(
1132 operand_shape_ptrs, dimension));
1133 return ConcatInDimInternal(shape, operands, dimension);
1134 });
1135 }
1136
ConcatInDimInternal(const Shape & shape,absl::Span<const XlaOp> operands,int64_t dimension)1137 StatusOr<XlaOp> XlaBuilder::ConcatInDimInternal(
1138 const Shape& shape, absl::Span<const XlaOp> operands, int64_t dimension) {
1139 HloInstructionProto instr;
1140 *instr.mutable_shape() = shape.ToProto();
1141
1142 instr.add_dimensions(dimension);
1143
1144 return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
1145 }
1146
Pad(XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)1147 XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value,
1148 const PaddingConfig& padding_config) {
1149 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1150 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1151 TF_ASSIGN_OR_RETURN(const Shape* padding_value_shape,
1152 GetShapePtr(padding_value));
1153 TF_ASSIGN_OR_RETURN(
1154 Shape shape, ShapeInference::InferPadShape(
1155 *operand_shape, *padding_value_shape, padding_config));
1156 return PadInternal(shape, operand, padding_value, padding_config);
1157 });
1158 }
1159
PadInDim(XlaOp operand,XlaOp padding_value,int64_t dimno,int64_t pad_lo,int64_t pad_hi)1160 XlaOp XlaBuilder::PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
1161 int64_t pad_lo, int64_t pad_hi) {
1162 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1163 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
1164 PaddingConfig padding_config = MakeNoPaddingConfig(shape->rank());
1165 auto* dims = padding_config.mutable_dimensions(dimno);
1166 dims->set_edge_padding_low(pad_lo);
1167 dims->set_edge_padding_high(pad_hi);
1168 return Pad(operand, padding_value, padding_config);
1169 });
1170 }
1171
PadInternal(const Shape & shape,XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)1172 StatusOr<XlaOp> XlaBuilder::PadInternal(const Shape& shape, XlaOp operand,
1173 XlaOp padding_value,
1174 const PaddingConfig& padding_config) {
1175 HloInstructionProto instr;
1176 *instr.mutable_shape() = shape.ToProto();
1177 *instr.mutable_padding_config() = padding_config;
1178 return AddInstruction(std::move(instr), HloOpcode::kPad,
1179 {operand, padding_value});
1180 }
1181
Reshape(XlaOp operand,absl::Span<const int64_t> dimensions,absl::Span<const int64_t> new_sizes,int64_t inferred_dimension)1182 XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64_t> dimensions,
1183 absl::Span<const int64_t> new_sizes,
1184 int64_t inferred_dimension) {
1185 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1186 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1187 TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferReshapeShape(
1188 *operand_shape, dimensions,
1189 new_sizes, inferred_dimension));
1190 XlaOp transposed = IsIdentityPermutation(dimensions)
1191 ? operand
1192 : Transpose(operand, dimensions);
1193 return ReshapeInternal(shape, transposed, inferred_dimension);
1194 });
1195 }
1196
Reshape(XlaOp operand,absl::Span<const int64_t> new_sizes,int64_t inferred_dimension)1197 XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64_t> new_sizes,
1198 int64_t inferred_dimension) {
1199 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1200 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
1201 std::vector<int64_t> dimensions(shape->dimensions_size());
1202 std::iota(dimensions.begin(), dimensions.end(), 0);
1203 return Reshape(operand, dimensions, new_sizes, inferred_dimension);
1204 });
1205 }
1206
Reshape(const Shape & shape,XlaOp operand,int64_t inferred_dimension)1207 XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand,
1208 int64_t inferred_dimension) {
1209 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1210 return ReshapeInternal(shape, operand, inferred_dimension);
1211 });
1212 }
1213
DynamicReshape(XlaOp operand,absl::Span<const XlaOp> dim_sizes,absl::Span<const int64_t> new_size_bounds,const std::vector<bool> & dims_are_dynamic)1214 XlaOp XlaBuilder::DynamicReshape(XlaOp operand,
1215 absl::Span<const XlaOp> dim_sizes,
1216 absl::Span<const int64_t> new_size_bounds,
1217 const std::vector<bool>& dims_are_dynamic) {
1218 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1219 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1220 std::vector<const Shape*> dim_size_shape_ptrs;
1221 TF_ASSIGN_OR_RETURN(const auto& dim_size_shapes,
1222 GetOperandShapes(dim_sizes));
1223
1224 absl::c_transform(dim_size_shapes, std::back_inserter(dim_size_shape_ptrs),
1225 [](const Shape& shape) { return &shape; });
1226 TF_ASSIGN_OR_RETURN(const Shape shape,
1227 ShapeInference::InferDynamicReshapeShape(
1228 *operand_shape, dim_size_shape_ptrs,
1229 new_size_bounds, dims_are_dynamic));
1230 TF_RETURN_IF_ERROR(first_error_);
1231 std::vector<XlaOp> operands;
1232 operands.reserve(1 + dim_sizes.size());
1233 operands.push_back(operand);
1234 for (const XlaOp& dim_size : dim_sizes) {
1235 operands.push_back(dim_size);
1236 }
1237 HloInstructionProto instr;
1238 *instr.mutable_shape() = shape.ToProto();
1239 return AddInstruction(std::move(instr), HloOpcode::kDynamicReshape,
1240 operands);
1241 });
1242 }
1243
Collapse(XlaOp operand,absl::Span<const int64_t> dimensions)1244 XlaOp XlaBuilder::Collapse(XlaOp operand,
1245 absl::Span<const int64_t> dimensions) {
1246 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1247 if (dimensions.size() <= 1) {
1248 // Not collapsing anything, trivially we can return the operand versus
1249 // enqueueing a trivial reshape.
1250 return operand;
1251 }
1252
1253 // Out-of-order collapse is not supported.
1254 // Checks that the collapsed dimensions are in order and consecutive.
1255 for (absl::Span<const int64_t>::size_type i = 1; i < dimensions.size();
1256 ++i) {
1257 if (dimensions[i] - 1 != dimensions[i - 1]) {
1258 return InvalidArgument(
1259 "Collapsed dimensions are not in consecutive order.");
1260 }
1261 }
1262
1263 // Create a new sizes vector from the old shape, replacing the collapsed
1264 // dimensions by the product of their sizes.
1265 TF_ASSIGN_OR_RETURN(const Shape* original_shape, GetShapePtr(operand));
1266
1267 VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape);
1268 VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ",");
1269
1270 std::vector<int64_t> new_sizes;
1271 for (int i = 0; i < original_shape->rank(); ++i) {
1272 if (i <= dimensions.front() || i > dimensions.back()) {
1273 new_sizes.push_back(original_shape->dimensions(i));
1274 } else {
1275 new_sizes.back() *= original_shape->dimensions(i);
1276 }
1277 }
1278
1279 VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]";
1280
1281 return Reshape(operand, new_sizes);
1282 });
1283 }
1284
1285 // Dummy pass-through computation returning it's parameter of shape `shape`.
PassthroughComputation(const Shape & shape)1286 static StatusOr<XlaComputation> PassthroughComputation(const Shape& shape) {
1287 XlaBuilder builder("dummy");
1288 XlaOp out = Parameter(&builder, 0, shape, "p");
1289 return builder.Build(out);
1290 }
1291
Select(XlaOp pred,XlaOp on_true,XlaOp on_false)1292 XlaOp XlaBuilder::Select(XlaOp pred, XlaOp on_true, XlaOp on_false) {
1293 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1294 TF_ASSIGN_OR_RETURN(const Shape* true_shape, GetShapePtr(on_true));
1295 TF_ASSIGN_OR_RETURN(const Shape* false_shape, GetShapePtr(on_false));
1296 TF_RET_CHECK(true_shape->IsTuple() == false_shape->IsTuple());
1297 if (true_shape->IsTuple()) {
1298 TF_ASSIGN_OR_RETURN(XlaComputation passthrough_true,
1299 PassthroughComputation(*true_shape));
1300 TF_ASSIGN_OR_RETURN(XlaComputation passthrough_false,
1301 PassthroughComputation(*false_shape));
1302 return Conditional(pred, on_true, passthrough_true, on_false,
1303 passthrough_false);
1304 }
1305 return TernaryOp(HloOpcode::kSelect, pred, on_true, on_false);
1306 });
1307 }
1308
Tuple(absl::Span<const XlaOp> elements)1309 XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
1310 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1311 std::vector<const Shape*> operand_shape_ptrs;
1312 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
1313 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1314 [](const Shape& shape) { return &shape; });
1315 TF_ASSIGN_OR_RETURN(const Shape shape,
1316 ShapeInference::InferVariadicOpShape(
1317 HloOpcode::kTuple, operand_shape_ptrs));
1318 return TupleInternal(shape, elements);
1319 });
1320 }
1321
TupleInternal(const Shape & shape,absl::Span<const XlaOp> elements)1322 StatusOr<XlaOp> XlaBuilder::TupleInternal(const Shape& shape,
1323 absl::Span<const XlaOp> elements) {
1324 HloInstructionProto instr;
1325 *instr.mutable_shape() = shape.ToProto();
1326 return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
1327 }
1328
GetTupleElement(XlaOp tuple_data,int64_t index)1329 XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64_t index) {
1330 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1331 TF_ASSIGN_OR_RETURN(const Shape* tuple_shape, GetShapePtr(tuple_data));
1332 if (!tuple_shape->IsTuple()) {
1333 return InvalidArgument(
1334 "Operand to GetTupleElement() is not a tuple; got %s",
1335 ShapeUtil::HumanString(*tuple_shape));
1336 }
1337 if (index < 0 || index >= ShapeUtil::TupleElementCount(*tuple_shape)) {
1338 return InvalidArgument(
1339 "GetTupleElement() index (%d) out of range for tuple shape %s", index,
1340 ShapeUtil::HumanString(*tuple_shape));
1341 }
1342 return GetTupleElementInternal(
1343 ShapeUtil::GetTupleElementShape(*tuple_shape, index), tuple_data,
1344 index);
1345 });
1346 }
1347
GetTupleElementInternal(const Shape & shape,XlaOp tuple_data,int64_t index)1348 StatusOr<XlaOp> XlaBuilder::GetTupleElementInternal(const Shape& shape,
1349 XlaOp tuple_data,
1350 int64_t index) {
1351 HloInstructionProto instr;
1352 *instr.mutable_shape() = shape.ToProto();
1353 instr.set_tuple_index(index);
1354 return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
1355 {tuple_data});
1356 }
1357
Dot(XlaOp lhs,XlaOp rhs,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)1358 XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
1359 const PrecisionConfig* precision_config,
1360 std::optional<PrimitiveType> preferred_element_type) {
1361 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1362 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1363
1364 DotDimensionNumbers dimension_numbers;
1365 dimension_numbers.add_lhs_contracting_dimensions(
1366 lhs_shape->dimensions_size() == 1 ? 0 : 1);
1367 dimension_numbers.add_rhs_contracting_dimensions(0);
1368 return DotGeneral(lhs, rhs, dimension_numbers, precision_config,
1369 preferred_element_type);
1370 });
1371 }
1372
DotGeneral(XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)1373 XlaOp XlaBuilder::DotGeneral(
1374 XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
1375 const PrecisionConfig* precision_config,
1376 std::optional<PrimitiveType> preferred_element_type) {
1377 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1378 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1379 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1380 TF_ASSIGN_OR_RETURN(
1381 Shape shape,
1382 ShapeInference::InferDotOpShape(
1383 *lhs_shape, *rhs_shape, dimension_numbers, preferred_element_type));
1384 return DotGeneralInternal(shape, lhs, rhs, dimension_numbers,
1385 precision_config);
1386 });
1387 }
1388
DotGeneralInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config)1389 StatusOr<XlaOp> XlaBuilder::DotGeneralInternal(
1390 const Shape& shape, XlaOp lhs, XlaOp rhs,
1391 const DotDimensionNumbers& dimension_numbers,
1392 const PrecisionConfig* precision_config) {
1393 HloInstructionProto instr;
1394 *instr.mutable_shape() = shape.ToProto();
1395 *instr.mutable_dot_dimension_numbers() = dimension_numbers;
1396 if (precision_config != nullptr) {
1397 *instr.mutable_precision_config() = *precision_config;
1398 }
1399 return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
1400 }
1401
VerifyConvolution(const Shape & lhs_shape,const Shape & rhs_shape,const ConvolutionDimensionNumbers & dimension_numbers) const1402 Status XlaBuilder::VerifyConvolution(
1403 const Shape& lhs_shape, const Shape& rhs_shape,
1404 const ConvolutionDimensionNumbers& dimension_numbers) const {
1405 if (lhs_shape.rank() != rhs_shape.rank()) {
1406 return InvalidArgument(
1407 "Convolution arguments must have same number of "
1408 "dimensions. Got: %s and %s",
1409 ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1410 }
1411 int num_dims = lhs_shape.rank();
1412 if (num_dims < 2) {
1413 return InvalidArgument(
1414 "Convolution expects argument arrays with >= 3 dimensions. "
1415 "Got: %s and %s",
1416 ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1417 }
1418 int num_spatial_dims = num_dims - 2;
1419
1420 const auto check_spatial_dimensions = [&](absl::string_view field_name,
1421 absl::Span<const int64_t> numbers) {
1422 if (numbers.size() != num_spatial_dims) {
1423 return InvalidArgument("Expected %d elements for %s, but got %d.",
1424 num_spatial_dims, field_name, numbers.size());
1425 }
1426 for (int i = 0; i < numbers.size(); ++i) {
1427 if (numbers[i] < 0 || numbers[i] >= num_dims) {
1428 return InvalidArgument("Convolution %s[%d] is out of bounds: %d",
1429 field_name, i, numbers[i]);
1430 }
1431 }
1432 return OkStatus();
1433 };
1434 TF_RETURN_IF_ERROR(
1435 check_spatial_dimensions("input_spatial_dimensions",
1436 dimension_numbers.input_spatial_dimensions()));
1437 TF_RETURN_IF_ERROR(
1438 check_spatial_dimensions("kernel_spatial_dimensions",
1439 dimension_numbers.kernel_spatial_dimensions()));
1440 return check_spatial_dimensions(
1441 "output_spatial_dimensions",
1442 dimension_numbers.output_spatial_dimensions());
1443 }
1444
Conv(XlaOp lhs,XlaOp rhs,absl::Span<const int64_t> window_strides,Padding padding,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)1445 XlaOp XlaBuilder::Conv(XlaOp lhs, XlaOp rhs,
1446 absl::Span<const int64_t> window_strides,
1447 Padding padding, int64_t feature_group_count,
1448 int64_t batch_group_count,
1449 const PrecisionConfig* precision_config,
1450 std::optional<PrimitiveType> preferred_element_type) {
1451 return ConvWithGeneralDimensions(
1452 lhs, rhs, window_strides, padding,
1453 CreateDefaultConvDimensionNumbers(window_strides.size()),
1454 feature_group_count, batch_group_count, precision_config,
1455 preferred_element_type);
1456 }
1457
ConvWithGeneralPadding(XlaOp lhs,XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)1458 XlaOp XlaBuilder::ConvWithGeneralPadding(
1459 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1460 absl::Span<const std::pair<int64_t, int64_t>> padding,
1461 int64_t feature_group_count, int64_t batch_group_count,
1462 const PrecisionConfig* precision_config,
1463 std::optional<PrimitiveType> preferred_element_type) {
1464 return ConvGeneral(lhs, rhs, window_strides, padding,
1465 CreateDefaultConvDimensionNumbers(window_strides.size()),
1466 feature_group_count, batch_group_count, precision_config,
1467 preferred_element_type);
1468 }
1469
ConvWithGeneralDimensions(XlaOp lhs,XlaOp rhs,absl::Span<const int64_t> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)1470 XlaOp XlaBuilder::ConvWithGeneralDimensions(
1471 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1472 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1473 int64_t feature_group_count, int64_t batch_group_count,
1474 const PrecisionConfig* precision_config,
1475 std::optional<PrimitiveType> preferred_element_type) {
1476 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1477 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1478 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1479
1480 TF_RETURN_IF_ERROR(
1481 VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers));
1482
1483 std::vector<int64_t> base_area_dimensions(
1484 dimension_numbers.input_spatial_dimensions_size());
1485 for (std::vector<int64_t>::size_type i = 0; i < base_area_dimensions.size();
1486 ++i) {
1487 base_area_dimensions[i] =
1488 lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i));
1489 }
1490
1491 std::vector<int64_t> window_dimensions(
1492 dimension_numbers.kernel_spatial_dimensions_size());
1493 for (std::vector<int64_t>::size_type i = 0; i < window_dimensions.size();
1494 ++i) {
1495 window_dimensions[i] =
1496 rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1497 }
1498
1499 return ConvGeneral(lhs, rhs, window_strides,
1500 MakePadding(base_area_dimensions, window_dimensions,
1501 window_strides, padding),
1502 dimension_numbers, feature_group_count,
1503 batch_group_count, precision_config,
1504 preferred_element_type);
1505 });
1506 }
1507
ConvGeneral(XlaOp lhs,XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)1508 XlaOp XlaBuilder::ConvGeneral(
1509 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1510 absl::Span<const std::pair<int64_t, int64_t>> padding,
1511 const ConvolutionDimensionNumbers& dimension_numbers,
1512 int64_t feature_group_count, int64_t batch_group_count,
1513 const PrecisionConfig* precision_config,
1514 std::optional<PrimitiveType> preferred_element_type) {
1515 return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
1516 dimension_numbers, feature_group_count,
1517 batch_group_count, precision_config,
1518 preferred_element_type);
1519 }
1520
1521 // TODO(rmcilroy) Ideally window_reversal would be a absl::Span<const bool>
1522 // however this causes an error in pybind11's conversion code. See
1523 // https://github.com/pybind/pybind11_abseil/issues/4 for details.
ConvGeneralDilated(XlaOp lhs,XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type,std::optional<std::vector<bool>> window_reversal)1524 XlaOp XlaBuilder::ConvGeneralDilated(
1525 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1526 absl::Span<const std::pair<int64_t, int64_t>> padding,
1527 absl::Span<const int64_t> lhs_dilation,
1528 absl::Span<const int64_t> rhs_dilation,
1529 const ConvolutionDimensionNumbers& dimension_numbers,
1530 int64_t feature_group_count, int64_t batch_group_count,
1531 const PrecisionConfig* precision_config,
1532 std::optional<PrimitiveType> preferred_element_type,
1533 std::optional<std::vector<bool>> window_reversal) {
1534 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1535 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1536 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1537 TF_RETURN_IF_ERROR(
1538 VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers));
1539
1540 std::vector<int64_t> window_dimensions(
1541 dimension_numbers.kernel_spatial_dimensions_size());
1542 for (std::vector<int64_t>::size_type i = 0; i < window_dimensions.size();
1543 ++i) {
1544 window_dimensions[i] =
1545 rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1546 }
1547
1548 TF_ASSIGN_OR_RETURN(Window window,
1549 ShapeInference::InferWindowFromDimensions(
1550 window_dimensions, window_strides, padding,
1551 lhs_dilation, rhs_dilation, window_reversal));
1552 TF_ASSIGN_OR_RETURN(
1553 Shape shape,
1554 ShapeInference::InferConvolveShape(
1555 *lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
1556 window, dimension_numbers, preferred_element_type));
1557 return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides,
1558 padding, lhs_dilation, rhs_dilation,
1559 dimension_numbers, feature_group_count,
1560 batch_group_count, precision_config);
1561 });
1562 }
1563
DynamicConvInstruction(XlaOp lhs,XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,std::optional<PrimitiveType> preferred_element_type)1564 StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
1565 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1566 absl::Span<const std::pair<int64_t, int64_t>> padding,
1567 absl::Span<const int64_t> lhs_dilation,
1568 absl::Span<const int64_t> rhs_dilation,
1569 const ConvolutionDimensionNumbers& dimension_numbers,
1570 int64_t feature_group_count, int64_t batch_group_count,
1571 const PrecisionConfig* precision_config, PaddingType padding_type,
1572 std::optional<PrimitiveType> preferred_element_type) {
1573 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1574 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1575 std::vector<int64_t> window_dimensions(
1576 dimension_numbers.kernel_spatial_dimensions_size());
1577 for (std::vector<int64_t>::size_type i = 0; i < window_dimensions.size();
1578 ++i) {
1579 window_dimensions[i] =
1580 rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1581 }
1582
1583 TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions(
1584 window_dimensions, window_strides,
1585 padding, lhs_dilation, rhs_dilation));
1586 TF_ASSIGN_OR_RETURN(
1587 Shape shape,
1588 ShapeInference::InferConvolveShape(
1589 *lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
1590 window, dimension_numbers, preferred_element_type));
1591
1592 HloInstructionProto instr;
1593 *instr.mutable_shape() = shape.ToProto();
1594
1595 *instr.mutable_window() = window;
1596 *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
1597 instr.set_feature_group_count(feature_group_count);
1598 instr.set_batch_group_count(batch_group_count);
1599 instr.set_padding_type(padding_type);
1600
1601 if (precision_config != nullptr) {
1602 *instr.mutable_precision_config() = *precision_config;
1603 }
1604 return std::move(instr);
1605 }
1606
DynamicConvInputGrad(XlaOp input_sizes,XlaOp lhs,XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,std::optional<PrimitiveType> preferred_element_type)1607 XlaOp XlaBuilder::DynamicConvInputGrad(
1608 XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
1609 absl::Span<const int64_t> window_strides,
1610 absl::Span<const std::pair<int64_t, int64_t>> padding,
1611 absl::Span<const int64_t> lhs_dilation,
1612 absl::Span<const int64_t> rhs_dilation,
1613 const ConvolutionDimensionNumbers& dimension_numbers,
1614 int64_t feature_group_count, int64_t batch_group_count,
1615 const PrecisionConfig* precision_config, PaddingType padding_type,
1616 std::optional<PrimitiveType> preferred_element_type) {
1617 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1618 TF_ASSIGN_OR_RETURN(
1619 HloInstructionProto instr,
1620 DynamicConvInstruction(
1621 lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
1622 dimension_numbers, feature_group_count, batch_group_count,
1623 precision_config, padding_type, preferred_element_type));
1624
1625 instr.set_custom_call_target("DynamicConvolutionInputGrad");
1626
1627 return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
1628 {input_sizes, lhs, rhs});
1629 });
1630 }
1631
DynamicConvKernelGrad(XlaOp activations,XlaOp gradients,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,std::optional<PrimitiveType> preferred_element_type)1632 XlaOp XlaBuilder::DynamicConvKernelGrad(
1633 XlaOp activations, XlaOp gradients,
1634 absl::Span<const int64_t> window_strides,
1635 absl::Span<const std::pair<int64_t, int64_t>> padding,
1636 absl::Span<const int64_t> lhs_dilation,
1637 absl::Span<const int64_t> rhs_dilation,
1638 const ConvolutionDimensionNumbers& dimension_numbers,
1639 int64_t feature_group_count, int64_t batch_group_count,
1640 const PrecisionConfig* precision_config, PaddingType padding_type,
1641 std::optional<PrimitiveType> preferred_element_type) {
1642 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1643 TF_ASSIGN_OR_RETURN(
1644 HloInstructionProto instr,
1645 DynamicConvInstruction(activations, gradients, window_strides, padding,
1646 lhs_dilation, rhs_dilation, dimension_numbers,
1647 feature_group_count, batch_group_count,
1648 precision_config, padding_type,
1649 preferred_element_type));
1650
1651 instr.set_custom_call_target("DynamicConvolutionKernelGrad");
1652 // The gradient of kernel has kernel shape and shouldn't have any dynamic
1653 // sizes.
1654 instr.mutable_shape()->clear_is_dynamic_dimension();
1655 return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
1656 {activations, gradients});
1657 });
1658 }
1659
DynamicConvForward(XlaOp lhs,XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,std::optional<PrimitiveType> preferred_element_type)1660 XlaOp XlaBuilder::DynamicConvForward(
1661 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1662 absl::Span<const std::pair<int64_t, int64_t>> padding,
1663 absl::Span<const int64_t> lhs_dilation,
1664 absl::Span<const int64_t> rhs_dilation,
1665 const ConvolutionDimensionNumbers& dimension_numbers,
1666 int64_t feature_group_count, int64_t batch_group_count,
1667 const PrecisionConfig* precision_config, PaddingType padding_type,
1668 std::optional<PrimitiveType> preferred_element_type) {
1669 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1670 TF_ASSIGN_OR_RETURN(
1671 HloInstructionProto instr,
1672 DynamicConvInstruction(
1673 lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
1674 dimension_numbers, feature_group_count, batch_group_count,
1675 precision_config, padding_type, preferred_element_type));
1676 instr.set_custom_call_target("DynamicConvolutionForward");
1677
1678 return AddInstruction(std::move(instr), HloOpcode::kCustomCall, {lhs, rhs});
1679 });
1680 }
1681
ConvGeneralDilatedInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const Window & window,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config)1682 StatusOr<XlaOp> XlaBuilder::ConvGeneralDilatedInternal(
1683 const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
1684 absl::Span<const int64_t> window_strides,
1685 absl::Span<const std::pair<int64_t, int64_t>> padding,
1686 absl::Span<const int64_t> lhs_dilation,
1687 absl::Span<const int64_t> rhs_dilation,
1688 const ConvolutionDimensionNumbers& dimension_numbers,
1689 int64_t feature_group_count, int64_t batch_group_count,
1690 const PrecisionConfig* precision_config) {
1691 HloInstructionProto instr;
1692 *instr.mutable_shape() = shape.ToProto();
1693
1694 *instr.mutable_window() = window;
1695 *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
1696 instr.set_feature_group_count(feature_group_count);
1697 instr.set_batch_group_count(batch_group_count);
1698
1699 if (precision_config != nullptr) {
1700 *instr.mutable_precision_config() = *precision_config;
1701 }
1702
1703 return AddInstruction(std::move(instr), HloOpcode::kConvolution, {lhs, rhs});
1704 }
1705
Fft(XlaOp operand,const FftType fft_type,const absl::Span<const int64_t> fft_length)1706 XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type,
1707 const absl::Span<const int64_t> fft_length) {
1708 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1709 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1710 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape(
1711 *operand_shape, fft_type, fft_length));
1712 return FftInternal(shape, operand, fft_type, fft_length);
1713 });
1714 }
1715
FftInternal(const Shape & shape,XlaOp operand,const FftType fft_type,const absl::Span<const int64_t> fft_length)1716 StatusOr<XlaOp> XlaBuilder::FftInternal(
1717 const Shape& shape, XlaOp operand, const FftType fft_type,
1718 const absl::Span<const int64_t> fft_length) {
1719 HloInstructionProto instr;
1720 *instr.mutable_shape() = shape.ToProto();
1721 instr.set_fft_type(fft_type);
1722 for (int64_t i : fft_length) {
1723 instr.add_fft_length(i);
1724 }
1725
1726 return AddInstruction(std::move(instr), HloOpcode::kFft, {operand});
1727 }
1728
TriangularSolveInternal(const Shape & shape,XlaOp a,XlaOp b,TriangularSolveOptions options)1729 StatusOr<XlaOp> XlaBuilder::TriangularSolveInternal(
1730 const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) {
1731 HloInstructionProto instr;
1732 *instr.mutable_triangular_solve_options() = std::move(options);
1733 *instr.mutable_shape() = shape.ToProto();
1734
1735 return AddInstruction(std::move(instr), HloOpcode::kTriangularSolve, {a, b});
1736 }
1737
CholeskyInternal(const Shape & shape,XlaOp a,bool lower)1738 StatusOr<XlaOp> XlaBuilder::CholeskyInternal(const Shape& shape, XlaOp a,
1739 bool lower) {
1740 HloInstructionProto instr;
1741 CholeskyOptions& options = *instr.mutable_cholesky_options();
1742 options.set_lower(lower);
1743 *instr.mutable_shape() = shape.ToProto();
1744
1745 return AddInstruction(std::move(instr), HloOpcode::kCholesky, {a});
1746 }
1747
Infeed(const Shape & shape,const std::string & config)1748 XlaOp XlaBuilder::Infeed(const Shape& shape, const std::string& config) {
1749 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1750 HloInstructionProto instr;
1751 if (!LayoutUtil::HasLayout(shape)) {
1752 return InvalidArgument("Given shape to Infeed must have a layout");
1753 }
1754 const Shape infeed_instruction_shape =
1755 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1756 *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1757 instr.set_infeed_config(config);
1758
1759 if (shape.IsArray() && sharding() &&
1760 sharding()->type() == OpSharding::OTHER) {
1761 // TODO(b/110793772): Support tiled array-shaped infeeds.
1762 return InvalidArgument(
1763 "Tiled sharding is not yet supported for array-shaped infeeds");
1764 }
1765
1766 if (sharding() && sharding()->type() == OpSharding::REPLICATED) {
1767 return InvalidArgument(
1768 "Replicated sharding is not yet supported for infeeds");
1769 }
1770
1771 // Infeed takes a single token operand. Generate the token to pass to the
1772 // infeed.
1773 XlaOp token;
1774 auto make_token = [&]() {
1775 HloInstructionProto token_instr;
1776 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1777 return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {});
1778 };
1779 if (sharding()) {
1780 // Arbitrarily assign token to device 0.
1781 OpSharding sharding = sharding_builder::AssignDevice(0);
1782 XlaScopedShardingAssignment scoped_sharding(this, sharding);
1783 TF_ASSIGN_OR_RETURN(token, make_token());
1784 } else {
1785 TF_ASSIGN_OR_RETURN(token, make_token());
1786 }
1787
1788 // The sharding is set by the client according to the data tuple shape.
1789 // However, the shape of the infeed instruction is a tuple containing the
1790 // data and a token. For tuple sharding type, the sharding must be changed
1791 // to accommodate the token.
1792 XlaOp infeed;
1793 if (sharding() && sharding()->type() == OpSharding::TUPLE) {
1794 // TODO(b/80000000): Remove this when clients have been updated to handle
1795 // tokens.
1796 OpSharding infeed_instruction_sharding = *sharding();
1797 // Arbitrarily assign the token to device 0.
1798 *infeed_instruction_sharding.add_tuple_shardings() =
1799 sharding_builder::AssignDevice(0);
1800 XlaScopedShardingAssignment scoped_sharding(this,
1801 infeed_instruction_sharding);
1802 TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1803 HloOpcode::kInfeed, {token}));
1804 } else {
1805 TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1806 HloOpcode::kInfeed, {token}));
1807 }
1808
1809 // The infeed instruction produces a tuple of the infed data and a token
1810 // type. Return XLA op containing the data.
1811 // TODO(b/80000000): Remove this when clients have been updated to handle
1812 // tokens.
1813 HloInstructionProto infeed_data;
1814 *infeed_data.mutable_shape() = shape.ToProto();
1815 infeed_data.set_tuple_index(0);
1816 return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement,
1817 {infeed});
1818 });
1819 }
1820
InfeedWithToken(XlaOp token,const Shape & shape,const std::string & config)1821 XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape,
1822 const std::string& config) {
1823 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1824 if (!LayoutUtil::HasLayout(shape)) {
1825 return InvalidArgument("Given shape to Infeed must have a layout");
1826 }
1827 const Shape infeed_instruction_shape =
1828 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1829
1830 if (shape.IsArray() && sharding() &&
1831 sharding()->type() == OpSharding::OTHER) {
1832 // TODO(b/110793772): Support tiled array-shaped infeeds.
1833 return InvalidArgument(
1834 "Tiled sharding is not yet supported for array-shaped infeeds");
1835 }
1836
1837 if (sharding() && sharding()->type() == OpSharding::REPLICATED) {
1838 return InvalidArgument(
1839 "Replicated sharding is not yet supported for infeeds");
1840 }
1841 return InfeedWithTokenInternal(infeed_instruction_shape, token, config);
1842 });
1843 }
1844
InfeedWithTokenInternal(const Shape & infeed_instruction_shape,XlaOp token,const std::string & config)1845 StatusOr<XlaOp> XlaBuilder::InfeedWithTokenInternal(
1846 const Shape& infeed_instruction_shape, XlaOp token,
1847 const std::string& config) {
1848 HloInstructionProto instr;
1849 *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1850 instr.set_infeed_config(config);
1851 return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
1852 }
1853
Outfeed(XlaOp operand,const Shape & shape_with_layout,const std::string & outfeed_config)1854 void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout,
1855 const std::string& outfeed_config) {
1856 ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1857 HloInstructionProto instr;
1858
1859 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1860
1861 // Check and set outfeed shape.
1862 if (!LayoutUtil::HasLayout(shape_with_layout)) {
1863 return InvalidArgument("Given shape to Outfeed must have a layout");
1864 }
1865 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1866 if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
1867 return InvalidArgument(
1868 "Outfeed shape %s must be compatible with operand shape %s",
1869 ShapeUtil::HumanStringWithLayout(shape_with_layout),
1870 ShapeUtil::HumanStringWithLayout(*operand_shape));
1871 }
1872 *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1873
1874 instr.set_outfeed_config(outfeed_config);
1875
1876 // Outfeed takes a token as its second operand. Generate the token to pass
1877 // to the outfeed.
1878 HloInstructionProto token_instr;
1879 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1880 TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
1881 HloOpcode::kAfterAll, {}));
1882
1883 TF_RETURN_IF_ERROR(
1884 AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand, token})
1885 .status());
1886
1887 // The outfeed instruction produces a token. However, existing users expect
1888 // a nil shape (empty tuple). This should only be relevant if the outfeed is
1889 // the root of a computation.
1890 // TODO(b/80000000): Remove this when clients have been updated to handle
1891 // tokens.
1892 HloInstructionProto tuple_instr;
1893 *tuple_instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
1894
1895 // The dummy tuple should have no sharding.
1896 {
1897 XlaScopedShardingAssignment scoped_sharding(this, OpSharding());
1898 TF_ASSIGN_OR_RETURN(
1899 XlaOp empty_tuple,
1900 AddInstruction(std::move(tuple_instr), HloOpcode::kTuple, {}));
1901 return empty_tuple;
1902 }
1903 });
1904 }
1905
OutfeedWithToken(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const std::string & outfeed_config)1906 XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token,
1907 const Shape& shape_with_layout,
1908 const std::string& outfeed_config) {
1909 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1910 // Check and set outfeed shape.
1911 if (!LayoutUtil::HasLayout(shape_with_layout)) {
1912 return InvalidArgument("Given shape to Outfeed must have a layout");
1913 }
1914 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1915 if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
1916 return InvalidArgument(
1917 "Outfeed shape %s must be compatible with operand shape %s",
1918 ShapeUtil::HumanStringWithLayout(shape_with_layout),
1919 ShapeUtil::HumanStringWithLayout(*operand_shape));
1920 }
1921 return OutfeedWithTokenInternal(operand, token, shape_with_layout,
1922 outfeed_config);
1923 });
1924 }
1925
OutfeedWithTokenInternal(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const std::string & outfeed_config)1926 StatusOr<XlaOp> XlaBuilder::OutfeedWithTokenInternal(
1927 XlaOp operand, XlaOp token, const Shape& shape_with_layout,
1928 const std::string& outfeed_config) {
1929 HloInstructionProto instr;
1930 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1931 *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1932 instr.set_outfeed_config(outfeed_config);
1933 return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
1934 {operand, token});
1935 }
1936
CreateToken()1937 XlaOp XlaBuilder::CreateToken() {
1938 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1939 HloInstructionProto instr;
1940 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1941 return AddInstruction(std::move(instr), HloOpcode::kAfterAll);
1942 });
1943 }
1944
AfterAll(absl::Span<const XlaOp> tokens)1945 XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
1946 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1947 if (tokens.empty()) {
1948 return InvalidArgument("AfterAll requires at least one operand");
1949 }
1950 for (int i = 0, end = tokens.size(); i < end; ++i) {
1951 XlaOp operand = tokens[i];
1952 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1953 if (!operand_shape->IsToken()) {
1954 return InvalidArgument(
1955 "All operands to AfterAll must be tokens; operand %d has shape %s",
1956 i, ShapeUtil::HumanString(*operand_shape));
1957 }
1958 }
1959 HloInstructionProto instr;
1960 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1961 return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens);
1962 });
1963 }
1964
CustomCall(const std::string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const std::string & opaque,std::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> output_operand_aliasing,const Literal * literal,std::optional<Window> window,std::optional<ConvolutionDimensionNumbers> dnums,CustomCallSchedule schedule,CustomCallApiVersion api_version)1965 XlaOp XlaBuilder::CustomCall(
1966 const std::string& call_target_name, absl::Span<const XlaOp> operands,
1967 const Shape& shape, const std::string& opaque,
1968 std::optional<absl::Span<const Shape>> operand_shapes_with_layout,
1969 bool has_side_effect,
1970 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
1971 output_operand_aliasing,
1972 const Literal* literal, std::optional<Window> window,
1973 std::optional<ConvolutionDimensionNumbers> dnums,
1974 CustomCallSchedule schedule, CustomCallApiVersion api_version) {
1975 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1976 if (absl::StartsWith(call_target_name, "$")) {
1977 return InvalidArgument(
1978 "Invalid custom_call_target \"%s\": Call targets that start with '$' "
1979 "are reserved for internal use.",
1980 call_target_name);
1981 }
1982 if (operand_shapes_with_layout.has_value()) {
1983 if (!LayoutUtil::HasLayout(shape)) {
1984 return InvalidArgument(
1985 "Result shape must have layout for custom call with constrained "
1986 "layout.");
1987 }
1988 if (operands.size() != operand_shapes_with_layout->size()) {
1989 return InvalidArgument(
1990 "Must specify a shape with layout for each operand for custom call "
1991 "with constrained layout; given %d shapes, expected %d",
1992 operand_shapes_with_layout->size(), operands.size());
1993 }
1994 int64_t operand_num = 0;
1995 for (const Shape& operand_shape : *operand_shapes_with_layout) {
1996 if (!LayoutUtil::HasLayout(operand_shape)) {
1997 return InvalidArgument(
1998 "No layout specified for operand %d for custom call with "
1999 "constrained layout.",
2000 operand_num);
2001 }
2002 ++operand_num;
2003 }
2004 }
2005 return CustomCallInternal(
2006 call_target_name, operands, /*computation=*/nullptr, shape, opaque,
2007 operand_shapes_with_layout, has_side_effect, output_operand_aliasing,
2008 literal, window, dnums, schedule, api_version);
2009 });
2010 }
2011
CustomCallInternal(const std::string & call_target_name,absl::Span<const XlaOp> operands,const XlaComputation * computation,const Shape & shape,const std::string & opaque,std::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> output_operand_aliasing,const Literal * literal,std::optional<Window> window,std::optional<ConvolutionDimensionNumbers> dnums,CustomCallSchedule schedule,CustomCallApiVersion api_version)2012 StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
2013 const std::string& call_target_name, absl::Span<const XlaOp> operands,
2014 const XlaComputation* computation, const Shape& shape,
2015 const std::string& opaque,
2016 std::optional<absl::Span<const Shape>> operand_shapes_with_layout,
2017 bool has_side_effect,
2018 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
2019 output_operand_aliasing,
2020 const Literal* literal, std::optional<Window> window,
2021 std::optional<ConvolutionDimensionNumbers> dnums,
2022 CustomCallSchedule schedule, CustomCallApiVersion api_version) {
2023 HloInstructionProto instr;
2024 // Bit of a hack: cudnn conv custom-calls are created through this API. Give
2025 // them a user-friendly name. (This has no effect on correctness, it's just
2026 // cosmetic.)
2027 if (call_target_name == "__cudnn$convForward") {
2028 instr.set_name("cudnn-conv");
2029 } else if (call_target_name == "__cudnn$convBackwardInput") {
2030 instr.set_name("cudnn-conv-bw-input");
2031 } else if (call_target_name == "__cudnn$convBackwardFilter") {
2032 instr.set_name("cudnn-conv-bw-filter");
2033 } else if (call_target_name == "__cudnn$convBiasActivationForward") {
2034 instr.set_name("cudnn-conv-bias-activation");
2035 }
2036 *instr.mutable_shape() = shape.ToProto();
2037 instr.set_custom_call_target(call_target_name);
2038 instr.set_backend_config(opaque);
2039 if (operand_shapes_with_layout.has_value()) {
2040 instr.set_constrain_layout(true);
2041 for (const Shape& operand_shape : *operand_shapes_with_layout) {
2042 *instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
2043 }
2044 }
2045 if (literal != nullptr) {
2046 *instr.mutable_literal() = literal->ToProto();
2047 }
2048 instr.set_custom_call_has_side_effect(has_side_effect);
2049 if (computation != nullptr && !computation->IsNull()) {
2050 AddCalledComputation(*computation, &instr);
2051 }
2052 for (const auto& pair : output_operand_aliasing) {
2053 auto aliasing = instr.add_custom_call_output_operand_aliasing();
2054 aliasing->set_operand_index(pair.second.first);
2055 for (int64_t index : pair.second.second) {
2056 aliasing->add_operand_shape_index(index);
2057 }
2058 for (int64_t index : pair.first) {
2059 aliasing->add_output_shape_index(index);
2060 }
2061 }
2062 if (window.has_value()) {
2063 *instr.mutable_window() = *window;
2064 }
2065 if (dnums.has_value()) {
2066 *instr.mutable_convolution_dimension_numbers() = *dnums;
2067 }
2068 instr.set_custom_call_schedule(schedule);
2069 instr.set_custom_call_api_version(api_version);
2070 return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
2071 }
2072
CustomCall(const std::string & call_target_name,absl::Span<const XlaOp> operands,const XlaComputation & computation,const Shape & shape,const std::string & opaque,std::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> output_operand_aliasing,const Literal * literal,CustomCallSchedule schedule,CustomCallApiVersion api_version)2073 XlaOp XlaBuilder::CustomCall(
2074 const std::string& call_target_name, absl::Span<const XlaOp> operands,
2075 const XlaComputation& computation, const Shape& shape,
2076 const std::string& opaque,
2077 std::optional<absl::Span<const Shape>> operand_shapes_with_layout,
2078 bool has_side_effect,
2079 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
2080 output_operand_aliasing,
2081 const Literal* literal, CustomCallSchedule schedule,
2082 CustomCallApiVersion api_version) {
2083 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2084 if (absl::StartsWith(call_target_name, "$")) {
2085 return InvalidArgument(
2086 "Invalid custom_call_target \"%s\": Call targets that start with '$' "
2087 "are reserved for internal use.",
2088 call_target_name);
2089 }
2090 if (operand_shapes_with_layout.has_value()) {
2091 if (!LayoutUtil::HasLayout(shape)) {
2092 return InvalidArgument(
2093 "Result shape must have layout for custom call with constrained "
2094 "layout.");
2095 }
2096 if (operands.size() != operand_shapes_with_layout->size()) {
2097 return InvalidArgument(
2098 "Must specify a shape with layout for each operand for custom call "
2099 "with constrained layout; given %d shapes, expected %d",
2100 operand_shapes_with_layout->size(), operands.size());
2101 }
2102 int64_t operand_num = 0;
2103 for (const Shape& operand_shape : *operand_shapes_with_layout) {
2104 if (!LayoutUtil::HasLayout(operand_shape)) {
2105 return InvalidArgument(
2106 "No layout specified for operand %d for custom call with "
2107 "constrained layout.",
2108 operand_num);
2109 }
2110 ++operand_num;
2111 }
2112 }
2113 return CustomCallInternal(
2114 call_target_name, operands, &computation, shape, opaque,
2115 operand_shapes_with_layout, has_side_effect, output_operand_aliasing,
2116 literal, /*window=*/{}, /*dnums=*/{}, schedule, api_version);
2117 });
2118 }
2119
OptimizationBarrier(XlaOp operand)2120 XlaOp XlaBuilder::OptimizationBarrier(XlaOp operand) {
2121 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2122 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2123 Shape shape = *operand_shape;
2124 HloInstructionProto instr;
2125 *instr.mutable_shape() = shape.ToProto();
2126 return AddInstruction(std::move(instr), HloOpcode::kOptimizationBarrier,
2127 {operand});
2128 });
2129 }
2130
Transpose(XlaOp operand,absl::Span<const int64_t> permutation)2131 XlaOp XlaBuilder::Transpose(XlaOp operand,
2132 absl::Span<const int64_t> permutation) {
2133 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2134 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2135 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape(
2136 *operand_shape, permutation));
2137 return TransposeInternal(shape, operand, permutation);
2138 });
2139 }
2140
TransposeInternal(const Shape & shape,XlaOp operand,absl::Span<const int64_t> permutation)2141 StatusOr<XlaOp> XlaBuilder::TransposeInternal(
2142 const Shape& shape, XlaOp operand, absl::Span<const int64_t> permutation) {
2143 HloInstructionProto instr;
2144 *instr.mutable_shape() = shape.ToProto();
2145 for (int64_t dim : permutation) {
2146 instr.add_dimensions(dim);
2147 }
2148 return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
2149 }
2150
Rev(XlaOp operand,absl::Span<const int64_t> dimensions)2151 XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span<const int64_t> dimensions) {
2152 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2153 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2154 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape(
2155 *operand_shape, dimensions));
2156 return RevInternal(shape, operand, dimensions);
2157 });
2158 }
2159
RevInternal(const Shape & shape,XlaOp operand,absl::Span<const int64_t> dimensions)2160 StatusOr<XlaOp> XlaBuilder::RevInternal(const Shape& shape, XlaOp operand,
2161 absl::Span<const int64_t> dimensions) {
2162 HloInstructionProto instr;
2163 *instr.mutable_shape() = shape.ToProto();
2164 for (int64_t dim : dimensions) {
2165 instr.add_dimensions(dim);
2166 }
2167 return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand});
2168 }
2169
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64_t dimension,bool is_stable)2170 XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
2171 const XlaComputation& comparator, int64_t dimension,
2172 bool is_stable) {
2173 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2174 std::vector<const Shape*> operand_shape_ptrs;
2175 TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes,
2176 GetOperandShapes(operands));
2177 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
2178 [](const Shape& shape) { return &shape; });
2179 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape(
2180 HloOpcode::kSort, operand_shape_ptrs));
2181 return SortInternal(shape, operands, comparator, dimension, is_stable);
2182 });
2183 }
2184
SortInternal(const Shape & shape,absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64_t dimension,bool is_stable)2185 StatusOr<XlaOp> XlaBuilder::SortInternal(const Shape& shape,
2186 absl::Span<const XlaOp> operands,
2187 const XlaComputation& comparator,
2188 int64_t dimension, bool is_stable) {
2189 HloInstructionProto instr;
2190 *instr.mutable_shape() = shape.ToProto();
2191 instr.set_is_stable(is_stable);
2192 if (dimension == -1) {
2193 TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0]));
2194 dimension = keys_shape->rank() - 1;
2195 }
2196 instr.add_dimensions(dimension);
2197 AddCalledComputation(comparator, &instr);
2198 return AddInstruction(std::move(instr), HloOpcode::kSort, operands);
2199 }
2200
ConvertElementType(XlaOp operand,PrimitiveType new_element_type)2201 XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
2202 PrimitiveType new_element_type) {
2203 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2204 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2205 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
2206 *operand_shape, new_element_type));
2207 if (primitive_util::IsComplexType(operand_shape->element_type()) &&
2208 !primitive_util::IsComplexType(new_element_type)) {
2209 operand = Real(operand);
2210 }
2211 return AddOpWithShape(HloOpcode::kConvert, shape, {operand});
2212 });
2213 }
2214
BitcastConvertType(XlaOp operand,PrimitiveType new_element_type)2215 XlaOp XlaBuilder::BitcastConvertType(XlaOp operand,
2216 PrimitiveType new_element_type) {
2217 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2218 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2219 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferBitcastConvertShape(
2220 *operand_shape, new_element_type));
2221 return BitcastConvertTypeInternal(shape, operand);
2222 });
2223 }
2224
BitcastConvertTypeInternal(const Shape & shape,XlaOp operand)2225 StatusOr<XlaOp> XlaBuilder::BitcastConvertTypeInternal(const Shape& shape,
2226 XlaOp operand) {
2227 HloInstructionProto instr;
2228 *instr.mutable_shape() = shape.ToProto();
2229 return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
2230 {operand});
2231 }
2232
Clamp(XlaOp min,XlaOp operand,XlaOp max)2233 XlaOp XlaBuilder::Clamp(XlaOp min, XlaOp operand, XlaOp max) {
2234 return TernaryOp(HloOpcode::kClamp, min, operand, max);
2235 }
2236
Map(absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64_t> dimensions,absl::Span<const XlaOp> static_operands)2237 XlaOp XlaBuilder::Map(absl::Span<const XlaOp> operands,
2238 const XlaComputation& computation,
2239 absl::Span<const int64_t> dimensions,
2240 absl::Span<const XlaOp> static_operands) {
2241 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2242 if (!static_operands.empty()) {
2243 return Unimplemented("static_operands is not supported in Map");
2244 }
2245
2246 HloInstructionProto instr;
2247 std::vector<const Shape*> operand_shape_ptrs;
2248 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
2249 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
2250 [](const Shape& shape) { return &shape; });
2251 TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
2252 computation.GetProgramShape());
2253 TF_ASSIGN_OR_RETURN(
2254 Shape shape, ShapeInference::InferMapShape(
2255 operand_shape_ptrs, called_program_shape, dimensions));
2256 *instr.mutable_shape() = shape.ToProto();
2257
2258 Shape output_shape(instr.shape());
2259 const int64_t output_rank = output_shape.rank();
2260 AddCalledComputation(computation, &instr);
2261 std::vector<XlaOp> new_operands(operands.begin(), operands.end());
2262 for (XlaOp& new_operand : new_operands) {
2263 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(new_operand));
2264 const int64_t rank = shape->rank();
2265 if (rank != output_rank) {
2266 TF_ASSIGN_OR_RETURN(new_operand,
2267 InDimBroadcast(output_shape, new_operand, {}));
2268 TF_ASSIGN_OR_RETURN(shape, GetShapePtr(new_operand));
2269 }
2270 if (!ShapeUtil::SameDimensions(output_shape, *shape)) {
2271 TF_ASSIGN_OR_RETURN(new_operand,
2272 AddBroadcastSequence(output_shape, new_operand));
2273 }
2274 }
2275
2276 return AddInstruction(std::move(instr), HloOpcode::kMap, new_operands);
2277 });
2278 }
2279
RngOp(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)2280 XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
2281 absl::Span<const XlaOp> parameters,
2282 const Shape& shape) {
2283 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2284 // Check the number of parameters per RNG distribution.
2285 switch (distribution) {
2286 case RandomDistribution::RNG_NORMAL:
2287 case RandomDistribution::RNG_UNIFORM:
2288 if (parameters.size() != 2) {
2289 return InvalidArgument(
2290 "RNG distribution (%s) expects 2 parameters, but got %ld",
2291 RandomDistribution_Name(distribution), parameters.size());
2292 }
2293 break;
2294 default:
2295 LOG(FATAL) << "unhandled distribution " << distribution;
2296 }
2297
2298 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
2299 return RngOpInternal(distribution, parameters, shape);
2300 });
2301 }
2302
RngOpInternal(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)2303 StatusOr<XlaOp> XlaBuilder::RngOpInternal(RandomDistribution distribution,
2304 absl::Span<const XlaOp> parameters,
2305 const Shape& shape) {
2306 HloInstructionProto instr;
2307 *instr.mutable_shape() = shape.ToProto();
2308 instr.set_distribution(distribution);
2309
2310 return AddInstruction(std::move(instr), HloOpcode::kRng, parameters);
2311 }
2312
RngNormal(XlaOp mu,XlaOp sigma,const Shape & shape)2313 XlaOp XlaBuilder::RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape) {
2314 return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
2315 }
2316
RngUniform(XlaOp a,XlaOp b,const Shape & shape)2317 XlaOp XlaBuilder::RngUniform(XlaOp a, XlaOp b, const Shape& shape) {
2318 return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
2319 }
2320
RngBitGenerator(RandomAlgorithm algorithm,XlaOp initial_state,const Shape & shape)2321 XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm,
2322 XlaOp initial_state, const Shape& shape) {
2323 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2324 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
2325 TF_ASSIGN_OR_RETURN(Shape state_shape, GetShape(initial_state));
2326 Shape output_shape = shape;
2327 switch (output_shape.element_type()) {
2328 case PrimitiveType::S8:
2329 case PrimitiveType::U8:
2330 output_shape.set_element_type(PrimitiveType::U8);
2331 break;
2332 case PrimitiveType::BF16:
2333 case PrimitiveType::F16:
2334 case PrimitiveType::S16:
2335 case PrimitiveType::U16:
2336 output_shape.set_element_type(PrimitiveType::U16);
2337 break;
2338 case PrimitiveType::F32:
2339 case PrimitiveType::S32:
2340 case PrimitiveType::U32:
2341 output_shape.set_element_type(PrimitiveType::U32);
2342 break;
2343 case PrimitiveType::F64:
2344 case PrimitiveType::S64:
2345 case PrimitiveType::U64:
2346 output_shape.set_element_type(PrimitiveType::U64);
2347 break;
2348 default:
2349 return InvalidArgument("Unsupported shape for RngBitGenerator: %s",
2350 PrimitiveType_Name(output_shape.element_type()));
2351 }
2352 return RngBitGeneratorInternal(
2353 ShapeUtil::MakeTupleShapeWithPtrs({&state_shape, &output_shape}),
2354 algorithm, initial_state);
2355 });
2356 }
2357
RngBitGeneratorInternal(const Shape & full_result_shape,RandomAlgorithm algorithm,XlaOp initial_state)2358 StatusOr<XlaOp> XlaBuilder::RngBitGeneratorInternal(
2359 const Shape& full_result_shape, RandomAlgorithm algorithm,
2360 XlaOp initial_state) {
2361 HloInstructionProto instr;
2362 *instr.mutable_shape() = full_result_shape.ToProto();
2363 instr.set_rng_algorithm(algorithm);
2364 return AddInstruction(std::move(instr), HloOpcode::kRngBitGenerator,
2365 {initial_state});
2366 }
2367
While(const XlaComputation & condition,const XlaComputation & body,XlaOp init)2368 XlaOp XlaBuilder::While(const XlaComputation& condition,
2369 const XlaComputation& body, XlaOp init) {
2370 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2371 // Infer shape.
2372 TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape());
2373 TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
2374 condition.GetProgramShape());
2375 TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init));
2376 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape(
2377 condition_program_shape,
2378 body_program_shape, *init_shape));
2379 return WhileInternal(shape, condition, body, init);
2380 });
2381 }
2382
WhileInternal(const Shape & shape,const XlaComputation & condition,const XlaComputation & body,XlaOp init)2383 StatusOr<XlaOp> XlaBuilder::WhileInternal(const Shape& shape,
2384 const XlaComputation& condition,
2385 const XlaComputation& body,
2386 XlaOp init) {
2387 HloInstructionProto instr;
2388 *instr.mutable_shape() = shape.ToProto();
2389 // Body comes before condition computation in the vector.
2390 AddCalledComputation(body, &instr);
2391 AddCalledComputation(condition, &instr);
2392 return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
2393 }
2394
Gather(XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64_t> slice_sizes,bool indices_are_sorted)2395 XlaOp XlaBuilder::Gather(XlaOp input, XlaOp start_indices,
2396 const GatherDimensionNumbers& dimension_numbers,
2397 absl::Span<const int64_t> slice_sizes,
2398 bool indices_are_sorted) {
2399 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2400 TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
2401 TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape,
2402 GetShapePtr(start_indices));
2403 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGatherShape(
2404 *input_shape, *start_indices_shape,
2405 dimension_numbers, slice_sizes));
2406 return GatherInternal(shape, input, start_indices, dimension_numbers,
2407 slice_sizes, indices_are_sorted);
2408 });
2409 }
2410
GatherInternal(const Shape & shape,XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64_t> slice_sizes,bool indices_are_sorted)2411 StatusOr<XlaOp> XlaBuilder::GatherInternal(
2412 const Shape& shape, XlaOp input, XlaOp start_indices,
2413 const GatherDimensionNumbers& dimension_numbers,
2414 absl::Span<const int64_t> slice_sizes, bool indices_are_sorted) {
2415 HloInstructionProto instr;
2416 instr.set_indices_are_sorted(indices_are_sorted);
2417 *instr.mutable_shape() = shape.ToProto();
2418 *instr.mutable_gather_dimension_numbers() = dimension_numbers;
2419 for (int64_t bound : slice_sizes) {
2420 instr.add_gather_slice_sizes(bound);
2421 }
2422
2423 return AddInstruction(std::move(instr), HloOpcode::kGather,
2424 {input, start_indices});
2425 }
2426
Scatter(XlaOp input,XlaOp scatter_indices,XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)2427 XlaOp XlaBuilder::Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
2428 const XlaComputation& update_computation,
2429 const ScatterDimensionNumbers& dimension_numbers,
2430 bool indices_are_sorted, bool unique_indices) {
2431 return Scatter(absl::MakeConstSpan(&input, 1), scatter_indices,
2432 absl::MakeConstSpan(&updates, 1), update_computation,
2433 dimension_numbers, indices_are_sorted, unique_indices);
2434 }
2435
Scatter(absl::Span<const XlaOp> inputs,XlaOp scatter_indices,absl::Span<const XlaOp> updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)2436 XlaOp XlaBuilder::Scatter(absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
2437 absl::Span<const XlaOp> updates,
2438 const XlaComputation& update_computation,
2439 const ScatterDimensionNumbers& dimension_numbers,
2440 bool indices_are_sorted, bool unique_indices) {
2441 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2442 if (inputs.empty()) {
2443 return InvalidArgument("Scatter inputs cannot be empty.");
2444 }
2445 if (inputs.size() != updates.size()) {
2446 return InvalidArgument(
2447 "Scatter should have same number of inputs and updates: %d vs %d.",
2448 inputs.size(), updates.size());
2449 }
2450 absl::InlinedVector<const Shape*, 3> operand_shapes;
2451 operand_shapes.reserve(inputs.size() + 1 + updates.size());
2452 for (const XlaOp& input : inputs) {
2453 TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
2454 operand_shapes.push_back(input_shape);
2455 }
2456 TF_ASSIGN_OR_RETURN(const Shape* scatter_indices_shape,
2457 GetShapePtr(scatter_indices));
2458 operand_shapes.push_back(scatter_indices_shape);
2459 for (const XlaOp& update : updates) {
2460 TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
2461 operand_shapes.push_back(update_shape);
2462 }
2463 TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2464 update_computation.GetProgramShape());
2465 TF_ASSIGN_OR_RETURN(Shape shape,
2466 ShapeInference::InferScatterShape(
2467 operand_shapes, to_apply_shape, dimension_numbers));
2468 return ScatterInternal(shape, inputs, scatter_indices, updates,
2469 update_computation, dimension_numbers,
2470 indices_are_sorted, unique_indices);
2471 });
2472 }
2473
ScatterInternal(const Shape & shape,absl::Span<const XlaOp> inputs,XlaOp scatter_indices,absl::Span<const XlaOp> updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)2474 StatusOr<XlaOp> XlaBuilder::ScatterInternal(
2475 const Shape& shape, absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
2476 absl::Span<const XlaOp> updates, const XlaComputation& update_computation,
2477 const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
2478 bool unique_indices) {
2479 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2480 HloInstructionProto instr;
2481 instr.set_indices_are_sorted(indices_are_sorted);
2482 instr.set_unique_indices(unique_indices);
2483 *instr.mutable_shape() = shape.ToProto();
2484 *instr.mutable_scatter_dimension_numbers() = dimension_numbers;
2485
2486 AddCalledComputation(update_computation, &instr);
2487 absl::InlinedVector<XlaOp, 3> operands;
2488 operands.reserve(inputs.size() + 1 + updates.size());
2489 absl::c_copy(inputs, std::back_inserter(operands));
2490 operands.push_back(scatter_indices);
2491 absl::c_copy(updates, std::back_inserter(operands));
2492 return AddInstruction(std::move(instr), HloOpcode::kScatter, operands);
2493 });
2494 }
2495
Conditional(XlaOp predicate,XlaOp true_operand,const XlaComputation & true_computation,XlaOp false_operand,const XlaComputation & false_computation)2496 XlaOp XlaBuilder::Conditional(XlaOp predicate, XlaOp true_operand,
2497 const XlaComputation& true_computation,
2498 XlaOp false_operand,
2499 const XlaComputation& false_computation) {
2500 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2501 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(predicate));
2502
2503 if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != PRED) {
2504 return InvalidArgument(
2505 "Argument to predicated-Conditional is not a scalar of PRED type "
2506 "(%s).",
2507 ShapeUtil::HumanString(*shape));
2508 }
2509 // The index of true_computation must be 0 and that of false computation
2510 // must be 1.
2511 return ConditionalImpl(predicate, {&true_computation, &false_computation},
2512 {true_operand, false_operand});
2513 });
2514 }
2515
Conditional(XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)2516 XlaOp XlaBuilder::Conditional(
2517 XlaOp branch_index,
2518 absl::Span<const XlaComputation* const> branch_computations,
2519 absl::Span<const XlaOp> branch_operands) {
2520 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2521 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(branch_index));
2522
2523 if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != S32) {
2524 return InvalidArgument(
2525 "Argument to indexed-Conditional is not a scalar of S32 type (%s).",
2526 ShapeUtil::HumanString(*shape));
2527 }
2528 return ConditionalImpl(branch_index, branch_computations, branch_operands);
2529 });
2530 }
2531
ConditionalImpl(XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)2532 XlaOp XlaBuilder::ConditionalImpl(
2533 XlaOp branch_index,
2534 absl::Span<const XlaComputation* const> branch_computations,
2535 absl::Span<const XlaOp> branch_operands) {
2536 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2537 HloInstructionProto instr;
2538
2539 TF_ASSIGN_OR_RETURN(const Shape* branch_index_shape,
2540 GetShapePtr(branch_index));
2541 std::vector<Shape> branch_operand_shapes(branch_operands.size());
2542 std::vector<ProgramShape> branch_computation_shapes(
2543 branch_computations.size());
2544 for (int j = 0, end = branch_operands.size(); j < end; ++j) {
2545 TF_ASSIGN_OR_RETURN(branch_operand_shapes[j],
2546 GetShape(branch_operands[j]));
2547 TF_ASSIGN_OR_RETURN(branch_computation_shapes[j],
2548 branch_computations[j]->GetProgramShape());
2549 }
2550 TF_ASSIGN_OR_RETURN(const Shape shape,
2551 ShapeInference::InferConditionalShape(
2552 *branch_index_shape, branch_computation_shapes,
2553 branch_operand_shapes));
2554 *instr.mutable_shape() = shape.ToProto();
2555
2556 for (const XlaComputation* branch_computation : branch_computations) {
2557 AddCalledComputation(*branch_computation, &instr);
2558 }
2559
2560 std::vector<XlaOp> operands(1, branch_index);
2561 for (const XlaOp branch_operand : branch_operands) {
2562 operands.emplace_back(branch_operand);
2563 }
2564 return AddInstruction(std::move(instr), HloOpcode::kConditional,
2565 absl::MakeSpan(operands));
2566 });
2567 }
2568
CheckOpBuilder(XlaOp op) const2569 Status XlaBuilder::CheckOpBuilder(XlaOp op) const {
2570 if (this != op.builder()) {
2571 return InvalidArgument(
2572 "XlaOp with handle %d is built by builder '%s', but is trying to use "
2573 "it in builder '%s'",
2574 op.handle(), op.builder()->name(), name());
2575 }
2576 return OkStatus();
2577 }
2578
Reduce(XlaOp operand,XlaOp init_value,const XlaComputation & computation,absl::Span<const int64_t> dimensions_to_reduce)2579 XlaOp XlaBuilder::Reduce(XlaOp operand, XlaOp init_value,
2580 const XlaComputation& computation,
2581 absl::Span<const int64_t> dimensions_to_reduce) {
2582 return Reduce(absl::Span<const XlaOp>({operand}),
2583 absl::Span<const XlaOp>({init_value}), computation,
2584 dimensions_to_reduce);
2585 }
2586
Reduce(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64_t> dimensions_to_reduce)2587 XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
2588 absl::Span<const XlaOp> init_values,
2589 const XlaComputation& computation,
2590 absl::Span<const int64_t> dimensions_to_reduce) {
2591 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2592 TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
2593 computation.GetProgramShape());
2594
2595 std::vector<XlaOp> all_operands;
2596 all_operands.insert(all_operands.end(), operands.begin(), operands.end());
2597 all_operands.insert(all_operands.end(), init_values.begin(),
2598 init_values.end());
2599
2600 std::vector<const Shape*> operand_shape_ptrs;
2601 TF_ASSIGN_OR_RETURN(const auto& operand_shapes,
2602 GetOperandShapes(all_operands));
2603 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
2604 [](const Shape& shape) { return &shape; });
2605
2606 TF_ASSIGN_OR_RETURN(
2607 Shape shape,
2608 ShapeInference::InferReduceShape(
2609 operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
2610 return ReduceInternal(shape, all_operands, computation,
2611 dimensions_to_reduce);
2612 });
2613 }
2614
ReduceInternal(const Shape & shape,absl::Span<const XlaOp> all_operands,const XlaComputation & computation,absl::Span<const int64_t> dimensions_to_reduce)2615 StatusOr<XlaOp> XlaBuilder::ReduceInternal(
2616 const Shape& shape, absl::Span<const XlaOp> all_operands,
2617 const XlaComputation& computation,
2618 absl::Span<const int64_t> dimensions_to_reduce) {
2619 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2620 HloInstructionProto instr;
2621 *instr.mutable_shape() = shape.ToProto();
2622
2623 for (int64_t dim : dimensions_to_reduce) {
2624 instr.add_dimensions(dim);
2625 }
2626
2627 AddCalledComputation(computation, &instr);
2628 return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands);
2629 });
2630 }
2631
ReduceAll(XlaOp operand,XlaOp init_value,const XlaComputation & computation)2632 XlaOp XlaBuilder::ReduceAll(XlaOp operand, XlaOp init_value,
2633 const XlaComputation& computation) {
2634 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2635 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2636 std::vector<int64_t> all_dimnos(operand_shape->rank());
2637 std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
2638 return Reduce(operand, init_value, computation, all_dimnos);
2639 });
2640 }
2641
ReduceWindow(XlaOp operand,XlaOp init_value,const XlaComputation & computation,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,Padding padding)2642 XlaOp XlaBuilder::ReduceWindow(XlaOp operand, XlaOp init_value,
2643 const XlaComputation& computation,
2644 absl::Span<const int64_t> window_dimensions,
2645 absl::Span<const int64_t> window_strides,
2646 Padding padding) {
2647 return ReduceWindow(absl::MakeSpan(&operand, 1),
2648 absl::MakeSpan(&init_value, 1), computation,
2649 window_dimensions, window_strides, padding);
2650 }
2651
ReduceWindow(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,Padding padding)2652 XlaOp XlaBuilder::ReduceWindow(absl::Span<const XlaOp> operands,
2653 absl::Span<const XlaOp> init_values,
2654 const XlaComputation& computation,
2655 absl::Span<const int64_t> window_dimensions,
2656 absl::Span<const int64_t> window_strides,
2657 Padding padding) {
2658 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2659 const Shape* operand_shape = nullptr;
2660 for (const auto& operand : operands) {
2661 TF_ASSIGN_OR_RETURN(operand_shape, GetShapePtr(operand));
2662 TF_RETURN_IF_ERROR(ValidatePaddingValues(
2663 operand_shape->dimensions(), window_dimensions, window_strides));
2664 }
2665 CHECK(operand_shape != nullptr);
2666 std::vector<std::pair<int64_t, int64_t>> padding_values =
2667 MakePadding(operand_shape->dimensions(), window_dimensions,
2668 window_strides, padding);
2669 TF_ASSIGN_OR_RETURN(auto window,
2670 ShapeInference::InferWindowFromDimensions(
2671 window_dimensions, window_strides, padding_values,
2672 /*lhs_dilation=*/{},
2673 /*rhs_dilation=*/{}));
2674 PaddingType padding_type = PADDING_INVALID;
2675 for (int64_t i = 0; i < operand_shape->rank(); ++i) {
2676 if (operand_shape->is_dynamic_dimension(i) &&
2677 !window_util::IsTrivialWindowDimension(window.dimensions(i)) &&
2678 padding == Padding::kSame) {
2679 // SAME padding can create dynamic padding sizes. The padding size
2680 // need to be rewritten by dynamic padder using HloInstructions. We
2681 // create a CustomCall to handle this.
2682 padding_type = PADDING_SAME;
2683 }
2684 }
2685 if (padding_type == PADDING_SAME) {
2686 TF_ASSIGN_OR_RETURN(
2687 HloInstructionProto instr,
2688 ReduceWindowInternal(operands, init_values, computation,
2689 window_dimensions, window_strides, {}, {},
2690 padding_values));
2691 instr.set_custom_call_target("DynamicReduceWindowSamePadding");
2692 std::vector<XlaOp> args;
2693 args.insert(args.end(), operands.begin(), operands.end());
2694 args.insert(args.end(), init_values.begin(), init_values.end());
2695 return AddInstruction(std::move(instr), HloOpcode::kCustomCall, args);
2696 }
2697 return ReduceWindowWithGeneralPadding(
2698 operands, init_values, computation, window_dimensions, window_strides,
2699 /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values);
2700 });
2701 }
2702
ReduceWindowWithGeneralPadding(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,absl::Span<const int64_t> base_dilations,absl::Span<const int64_t> window_dilations,absl::Span<const std::pair<int64_t,int64_t>> padding)2703 XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
2704 absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
2705 const XlaComputation& computation,
2706 absl::Span<const int64_t> window_dimensions,
2707 absl::Span<const int64_t> window_strides,
2708 absl::Span<const int64_t> base_dilations,
2709 absl::Span<const int64_t> window_dilations,
2710 absl::Span<const std::pair<int64_t, int64_t>> padding) {
2711 std::vector<const Shape*> operand_shapes, init_shapes;
2712 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2713 if (operands.size() == 1) {
2714 const auto& operand = operands[0];
2715 const auto& init_value = init_values[0];
2716 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2717 operand_shapes.push_back(operand_shape);
2718 TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
2719 init_shapes.push_back(init_shape);
2720
2721 TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2722 computation.GetProgramShape());
2723 TF_ASSIGN_OR_RETURN(auto window,
2724 ShapeInference::InferWindowFromDimensions(
2725 window_dimensions, window_strides, padding,
2726 /*lhs_dilation=*/base_dilations,
2727 /*rhs_dilation=*/window_dilations));
2728 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape(
2729 absl::MakeSpan(operand_shapes),
2730 absl::MakeSpan(init_shapes), window,
2731 to_apply_shape));
2732 return ReduceWindowInternal(shape, operands[0], init_values[0],
2733 computation, window);
2734 }
2735
2736 TF_ASSIGN_OR_RETURN(
2737 HloInstructionProto instr,
2738 ReduceWindowInternal(operands, init_values, computation,
2739 window_dimensions, window_strides, base_dilations,
2740 window_dilations, padding));
2741 std::vector<XlaOp> args;
2742 args.insert(args.end(), operands.begin(), operands.end());
2743 args.insert(args.end(), init_values.begin(), init_values.end());
2744 return AddInstruction(std::move(instr), HloOpcode::kReduceWindow, args);
2745 });
2746 }
2747
ReduceWindowInternal(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,absl::Span<const int64_t> base_dilations,absl::Span<const int64_t> window_dilations,absl::Span<const std::pair<int64_t,int64_t>> padding)2748 StatusOr<HloInstructionProto> XlaBuilder::ReduceWindowInternal(
2749 absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
2750 const XlaComputation& computation,
2751 absl::Span<const int64_t> window_dimensions,
2752 absl::Span<const int64_t> window_strides,
2753 absl::Span<const int64_t> base_dilations,
2754 absl::Span<const int64_t> window_dilations,
2755 absl::Span<const std::pair<int64_t, int64_t>> padding) {
2756 std::vector<const Shape*> operand_shapes, init_shapes;
2757 for (int i = 0; i < operands.size(); ++i) {
2758 const auto& operand = operands[i];
2759 const auto& init_value = init_values[i];
2760 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2761 operand_shapes.push_back(operand_shape);
2762 TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
2763 init_shapes.push_back(init_shape);
2764 }
2765 TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2766 computation.GetProgramShape());
2767 TF_ASSIGN_OR_RETURN(auto window,
2768 ShapeInference::InferWindowFromDimensions(
2769 window_dimensions, window_strides, padding,
2770 /*lhs_dilation=*/base_dilations,
2771 /*rhs_dilation=*/window_dilations));
2772 TF_ASSIGN_OR_RETURN(Shape shape,
2773 ShapeInference::InferReduceWindowShape(
2774 absl::MakeSpan(operand_shapes),
2775 absl::MakeSpan(init_shapes), window, to_apply_shape));
2776 HloInstructionProto instr;
2777 *instr.mutable_shape() = shape.ToProto();
2778 *instr.mutable_window() = std::move(window);
2779 AddCalledComputation(computation, &instr);
2780 return instr;
2781 }
2782
ReduceWindowInternal(const Shape & shape,XlaOp operand,XlaOp init_value,const XlaComputation & computation,Window window)2783 StatusOr<XlaOp> XlaBuilder::ReduceWindowInternal(
2784 const Shape& shape, XlaOp operand, XlaOp init_value,
2785 const XlaComputation& computation, Window window) {
2786 HloInstructionProto instr;
2787 *instr.mutable_shape() = shape.ToProto();
2788 *instr.mutable_window() = std::move(window);
2789
2790 AddCalledComputation(computation, &instr);
2791 return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
2792 {operand, init_value});
2793 }
2794
BatchNormTraining(XlaOp operand,XlaOp scale,XlaOp offset,float epsilon,int64_t feature_index)2795 XlaOp XlaBuilder::BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
2796 float epsilon, int64_t feature_index) {
2797 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2798 HloInstructionProto instr;
2799
2800 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2801 TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2802 TF_ASSIGN_OR_RETURN(const Shape* offset_shape, GetShapePtr(offset));
2803 TF_ASSIGN_OR_RETURN(
2804 Shape shape,
2805 ShapeInference::InferBatchNormTrainingShape(
2806 *operand_shape, *scale_shape, *offset_shape, feature_index));
2807 *instr.mutable_shape() = shape.ToProto();
2808
2809 instr.set_epsilon(epsilon);
2810 instr.set_feature_index(feature_index);
2811
2812 return AddInstruction(std::move(instr), HloOpcode::kBatchNormTraining,
2813 {operand, scale, offset});
2814 });
2815 }
2816
BatchNormInference(XlaOp operand,XlaOp scale,XlaOp offset,XlaOp mean,XlaOp variance,float epsilon,int64_t feature_index)2817 XlaOp XlaBuilder::BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset,
2818 XlaOp mean, XlaOp variance, float epsilon,
2819 int64_t feature_index) {
2820 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2821 HloInstructionProto instr;
2822
2823 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2824 TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2825 TF_ASSIGN_OR_RETURN(const Shape* offset_shape, GetShapePtr(offset));
2826 TF_ASSIGN_OR_RETURN(const Shape* mean_shape, GetShapePtr(mean));
2827 TF_ASSIGN_OR_RETURN(const Shape* variance_shape, GetShapePtr(variance));
2828 TF_ASSIGN_OR_RETURN(Shape shape,
2829 ShapeInference::InferBatchNormInferenceShape(
2830 *operand_shape, *scale_shape, *offset_shape,
2831 *mean_shape, *variance_shape, feature_index));
2832 *instr.mutable_shape() = shape.ToProto();
2833
2834 instr.set_epsilon(epsilon);
2835 instr.set_feature_index(feature_index);
2836
2837 return AddInstruction(std::move(instr), HloOpcode::kBatchNormInference,
2838 {operand, scale, offset, mean, variance});
2839 });
2840 }
2841
BatchNormGrad(XlaOp operand,XlaOp scale,XlaOp batch_mean,XlaOp batch_var,XlaOp grad_output,float epsilon,int64_t feature_index)2842 XlaOp XlaBuilder::BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
2843 XlaOp batch_var, XlaOp grad_output,
2844 float epsilon, int64_t feature_index) {
2845 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2846 HloInstructionProto instr;
2847
2848 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2849 TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2850 TF_ASSIGN_OR_RETURN(const Shape* batch_mean_shape, GetShapePtr(batch_mean));
2851 TF_ASSIGN_OR_RETURN(const Shape* batch_var_shape, GetShapePtr(batch_var));
2852 TF_ASSIGN_OR_RETURN(const Shape* grad_output_shape,
2853 GetShapePtr(grad_output));
2854 TF_ASSIGN_OR_RETURN(
2855 Shape shape, ShapeInference::InferBatchNormGradShape(
2856 *operand_shape, *scale_shape, *batch_mean_shape,
2857 *batch_var_shape, *grad_output_shape, feature_index));
2858 *instr.mutable_shape() = shape.ToProto();
2859
2860 instr.set_epsilon(epsilon);
2861 instr.set_feature_index(feature_index);
2862
2863 return AddInstruction(std::move(instr), HloOpcode::kBatchNormGrad,
2864 {operand, scale, batch_mean, batch_var, grad_output});
2865 });
2866 }
2867
AllGather(XlaOp operand,int64_t all_gather_dimension,int64_t shard_count,absl::Span<const ReplicaGroup> replica_groups,const std::optional<ChannelHandle> & channel_id,const std::optional<Layout> & layout,const std::optional<bool> use_global_device_ids)2868 XlaOp XlaBuilder::AllGather(XlaOp operand, int64_t all_gather_dimension,
2869 int64_t shard_count,
2870 absl::Span<const ReplicaGroup> replica_groups,
2871 const std::optional<ChannelHandle>& channel_id,
2872 const std::optional<Layout>& layout,
2873 const std::optional<bool> use_global_device_ids) {
2874 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2875 HloInstructionProto instr;
2876 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2877
2878 TF_ASSIGN_OR_RETURN(
2879 Shape inferred_shape,
2880 ShapeInference::InferAllGatherShape({operand_shape},
2881 all_gather_dimension, shard_count));
2882 if (layout) {
2883 *inferred_shape.mutable_layout() = *layout;
2884 instr.set_constrain_layout(true);
2885 }
2886 *instr.mutable_shape() = inferred_shape.ToProto();
2887
2888 instr.add_dimensions(all_gather_dimension);
2889 for (const ReplicaGroup& group : replica_groups) {
2890 *instr.add_replica_groups() = group;
2891 }
2892 if (channel_id.has_value()) {
2893 instr.set_channel_id(channel_id->handle());
2894 }
2895 if (use_global_device_ids.has_value()) {
2896 instr.set_use_global_device_ids(use_global_device_ids.value());
2897 }
2898
2899 TF_ASSIGN_OR_RETURN(
2900 auto all_gather,
2901 AddInstruction(std::move(instr), HloOpcode::kAllGather, {operand}));
2902 return all_gather;
2903 });
2904 }
2905
CrossReplicaSum(XlaOp operand,absl::Span<const ReplicaGroup> replica_groups)2906 XlaOp XlaBuilder::CrossReplicaSum(
2907 XlaOp operand, absl::Span<const ReplicaGroup> replica_groups) {
2908 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2909 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
2910 const Shape* element_shape;
2911 if (shape->IsTuple()) {
2912 if (shape->tuple_shapes_size() == 0) {
2913 return Unimplemented(
2914 "0 element tuple CrossReplicaSum is not supported");
2915 }
2916 element_shape = &shape->tuple_shapes(0);
2917 } else {
2918 element_shape = shape;
2919 }
2920 const Shape scalar_shape =
2921 ShapeUtil::MakeShape(element_shape->element_type(), {});
2922 auto b = CreateSubBuilder("sum");
2923 auto x = b->Parameter(/*parameter_number=*/0, scalar_shape, "x");
2924 auto y = b->Parameter(/*parameter_number=*/1, scalar_shape, "y");
2925 if (scalar_shape.element_type() == PRED) {
2926 Or(x, y);
2927 } else {
2928 Add(x, y);
2929 }
2930 TF_ASSIGN_OR_RETURN(auto computation, b->Build());
2931 return AllReduce(operand, computation, replica_groups,
2932 /*channel_id=*/std::nullopt);
2933 });
2934 }
2935
AllReduce(XlaOp operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const std::optional<ChannelHandle> & channel_id,const std::optional<Shape> & shape_with_layout,const std::optional<bool> use_global_device_ids)2936 XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
2937 absl::Span<const ReplicaGroup> replica_groups,
2938 const std::optional<ChannelHandle>& channel_id,
2939 const std::optional<Shape>& shape_with_layout,
2940 const std::optional<bool> use_global_device_ids) {
2941 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2942 HloInstructionProto instr;
2943 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2944 std::vector<const Shape*> operand_shapes;
2945 std::vector<XlaOp> operands;
2946 if (operand_shape->IsTuple()) {
2947 if (operand_shape->tuple_shapes_size() == 0) {
2948 return Unimplemented("0 element tuple AllReduce is not supported");
2949 }
2950 for (int i = 0; i < operand_shape->tuple_shapes_size(); ++i) {
2951 if (operand_shape->tuple_shapes(i).element_type() !=
2952 operand_shape->tuple_shapes(0).element_type()) {
2953 return Unimplemented(
2954 "All the shapes of a tuple input of AllReduce must have the same "
2955 "element type");
2956 }
2957 operand_shapes.push_back(&operand_shape->tuple_shapes(i));
2958 operands.push_back(GetTupleElement(operand, i));
2959 }
2960 } else {
2961 operand_shapes.push_back(operand_shape);
2962 operands.push_back(operand);
2963 }
2964
2965 TF_ASSIGN_OR_RETURN(Shape inferred_shape,
2966 ShapeInference::InferAllReduceShape(operand_shapes));
2967 if (shape_with_layout) {
2968 if (!LayoutUtil::HasLayout(*shape_with_layout)) {
2969 return InvalidArgument("shape_with_layout must have the layout set: %s",
2970 shape_with_layout->ToString());
2971 }
2972 if (!ShapeUtil::Compatible(*shape_with_layout, *operand_shape)) {
2973 return InvalidArgument(
2974 "Provided shape_with_layout must be compatible with the "
2975 "operand shape: %s vs %s",
2976 shape_with_layout->ToString(), operand_shape->ToString());
2977 }
2978 instr.set_constrain_layout(true);
2979 if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
2980 // For a single-element tuple, take the tuple element shape.
2981 TF_RET_CHECK(shape_with_layout->tuple_shapes_size() == 1);
2982 *instr.mutable_shape() = shape_with_layout->tuple_shapes(0).ToProto();
2983 } else {
2984 *instr.mutable_shape() = shape_with_layout->ToProto();
2985 }
2986 } else {
2987 *instr.mutable_shape() = inferred_shape.ToProto();
2988 }
2989
2990 for (const ReplicaGroup& group : replica_groups) {
2991 *instr.add_replica_groups() = group;
2992 }
2993
2994 if (channel_id.has_value()) {
2995 instr.set_channel_id(channel_id->handle());
2996 }
2997
2998 if (use_global_device_ids.has_value()) {
2999 instr.set_use_global_device_ids(*use_global_device_ids);
3000 }
3001
3002 AddCalledComputation(computation, &instr);
3003
3004 TF_ASSIGN_OR_RETURN(
3005 auto all_reduce,
3006 AddInstruction(std::move(instr), HloOpcode::kAllReduce, operands));
3007 if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
3008 // For a single-element tuple, wrap the result into a tuple.
3009 TF_RET_CHECK(operand_shapes.size() == 1);
3010 TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], inferred_shape));
3011 return Tuple({all_reduce});
3012 }
3013 return all_reduce;
3014 });
3015 }
3016
ReduceScatter(XlaOp operand,const XlaComputation & computation,int64_t scatter_dimension,int64_t shard_count,absl::Span<const ReplicaGroup> replica_groups,const std::optional<ChannelHandle> & channel_id,const std::optional<Layout> & layout,const std::optional<bool> use_global_device_ids)3017 XlaOp XlaBuilder::ReduceScatter(
3018 XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension,
3019 int64_t shard_count, absl::Span<const ReplicaGroup> replica_groups,
3020 const std::optional<ChannelHandle>& channel_id,
3021 const std::optional<Layout>& layout,
3022 const std::optional<bool> use_global_device_ids) {
3023 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3024 HloInstructionProto instr;
3025 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3026 std::vector<const Shape*> operand_shapes;
3027 std::vector<XlaOp> operands;
3028 if (operand_shape->IsTuple()) {
3029 if (operand_shape->tuple_shapes_size() == 0) {
3030 return Unimplemented("0 element tuple ReduceScatter is not supported");
3031 }
3032 for (int i = 0; i < operand_shape->tuple_shapes_size(); ++i) {
3033 if (operand_shape->tuple_shapes(i).element_type() !=
3034 operand_shape->tuple_shapes(0).element_type()) {
3035 return Unimplemented(
3036 "All the shapes of a tuple input of ReduceScatter must have "
3037 "the same "
3038 "element type");
3039 }
3040 operand_shapes.push_back(&operand_shape->tuple_shapes(i));
3041 operands.push_back(GetTupleElement(operand, i));
3042 }
3043 } else {
3044 operand_shapes.push_back(operand_shape);
3045 operands.push_back(operand);
3046 }
3047
3048 TF_ASSIGN_OR_RETURN(Shape inferred_shape,
3049 ShapeInference::InferReduceScatterShape(
3050 operand_shapes, scatter_dimension, shard_count));
3051 if (layout) {
3052 *inferred_shape.mutable_layout() = *layout;
3053 instr.set_constrain_layout(true);
3054 }
3055 *instr.mutable_shape() = inferred_shape.ToProto();
3056
3057 AddCalledComputation(computation, &instr);
3058
3059 instr.add_dimensions(scatter_dimension);
3060 for (const ReplicaGroup& group : replica_groups) {
3061 *instr.add_replica_groups() = group;
3062 }
3063 if (channel_id.has_value()) {
3064 instr.set_channel_id(channel_id->handle());
3065 }
3066 if (use_global_device_ids.has_value()) {
3067 instr.set_use_global_device_ids(use_global_device_ids.value());
3068 }
3069
3070 TF_ASSIGN_OR_RETURN(
3071 auto reduce_scatter,
3072 AddInstruction(std::move(instr), HloOpcode::kReduceScatter, {operand}));
3073 return reduce_scatter;
3074 });
3075 }
3076
AllToAll(XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups,const std::optional<Layout> & layout)3077 XlaOp XlaBuilder::AllToAll(XlaOp operand, int64_t split_dimension,
3078 int64_t concat_dimension, int64_t split_count,
3079 absl::Span<const ReplicaGroup> replica_groups,
3080 const std::optional<Layout>& layout) {
3081 // Array all_to_all may need to violate layout constraint to be legal so use
3082 // the tuple version.
3083 if (layout.has_value()) {
3084 return AllToAllTuple(operand, split_dimension, concat_dimension,
3085 split_count, replica_groups, layout);
3086 }
3087 return AllToAllArray(operand, split_dimension, concat_dimension, split_count,
3088 replica_groups);
3089 }
3090
AllToAllArray(XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups)3091 XlaOp XlaBuilder::AllToAllArray(XlaOp operand, int64_t split_dimension,
3092 int64_t concat_dimension, int64_t split_count,
3093 absl::Span<const ReplicaGroup> replica_groups) {
3094 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3095 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3096 TF_ASSIGN_OR_RETURN(
3097 const Shape all_to_all_shape,
3098 ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
3099 concat_dimension, split_count));
3100 HloInstructionProto instr;
3101 *instr.mutable_shape() = operand_shape->ToProto();
3102 if (replica_groups.empty()) {
3103 auto* group = instr.add_replica_groups();
3104 for (int64_t i = 0; i < split_count; ++i) {
3105 group->add_replica_ids(i);
3106 }
3107 } else {
3108 for (const ReplicaGroup& group : replica_groups) {
3109 *instr.add_replica_groups() = group;
3110 }
3111 }
3112 instr.add_dimensions(split_dimension);
3113 TF_ASSIGN_OR_RETURN(
3114 XlaOp all_to_all,
3115 AddInstruction(std::move(instr), HloOpcode::kAllToAll, {operand}));
3116 if (split_dimension == concat_dimension) {
3117 return all_to_all;
3118 }
3119 DimensionVector sizes;
3120 for (int64_t i = 0; i < operand_shape->rank(); ++i) {
3121 if (i != split_dimension) {
3122 sizes.push_back(operand_shape->dimensions(i));
3123 continue;
3124 }
3125 sizes.push_back(split_count);
3126 sizes.push_back(operand_shape->dimensions(i) / split_count);
3127 }
3128 all_to_all = Reshape(all_to_all, sizes);
3129
3130 std::vector<int64_t> permutation;
3131 const auto rank = operand_shape->rank();
3132 permutation.reserve(rank + 1);
3133 for (int64_t i = 0; i < rank; ++i) {
3134 int64_t dim_after_reshape = i >= split_dimension ? i + 1 : i;
3135 if (i == concat_dimension) {
3136 permutation.push_back(split_dimension);
3137 }
3138 permutation.push_back(dim_after_reshape);
3139 }
3140 all_to_all = Transpose(all_to_all, permutation);
3141 return Reshape(all_to_all_shape, all_to_all);
3142 });
3143 }
3144
AllToAllTuple(absl::Span<const XlaOp> operands,absl::Span<const ReplicaGroup> replica_groups,const std::optional<Layout> & layout)3145 XlaOp XlaBuilder::AllToAllTuple(absl::Span<const XlaOp> operands,
3146 absl::Span<const ReplicaGroup> replica_groups,
3147 const std::optional<Layout>& layout) {
3148 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3149 HloInstructionProto instr;
3150 TF_ASSIGN_OR_RETURN(auto operand_shapes, this->GetOperandShapes(operands));
3151 std::vector<const Shape*> operand_shape_ptrs;
3152 operand_shape_ptrs.reserve(operand_shapes.size());
3153 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
3154 [](const Shape& shape) { return &shape; });
3155 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferAllToAllTupleShape(
3156 operand_shape_ptrs));
3157
3158 if (layout) {
3159 TF_RET_CHECK(shape.IsTuple() && !ShapeUtil::IsNestedTuple(shape));
3160 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
3161 const int64_t layout_minor_to_major_size =
3162 layout->minor_to_major().size();
3163 if (layout_minor_to_major_size != shape.tuple_shapes(i).rank()) {
3164 return InvalidArgument(
3165 "Provided layout must be compatible with the operands' shape. "
3166 "The layout is %s, but operand %d has shape %s.",
3167 layout->ToString(), i, shape.tuple_shapes(i).ToString());
3168 }
3169 *(shape.mutable_tuple_shapes(i)->mutable_layout()) = *layout;
3170 }
3171 instr.set_constrain_layout(true);
3172 }
3173 *instr.mutable_shape() = shape.ToProto();
3174
3175 for (const ReplicaGroup& group : replica_groups) {
3176 *instr.add_replica_groups() = group;
3177 }
3178
3179 return AddInstruction(std::move(instr), HloOpcode::kAllToAll, operands);
3180 });
3181 }
3182
AllToAllTuple(XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups,const std::optional<Layout> & layout)3183 XlaOp XlaBuilder::AllToAllTuple(XlaOp operand, int64_t split_dimension,
3184 int64_t concat_dimension, int64_t split_count,
3185 absl::Span<const ReplicaGroup> replica_groups,
3186 const std::optional<Layout>& layout) {
3187 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3188 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3189
3190 // The HloInstruction for Alltoall currently only handles the data
3191 // communication: it accepts N already split parts and scatters them to N
3192 // cores, and each core gathers the N received parts into a tuple as the
3193 // output. So here we explicitly split the operand before the hlo alltoall,
3194 // and concat the tuple elements.
3195 //
3196 // First, run shape inference to make sure the shapes are valid.
3197 TF_RETURN_IF_ERROR(
3198 ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
3199 concat_dimension, split_count)
3200 .status());
3201
3202 // Split into N parts.
3203 std::vector<XlaOp> slices;
3204 slices.reserve(split_count);
3205 const int64_t block_size =
3206 operand_shape->dimensions(split_dimension) / split_count;
3207 for (int i = 0; i < split_count; i++) {
3208 slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size,
3209 /*limit_index=*/(i + 1) * block_size,
3210 /*stride=*/1, /*dimno=*/split_dimension));
3211 }
3212
3213 // Handle data communication.
3214 XlaOp alltoall = this->AllToAllTuple(slices, replica_groups, layout);
3215
3216 // Concat the N received parts.
3217 std::vector<XlaOp> received;
3218 received.reserve(split_count);
3219 for (int i = 0; i < split_count; i++) {
3220 received.push_back(this->GetTupleElement(alltoall, i));
3221 }
3222 return this->ConcatInDim(received, concat_dimension);
3223 });
3224 }
3225
CollectivePermute(XlaOp operand,const std::vector<std::pair<int64_t,int64_t>> & source_target_pairs)3226 XlaOp XlaBuilder::CollectivePermute(
3227 XlaOp operand,
3228 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs) {
3229 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3230 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3231 HloInstructionProto instr;
3232 TF_ASSIGN_OR_RETURN(
3233 Shape shape,
3234 ShapeInference::InferCollectivePermuteShape({operand_shape}));
3235 *instr.mutable_shape() = shape.ToProto();
3236
3237 for (const auto& pair : source_target_pairs) {
3238 auto* proto_pair = instr.add_source_target_pairs();
3239 proto_pair->set_source(pair.first);
3240 proto_pair->set_target(pair.second);
3241 }
3242
3243 return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute,
3244 {operand});
3245 });
3246 }
3247
ReplicaId()3248 XlaOp XlaBuilder::ReplicaId() {
3249 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3250 HloInstructionProto instr;
3251 *instr.mutable_shape() = ShapeUtil::MakeShape(U32, {}).ToProto();
3252 return AddInstruction(std::move(instr), HloOpcode::kReplicaId, {});
3253 });
3254 }
3255
SelectAndScatter(XlaOp operand,const XlaComputation & select,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,Padding padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)3256 XlaOp XlaBuilder::SelectAndScatter(XlaOp operand, const XlaComputation& select,
3257 absl::Span<const int64_t> window_dimensions,
3258 absl::Span<const int64_t> window_strides,
3259 Padding padding, XlaOp source,
3260 XlaOp init_value,
3261 const XlaComputation& scatter) {
3262 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3263 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3264
3265 std::vector<std::pair<int64_t, int64_t>> padding_values =
3266 MakePadding(operand_shape->dimensions(), window_dimensions,
3267 window_strides, padding);
3268
3269 TF_ASSIGN_OR_RETURN(auto window,
3270 ShapeInference::InferWindowFromDimensions(
3271 window_dimensions, window_strides, padding_values,
3272 /*lhs_dilation=*/{},
3273 /*rhs_dilation=*/{}));
3274 PaddingType padding_type = PADDING_INVALID;
3275 for (int64_t i = 0; i < operand_shape->rank(); ++i) {
3276 if (operand_shape->is_dynamic_dimension(i) &&
3277 !window_util::IsTrivialWindowDimension(window.dimensions(i)) &&
3278 padding == Padding::kSame) {
3279 // SAME padding can create dynamic padding sizes. The padding size
3280 // need to be rewritten by dynamic padder using HloInstructions. We
3281 // create a CustomCall to handle this.
3282 padding_type = PADDING_SAME;
3283 }
3284 }
3285 if (padding_type == PADDING_SAME) {
3286 TF_ASSIGN_OR_RETURN(
3287 HloInstructionProto instr,
3288 SelectAndScatterInternal(operand, select, window_dimensions,
3289 window_strides, padding_values, source,
3290 init_value, scatter));
3291 instr.set_custom_call_target("DynamicSelectAndScatterSamePadding");
3292 return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
3293 {operand, source, init_value});
3294 }
3295 return SelectAndScatterWithGeneralPadding(
3296 operand, select, window_dimensions, window_strides, padding_values,
3297 source, init_value, scatter);
3298 });
3299 }
3300
SelectAndScatterInternal(XlaOp operand,const XlaComputation & select,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)3301 StatusOr<HloInstructionProto> XlaBuilder::SelectAndScatterInternal(
3302 XlaOp operand, const XlaComputation& select,
3303 absl::Span<const int64_t> window_dimensions,
3304 absl::Span<const int64_t> window_strides,
3305 absl::Span<const std::pair<int64_t, int64_t>> padding, XlaOp source,
3306 XlaOp init_value, const XlaComputation& scatter) {
3307 HloInstructionProto instr;
3308
3309 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3310 TF_ASSIGN_OR_RETURN(const Shape* source_shape, GetShapePtr(source));
3311 TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
3312 TF_ASSIGN_OR_RETURN(const ProgramShape& select_shape,
3313 select.GetProgramShape());
3314 TF_ASSIGN_OR_RETURN(const ProgramShape& scatter_shape,
3315 scatter.GetProgramShape());
3316 TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
3317 ShapeInference::InferWindowFromDimensions(
3318 window_dimensions, window_strides, padding,
3319 /*lhs_dilation=*/{}, /*rhs_dilation=*/{}));
3320 TF_ASSIGN_OR_RETURN(Shape shape,
3321 ShapeInference::InferSelectAndScatterShape(
3322 *operand_shape, select_shape, instr.window(),
3323 *source_shape, *init_shape, scatter_shape));
3324 *instr.mutable_shape() = shape.ToProto();
3325
3326 AddCalledComputation(select, &instr);
3327 AddCalledComputation(scatter, &instr);
3328 return instr;
3329 }
3330
SelectAndScatterWithGeneralPadding(XlaOp operand,const XlaComputation & select,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)3331 XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
3332 XlaOp operand, const XlaComputation& select,
3333 absl::Span<const int64_t> window_dimensions,
3334 absl::Span<const int64_t> window_strides,
3335 absl::Span<const std::pair<int64_t, int64_t>> padding, XlaOp source,
3336 XlaOp init_value, const XlaComputation& scatter) {
3337 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3338 TF_ASSIGN_OR_RETURN(HloInstructionProto instr,
3339 SelectAndScatterInternal(
3340 operand, select, window_dimensions, window_strides,
3341 padding, source, init_value, scatter));
3342
3343 return AddInstruction(std::move(instr), HloOpcode::kSelectAndScatter,
3344 {operand, source, init_value});
3345 });
3346 }
3347
ReducePrecision(XlaOp operand,const int exponent_bits,const int mantissa_bits)3348 XlaOp XlaBuilder::ReducePrecision(XlaOp operand, const int exponent_bits,
3349 const int mantissa_bits) {
3350 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3351 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3352 TF_ASSIGN_OR_RETURN(Shape shape,
3353 ShapeInference::InferReducePrecisionShape(
3354 *operand_shape, exponent_bits, mantissa_bits));
3355 return ReducePrecisionInternal(shape, operand, exponent_bits,
3356 mantissa_bits);
3357 });
3358 }
3359
ReducePrecisionInternal(const Shape & shape,XlaOp operand,const int exponent_bits,const int mantissa_bits)3360 StatusOr<XlaOp> XlaBuilder::ReducePrecisionInternal(const Shape& shape,
3361 XlaOp operand,
3362 const int exponent_bits,
3363 const int mantissa_bits) {
3364 HloInstructionProto instr;
3365 *instr.mutable_shape() = shape.ToProto();
3366 instr.set_exponent_bits(exponent_bits);
3367 instr.set_mantissa_bits(mantissa_bits);
3368 return AddInstruction(std::move(instr), HloOpcode::kReducePrecision,
3369 {operand});
3370 }
3371
Send(XlaOp operand,const ChannelHandle & handle)3372 void XlaBuilder::Send(XlaOp operand, const ChannelHandle& handle) {
3373 ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3374 // Send HLO takes two operands: a data operand and a token. Generate the
3375 // token to pass into the send.
3376 // TODO(b/80000000): Remove this when clients have been updated to handle
3377 // tokens.
3378 HloInstructionProto token_instr;
3379 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3380 TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
3381 HloOpcode::kAfterAll, {}));
3382
3383 return SendWithToken(operand, token, handle);
3384 });
3385 }
3386
SendWithToken(XlaOp operand,XlaOp token,const ChannelHandle & handle)3387 XlaOp XlaBuilder::SendWithToken(XlaOp operand, XlaOp token,
3388 const ChannelHandle& handle) {
3389 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3390 if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
3391 return InvalidArgument("Send must use a device-to-device channel");
3392 }
3393
3394 // Send instruction produces a tuple of {aliased operand, U32 context,
3395 // token}.
3396 HloInstructionProto send_instr;
3397 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
3398 *send_instr.mutable_shape() =
3399 ShapeUtil::MakeTupleShape({*shape, ShapeUtil::MakeShape(U32, {}),
3400 ShapeUtil::MakeTokenShape()})
3401 .ToProto();
3402 send_instr.set_channel_id(handle.handle());
3403 TF_ASSIGN_OR_RETURN(XlaOp send,
3404 AddInstruction(std::move(send_instr), HloOpcode::kSend,
3405 {operand, token}));
3406
3407 HloInstructionProto send_done_instr;
3408 *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3409 send_done_instr.set_channel_id(handle.handle());
3410 return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
3411 {send});
3412 });
3413 }
3414
Recv(const Shape & shape,const ChannelHandle & handle)3415 XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
3416 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3417 // Recv HLO takes a single token operand. Generate the token to pass into
3418 // the Recv and RecvDone instructions.
3419 // TODO(b/80000000): Remove this when clients have been updated to handle
3420 // tokens.
3421 HloInstructionProto token_instr;
3422 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3423 TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
3424 HloOpcode::kAfterAll, {}));
3425
3426 XlaOp recv = RecvWithToken(token, shape, handle);
3427
3428 // The RecvDone instruction produces a tuple of the data and a token
3429 // type. Return XLA op containing the data.
3430 // TODO(b/80000000): Remove this when clients have been updated to handle
3431 // tokens.
3432 HloInstructionProto recv_data;
3433 *recv_data.mutable_shape() = shape.ToProto();
3434 recv_data.set_tuple_index(0);
3435 return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement,
3436 {recv});
3437 });
3438 }
3439
RecvWithToken(XlaOp token,const Shape & shape,const ChannelHandle & handle)3440 XlaOp XlaBuilder::RecvWithToken(XlaOp token, const Shape& shape,
3441 const ChannelHandle& handle) {
3442 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3443 if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
3444 return InvalidArgument("Recv must use a device-to-device channel");
3445 }
3446
3447 // Recv instruction produces a tuple of {receive buffer, U32 context,
3448 // token}.
3449 HloInstructionProto recv_instr;
3450 *recv_instr.mutable_shape() =
3451 ShapeUtil::MakeTupleShape(
3452 {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
3453 .ToProto();
3454 recv_instr.set_channel_id(handle.handle());
3455 TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
3456 HloOpcode::kRecv, {token}));
3457
3458 HloInstructionProto recv_done_instr;
3459 *recv_done_instr.mutable_shape() =
3460 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
3461 .ToProto();
3462 recv_done_instr.set_channel_id(handle.handle());
3463 return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
3464 {recv});
3465 });
3466 }
3467
SendToHost(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const ChannelHandle & handle)3468 XlaOp XlaBuilder::SendToHost(XlaOp operand, XlaOp token,
3469 const Shape& shape_with_layout,
3470 const ChannelHandle& handle) {
3471 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3472 if (!LayoutUtil::HasLayout(shape_with_layout)) {
3473 return InvalidArgument("Shape passed to SendToHost must have a layout");
3474 }
3475 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3476 if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
3477 return InvalidArgument(
3478 "SendToHost shape %s must be compatible with operand shape %s",
3479 ShapeUtil::HumanStringWithLayout(shape_with_layout),
3480 ShapeUtil::HumanStringWithLayout(*operand_shape));
3481 }
3482 // TODO(b/111544877): Support tuple shapes.
3483 if (!operand_shape->IsArray()) {
3484 return InvalidArgument("SendToHost only supports array shapes, shape: %s",
3485 ShapeUtil::HumanString(*operand_shape));
3486 }
3487
3488 if (handle.type() != ChannelHandle::DEVICE_TO_HOST) {
3489 return InvalidArgument("SendToHost must use a device-to-host channel");
3490 }
3491
3492 // Send instruction produces a tuple of {aliased operand, U32 context,
3493 // token}.
3494 HloInstructionProto send_instr;
3495 *send_instr.mutable_shape() =
3496 ShapeUtil::MakeTupleShape({shape_with_layout,
3497 ShapeUtil::MakeShape(U32, {}),
3498 ShapeUtil::MakeTokenShape()})
3499 .ToProto();
3500 send_instr.set_channel_id(handle.handle());
3501 send_instr.set_is_host_transfer(true);
3502 TF_ASSIGN_OR_RETURN(XlaOp send,
3503 AddInstruction(std::move(send_instr), HloOpcode::kSend,
3504 {operand, token}));
3505
3506 HloInstructionProto send_done_instr;
3507 *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3508 send_done_instr.set_channel_id(handle.handle());
3509 send_done_instr.set_is_host_transfer(true);
3510 TF_ASSIGN_OR_RETURN(XlaOp send_done,
3511 AddInstruction(std::move(send_done_instr),
3512 HloOpcode::kSendDone, {send}));
3513 return send_done;
3514 });
3515 }
3516
RecvFromHost(XlaOp token,const Shape & shape,const ChannelHandle & handle)3517 XlaOp XlaBuilder::RecvFromHost(XlaOp token, const Shape& shape,
3518 const ChannelHandle& handle) {
3519 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3520 if (!LayoutUtil::HasLayout(shape)) {
3521 return InvalidArgument("Shape passed to RecvFromHost must have a layout");
3522 }
3523
3524 // TODO(b/111544877): Support tuple shapes.
3525 if (!shape.IsArray()) {
3526 return InvalidArgument(
3527 "RecvFromHost only supports array shapes, shape: %s",
3528 ShapeUtil::HumanString(shape));
3529 }
3530
3531 if (handle.type() != ChannelHandle::HOST_TO_DEVICE) {
3532 return InvalidArgument("RecvFromHost must use a host-to-device channel");
3533 }
3534
3535 // Recv instruction produces a tuple of {receive buffer, U32 context,
3536 // token}.
3537 HloInstructionProto recv_instr;
3538 *recv_instr.mutable_shape() =
3539 ShapeUtil::MakeTupleShape(
3540 {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
3541 .ToProto();
3542 recv_instr.set_channel_id(handle.handle());
3543 recv_instr.set_is_host_transfer(true);
3544 TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
3545 HloOpcode::kRecv, {token}));
3546
3547 HloInstructionProto recv_done_instr;
3548 *recv_done_instr.mutable_shape() =
3549 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
3550 .ToProto();
3551 recv_done_instr.set_channel_id(handle.handle());
3552 recv_done_instr.set_is_host_transfer(true);
3553 return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
3554 {recv});
3555 });
3556 }
3557
GetDimensionSize(XlaOp operand,int64_t dimension)3558 XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64_t dimension) {
3559 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3560 HloInstructionProto instr;
3561 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3562 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape(
3563 *operand_shape, dimension));
3564 // Calling GetDimensionSize on a static dimension returns a constant
3565 // instruction.
3566 if (!operand_shape->is_dynamic_dimension(dimension)) {
3567 return ConstantR0<int32_t>(this, operand_shape->dimensions(dimension));
3568 }
3569 *instr.mutable_shape() = shape.ToProto();
3570 instr.add_dimensions(dimension);
3571 return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize,
3572 {operand});
3573 });
3574 }
3575
RemoveDynamicDimension(XlaOp operand,int64_t dimension)3576 XlaOp XlaBuilder::RemoveDynamicDimension(XlaOp operand, int64_t dimension) {
3577 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3578 HloInstructionProto instr;
3579 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3580
3581 Shape shape = *operand_shape;
3582 shape.set_dynamic_dimension(dimension, false);
3583 // Setting an op's dynamic dimension to its static size removes the dynamic
3584 // dimension.
3585 XlaOp static_size =
3586 ConstantR0<int32_t>(this, operand_shape->dimensions(dimension));
3587 return SetDimensionSizeInternal(shape, operand, static_size, dimension);
3588 });
3589 }
3590
SetDimensionSize(XlaOp operand,XlaOp val,int64_t dimension)3591 XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val,
3592 int64_t dimension) {
3593 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3594 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3595 TF_ASSIGN_OR_RETURN(const Shape* val_shape, GetShapePtr(val));
3596
3597 TF_ASSIGN_OR_RETURN(Shape shape,
3598 ShapeInference::InferSetDimensionSizeShape(
3599 *operand_shape, *val_shape, dimension));
3600 return SetDimensionSizeInternal(shape, operand, val, dimension);
3601 });
3602 }
3603
SetDimensionSizeInternal(const Shape & shape,XlaOp operand,XlaOp val,int64_t dimension)3604 StatusOr<XlaOp> XlaBuilder::SetDimensionSizeInternal(const Shape& shape,
3605 XlaOp operand, XlaOp val,
3606 int64_t dimension) {
3607 TF_ASSIGN_OR_RETURN(const HloInstructionProto* val_proto,
3608 LookUpInstruction(val));
3609 if (StringToHloOpcode(val_proto->opcode()).ValueOrDie() ==
3610 HloOpcode::kConstant &&
3611 shape.is_dynamic_dimension(dimension)) {
3612 TF_ASSIGN_OR_RETURN(auto constant_size,
3613 Literal::CreateFromProto(val_proto->literal(), true));
3614 if (constant_size.Get<int32_t>({}) == shape.dimensions(dimension)) {
3615 return operand;
3616 }
3617 }
3618
3619 HloInstructionProto instr;
3620 *instr.mutable_shape() = shape.ToProto();
3621 instr.add_dimensions(dimension);
3622 return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize,
3623 {operand, val});
3624 }
3625
IsConstant(XlaOp operand) const3626 StatusOr<bool> XlaBuilder::IsConstant(XlaOp operand) const {
3627 TF_RETURN_IF_ERROR(first_error_);
3628
3629 // Verify that the handle is valid.
3630 TF_RETURN_IF_ERROR(LookUpInstruction(operand).status());
3631
3632 bool is_constant = true;
3633 absl::flat_hash_set<int64_t> visited;
3634 IsConstantVisitor(operand.handle(), /*depth=*/0, &visited, &is_constant);
3635 return is_constant;
3636 }
3637
BuildConstantSubGraph(XlaOp root_op,bool dynamic_dimension_is_minus_one)3638 StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
3639 XlaOp root_op, bool dynamic_dimension_is_minus_one) {
3640 TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
3641 if (!is_constant) {
3642 auto op_status = LookUpInstruction(root_op);
3643 std::string op_string =
3644 op_status.ok() ? op_status.ValueOrDie()->name() : "<unknown operation>";
3645 return InvalidArgument(
3646 "Operand to BuildConstantSubGraph depends on a parameter.\n\n"
3647 " op requested for constant subgraph: %s\n\n"
3648 "This is an internal error that typically happens when the XLA user "
3649 "(e.g. TensorFlow) is attempting to determine a value that must be a "
3650 "compile-time constant (e.g. an array dimension) but it is not capable "
3651 "of being evaluated at XLA compile time.\n\n"
3652 "Please file a usability bug with the framework being used (e.g. "
3653 "TensorFlow).",
3654 op_string);
3655 }
3656
3657 TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
3658 LookUpInstruction(root_op));
3659 if (VLOG_IS_ON(4)) {
3660 VLOG(4) << "Build constant subgraph for:\n" << OpToString(root_op);
3661 }
3662
3663 HloComputationProto entry;
3664 SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator,
3665 GetNextId());
3666 ProgramShapeProto* program_shape = entry.mutable_program_shape();
3667 *program_shape->mutable_result() = root->shape();
3668
3669 // We use std::set to keep the instruction ids in ascending order (which is
3670 // also a valid dependency order). The related ops will be added to the
3671 // subgraph in the same order.
3672 std::set<int64_t> related_ops;
3673 absl::flat_hash_map<int64_t, int64_t> substitutions;
3674 absl::flat_hash_set<int64_t> related_calls; // Related computations.
3675 std::queue<int64_t> worklist;
3676 worklist.push(root->id());
3677 related_ops.insert(root->id());
3678
3679 while (!worklist.empty()) {
3680 int64_t handle = worklist.front();
3681 worklist.pop();
3682 TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
3683 LookUpInstructionByHandle(handle));
3684
3685 auto default_behavior = [&related_ops, &worklist, &related_calls,
3686 instr_proto]() {
3687 for (int64_t id : instr_proto->operand_ids()) {
3688 if (related_ops.insert(id).second) {
3689 worklist.push(id);
3690 }
3691 }
3692 for (int64_t called_id : instr_proto->called_computation_ids()) {
3693 related_calls.insert(called_id);
3694 }
3695 };
3696
3697 if (instr_proto->opcode() ==
3698 HloOpcodeString(HloOpcode::kGetDimensionSize) ||
3699 InstrIsSetBound(instr_proto)) {
3700 int32_t constant_value = -1;
3701 HloInstructionProto const_instr;
3702
3703 if (instr_proto->opcode() ==
3704 HloOpcodeString(HloOpcode::kGetDimensionSize)) {
3705 // At this point, BuildConstantSubGraph should never encounter a
3706 // GetDimensionSize with a dynamic dimension. IsConstant check would
3707 // have failed at the beginning of this function.
3708 //
3709 // Replace GetDimensionSize with a Constant representing the static
3710 // bound of the shape.
3711 int64_t dimension = instr_proto->dimensions(0);
3712 int64_t operand_handle = instr_proto->operand_ids(0);
3713 TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
3714 LookUpInstructionByHandle(operand_handle));
3715
3716 if (!(operand_proto->shape().is_dynamic_dimension(dimension) &&
3717 dynamic_dimension_is_minus_one)) {
3718 constant_value = static_cast<int32_t>(
3719 operand_proto->shape().dimensions(dimension));
3720 }
3721 Literal literal = LiteralUtil::CreateR0(constant_value);
3722 *const_instr.mutable_literal() = literal.ToProto();
3723 *const_instr.mutable_shape() = literal.shape().ToProto();
3724 } else {
3725 if (instr_proto->literal().shape().element_type() == TUPLE) {
3726 *const_instr.mutable_literal() =
3727 // First literal of SetBound contains bounds, second literal
3728 // contains dynamism indicators.
3729 instr_proto->literal().tuple_literals(0);
3730 } else {
3731 *const_instr.mutable_literal() = instr_proto->literal();
3732 }
3733
3734 *const_instr.mutable_shape() = instr_proto->shape();
3735 }
3736 *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
3737 const_instr.set_id(handle);
3738 *const_instr.mutable_name() =
3739 GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id());
3740 *entry.add_instructions() =
3741 const_instr; // Add to the result constant graph.
3742
3743 } else if (instr_proto->opcode() ==
3744 HloOpcodeString(HloOpcode::kGetTupleElement)) {
3745 // Look through GTE(Tuple(..), i).
3746 TF_ASSIGN_OR_RETURN(
3747 const HloInstructionProto* maybe_tuple_instr,
3748 LookUpInstructionByHandle(instr_proto->operand_ids(0)));
3749
3750 if (maybe_tuple_instr->opcode() == HloOpcodeString(HloOpcode::kTuple)) {
3751 int64_t id = maybe_tuple_instr->operand_ids(instr_proto->tuple_index());
3752 // Enqueue any dependencies of `id`.
3753 if (related_ops.insert(id).second) {
3754 worklist.push(id);
3755 }
3756 substitutions[handle] = id;
3757
3758 } else {
3759 default_behavior();
3760 }
3761
3762 } else {
3763 default_behavior();
3764 }
3765 }
3766
3767 // Resolve any substitutions for the root id.
3768 int64_t root_id = root->id();
3769 auto it = substitutions.find(root_id);
3770 while (it != substitutions.end()) {
3771 root_id = it->second;
3772 it = substitutions.find(root_id);
3773 }
3774 entry.set_root_id(root_id);
3775
3776 // Add related ops to the computation.
3777 for (int64_t id : related_ops) {
3778 if (substitutions.find(id) != substitutions.end()) {
3779 // Skip adding this instruction; we will replace references to it with the
3780 // substitution instruction's id.
3781 continue;
3782 }
3783 TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
3784 LookUpInstructionByHandle(id));
3785
3786 if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize) ||
3787 InstrIsSetBound(instr_src)) {
3788 continue;
3789 }
3790 HloInstructionProto* instr = entry.add_instructions();
3791 *instr = *instr_src;
3792 // Replace operands in case we have substitutions mapped.
3793 instr->clear_operand_ids();
3794 for (int64_t operand_id : instr_src->operand_ids()) {
3795 auto it = substitutions.find(operand_id);
3796 while (it != substitutions.end()) {
3797 operand_id = it->second;
3798 it = substitutions.find(operand_id);
3799 }
3800 instr->add_operand_ids(operand_id);
3801 }
3802 // Ensures that the instruction names are unique among the graph.
3803 const std::string& new_name =
3804 StrCat(instr->name(), ".", entry.id(), ".", instr->id());
3805 instr->set_name(new_name);
3806 }
3807
3808 XlaComputation computation(entry.id());
3809 HloModuleProto* module = computation.mutable_proto();
3810 module->set_name(entry.name());
3811 module->set_id(entry.id());
3812 module->set_entry_computation_name(entry.name());
3813 module->set_entry_computation_id(entry.id());
3814 *module->mutable_host_program_shape() = *program_shape;
3815 for (auto& e : embedded_) {
3816 if (related_calls.find(e.second.id()) != related_calls.end()) {
3817 *module->add_computations() = e.second;
3818 }
3819 }
3820 *module->add_computations() = std::move(entry);
3821 if (VLOG_IS_ON(4)) {
3822 VLOG(4) << "Constant computation:\n" << module->DebugString();
3823 }
3824 return std::move(computation);
3825 }
3826
CreateSubBuilder(const std::string & computation_name)3827 std::unique_ptr<XlaBuilder> XlaBuilder::CreateSubBuilder(
3828 const std::string& computation_name) {
3829 auto sub_builder = std::make_unique<XlaBuilder>(computation_name);
3830 sub_builder->parent_builder_ = this;
3831 sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_;
3832 return sub_builder;
3833 }
3834
3835 /* static */ ConvolutionDimensionNumbers
CreateDefaultConvDimensionNumbers(int num_spatial_dims)3836 XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
3837 ConvolutionDimensionNumbers dimension_numbers;
3838 dimension_numbers.set_input_batch_dimension(kConvBatchDimension);
3839 dimension_numbers.set_input_feature_dimension(kConvFeatureDimension);
3840 dimension_numbers.set_output_batch_dimension(kConvBatchDimension);
3841 dimension_numbers.set_output_feature_dimension(kConvFeatureDimension);
3842 dimension_numbers.set_kernel_output_feature_dimension(
3843 kConvKernelOutputDimension);
3844 dimension_numbers.set_kernel_input_feature_dimension(
3845 kConvKernelInputDimension);
3846 for (int i = 0; i < num_spatial_dims; ++i) {
3847 dimension_numbers.add_input_spatial_dimensions(i + 2);
3848 dimension_numbers.add_kernel_spatial_dimensions(i + 2);
3849 dimension_numbers.add_output_spatial_dimensions(i + 2);
3850 }
3851 return dimension_numbers;
3852 }
3853
Validate(const ConvolutionDimensionNumbers & dnum)3854 /* static */ Status XlaBuilder::Validate(
3855 const ConvolutionDimensionNumbers& dnum) {
3856 if (dnum.input_spatial_dimensions_size() < 2) {
3857 return FailedPrecondition("input spacial dimension < 2: %d",
3858 dnum.input_spatial_dimensions_size());
3859 }
3860 if (dnum.kernel_spatial_dimensions_size() < 2) {
3861 return FailedPrecondition("kernel spacial dimension < 2: %d",
3862 dnum.kernel_spatial_dimensions_size());
3863 }
3864 if (dnum.output_spatial_dimensions_size() < 2) {
3865 return FailedPrecondition("output spacial dimension < 2: %d",
3866 dnum.output_spatial_dimensions_size());
3867 }
3868
3869 if (std::set<int64_t>(
3870 {dnum.input_batch_dimension(), dnum.input_feature_dimension(),
3871 dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)})
3872 .size() != 4) {
3873 return FailedPrecondition(
3874 "dimension numbers for the input are not unique: (%d, %d, %d, "
3875 "%d)",
3876 dnum.input_batch_dimension(), dnum.input_feature_dimension(),
3877 dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1));
3878 }
3879 if (std::set<int64_t>({dnum.kernel_output_feature_dimension(),
3880 dnum.kernel_input_feature_dimension(),
3881 dnum.kernel_spatial_dimensions(0),
3882 dnum.kernel_spatial_dimensions(1)})
3883 .size() != 4) {
3884 return FailedPrecondition(
3885 "dimension numbers for the weight are not unique: (%d, %d, %d, "
3886 "%d)",
3887 dnum.kernel_output_feature_dimension(),
3888 dnum.kernel_input_feature_dimension(),
3889 dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1));
3890 }
3891 if (std::set<int64_t>({dnum.output_batch_dimension(),
3892 dnum.output_feature_dimension(),
3893 dnum.output_spatial_dimensions(0),
3894 dnum.output_spatial_dimensions(1)})
3895 .size() != 4) {
3896 return FailedPrecondition(
3897 "dimension numbers for the output are not unique: (%d, %d, %d, "
3898 "%d)",
3899 dnum.output_batch_dimension(), dnum.output_feature_dimension(),
3900 dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1));
3901 }
3902 return OkStatus();
3903 }
3904
AddInstruction(HloInstructionProto && instr,HloOpcode opcode,absl::Span<const XlaOp> operands)3905 StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
3906 HloOpcode opcode,
3907 absl::Span<const XlaOp> operands) {
3908 TF_RETURN_IF_ERROR(first_error_);
3909
3910 const int64_t handle = GetNextId();
3911 instr.set_id(handle);
3912 instr.set_opcode(HloOpcodeString(opcode));
3913 if (instr.name().empty()) {
3914 instr.set_name(instr.opcode());
3915 }
3916 for (const auto& operand : operands) {
3917 if (operand.builder_ == nullptr) {
3918 return InvalidArgument("invalid XlaOp with handle %d", operand.handle());
3919 }
3920 if (operand.builder_ != this) {
3921 return InvalidArgument("Do not add XlaOp from builder %s to builder %s",
3922 operand.builder_->name(), this->name());
3923 }
3924 instr.add_operand_ids(operand.handle());
3925 }
3926
3927 if (one_shot_metadata_.has_value()) {
3928 *instr.mutable_metadata() = one_shot_metadata_.value();
3929 one_shot_metadata_.reset();
3930 } else {
3931 *instr.mutable_metadata() = metadata_;
3932 }
3933 if (sharding_) {
3934 *instr.mutable_sharding() = *sharding_;
3935 }
3936 *instr.mutable_frontend_attributes() = frontend_attributes_;
3937
3938 handle_to_index_[handle] = instructions_.size();
3939 instructions_.push_back(std::move(instr));
3940 instruction_shapes_.push_back(
3941 std::make_unique<Shape>(instructions_.back().shape()));
3942
3943 XlaOp op(handle, this);
3944 return op;
3945 }
3946
AddOpWithShape(HloOpcode opcode,const Shape & shape,absl::Span<const XlaOp> operands)3947 StatusOr<XlaOp> XlaBuilder::AddOpWithShape(HloOpcode opcode, const Shape& shape,
3948 absl::Span<const XlaOp> operands) {
3949 HloInstructionProto instr;
3950 *instr.mutable_shape() = shape.ToProto();
3951 return AddInstruction(std::move(instr), opcode, operands);
3952 }
3953
AddCalledComputation(const XlaComputation & computation,HloInstructionProto * instr)3954 void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
3955 HloInstructionProto* instr) {
3956 absl::flat_hash_map<int64_t, int64_t> remapped_ids;
3957 std::vector<HloComputationProto> imported_computations;
3958 imported_computations.reserve(computation.proto().computations_size());
3959 // Before we import the computations by remapping IDs, and capturing the
3960 // old->new mappings in remapped_ids.
3961 for (const HloComputationProto& e : computation.proto().computations()) {
3962 HloComputationProto new_computation(e);
3963 int64_t computation_id = GetNextId();
3964 remapped_ids[new_computation.id()] = computation_id;
3965 SetProtoIdAndName(&new_computation,
3966 GetBaseName(new_computation.name(), kNameSeparator),
3967 kNameSeparator, computation_id);
3968 for (auto& instruction : *new_computation.mutable_instructions()) {
3969 int64_t instruction_id = GetNextId();
3970 remapped_ids[instruction.id()] = instruction_id;
3971 SetProtoIdAndName(&instruction,
3972 GetBaseName(instruction.name(), kNameSeparator),
3973 kNameSeparator, instruction_id);
3974 }
3975 new_computation.set_root_id(remapped_ids.at(new_computation.root_id()));
3976
3977 imported_computations.push_back(std::move(new_computation));
3978 }
3979 // Once we have imported all the computations, and captured all the ID
3980 // mappings, we go back and fixup the IDs in the imported computations.
3981 instr->add_called_computation_ids(
3982 remapped_ids.at(computation.proto().entry_computation_id()));
3983 for (auto& imported_computation : imported_computations) {
3984 for (auto& instruction : *imported_computation.mutable_instructions()) {
3985 for (auto& operand_id : *instruction.mutable_operand_ids()) {
3986 operand_id = remapped_ids.at(operand_id);
3987 }
3988 for (auto& control_predecessor_id :
3989 *instruction.mutable_control_predecessor_ids()) {
3990 control_predecessor_id = remapped_ids.at(control_predecessor_id);
3991 }
3992 for (auto& called_computation_id :
3993 *instruction.mutable_called_computation_ids()) {
3994 called_computation_id = remapped_ids.at(called_computation_id);
3995 }
3996 }
3997
3998 int64_t computation_id = imported_computation.id();
3999 for (int64_t i = 0; i < imported_computation.instructions_size(); ++i) {
4000 ImportedInstruction imported_instruction;
4001 imported_instruction.computation_id = computation_id;
4002 imported_instruction.instruction_index = i;
4003 handle_to_imported_index_.insert(
4004 {imported_computation.instructions(i).id(), imported_instruction});
4005 }
4006 embedded_.insert({computation_id, std::move(imported_computation)});
4007 }
4008 }
4009
LookUpInstruction(const XlaOp op) const4010 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
4011 const XlaOp op) const {
4012 TF_RETURN_IF_ERROR(first_error_);
4013 return LookUpInstructionInternal<const HloInstructionProto*>(op);
4014 }
4015
LookUpInstructionByHandle(int64_t handle) const4016 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
4017 int64_t handle) const {
4018 return LookUpInstructionByHandleInternal<const HloInstructionProto*>(handle);
4019 }
4020
LookUpMutableInstruction(const XlaOp op)4021 StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstruction(
4022 const XlaOp op) {
4023 TF_RETURN_IF_ERROR(first_error_);
4024 return LookUpInstructionInternal<HloInstructionProto*>(op);
4025 }
4026
LookUpMutableInstructionByHandle(int64_t handle)4027 StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstructionByHandle(
4028 int64_t handle) {
4029 return LookUpInstructionByHandleInternal<HloInstructionProto*>(handle);
4030 }
4031
4032 // Enqueues a "retrieve parameter value" instruction for a parameter that was
4033 // passed to the computation.
Parameter(XlaBuilder * builder,int64_t parameter_number,const Shape & shape,const std::string & name)4034 XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
4035 const Shape& shape, const std::string& name) {
4036 std::vector<bool> empty_bools;
4037 return Parameter(builder, parameter_number, shape, name, empty_bools);
4038 }
4039
Parameter(XlaBuilder * builder,int64_t parameter_number,const Shape & shape,const std::string & name,const std::vector<bool> & replicated_at_leaf_buffers)4040 XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
4041 const Shape& shape, const std::string& name,
4042 const std::vector<bool>& replicated_at_leaf_buffers) {
4043 return builder->Parameter(parameter_number, shape, name,
4044 replicated_at_leaf_buffers);
4045 }
4046
4047 // Enqueues a constant with the value of the given literal onto the
4048 // computation.
ConstantLiteral(XlaBuilder * builder,const LiteralSlice & literal)4049 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) {
4050 return builder->ConstantLiteral(literal);
4051 }
4052
Broadcast(const XlaOp operand,absl::Span<const int64_t> broadcast_sizes)4053 XlaOp Broadcast(const XlaOp operand,
4054 absl::Span<const int64_t> broadcast_sizes) {
4055 return operand.builder()->Broadcast(operand, broadcast_sizes);
4056 }
4057
BroadcastInDim(const XlaOp operand,const absl::Span<const int64_t> out_dim_size,const absl::Span<const int64_t> broadcast_dimensions)4058 XlaOp BroadcastInDim(const XlaOp operand,
4059 const absl::Span<const int64_t> out_dim_size,
4060 const absl::Span<const int64_t> broadcast_dimensions) {
4061 return operand.builder()->BroadcastInDim(operand, out_dim_size,
4062 broadcast_dimensions);
4063 }
4064
Copy(const XlaOp operand)4065 XlaOp Copy(const XlaOp operand) {
4066 return operand.builder()->UnaryOp(HloOpcode::kCopy, operand);
4067 }
4068
Pad(const XlaOp operand,const XlaOp padding_value,const PaddingConfig & padding_config)4069 XlaOp Pad(const XlaOp operand, const XlaOp padding_value,
4070 const PaddingConfig& padding_config) {
4071 return operand.builder()->Pad(operand, padding_value, padding_config);
4072 }
4073
PadInDim(XlaOp operand,XlaOp padding_value,int64_t dimno,int64_t pad_lo,int64_t pad_hi)4074 XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
4075 int64_t pad_lo, int64_t pad_hi) {
4076 return operand.builder()->PadInDim(operand, padding_value, dimno, pad_lo,
4077 pad_hi);
4078 }
4079
Reshape(const XlaOp operand,absl::Span<const int64_t> dimensions,absl::Span<const int64_t> new_sizes)4080 XlaOp Reshape(const XlaOp operand, absl::Span<const int64_t> dimensions,
4081 absl::Span<const int64_t> new_sizes) {
4082 return operand.builder()->Reshape(operand, dimensions, new_sizes);
4083 }
4084
Reshape(const XlaOp operand,absl::Span<const int64_t> new_sizes)4085 XlaOp Reshape(const XlaOp operand, absl::Span<const int64_t> new_sizes) {
4086 return operand.builder()->Reshape(operand, new_sizes);
4087 }
4088
Reshape(const Shape & shape,XlaOp operand)4089 XlaOp Reshape(const Shape& shape, XlaOp operand) {
4090 return operand.builder()->Reshape(shape, operand);
4091 }
4092
DynamicReshape(XlaOp operand,absl::Span<const XlaOp> dim_sizes,absl::Span<const int64_t> new_size_bounds,const std::vector<bool> & dims_are_dynamic)4093 XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
4094 absl::Span<const int64_t> new_size_bounds,
4095 const std::vector<bool>& dims_are_dynamic) {
4096 return operand.builder()->DynamicReshape(operand, dim_sizes, new_size_bounds,
4097 dims_are_dynamic);
4098 }
4099
ReshapeWithInferredDimension(XlaOp operand,absl::Span<const int64_t> new_sizes,int64_t inferred_dimension)4100 XlaOp ReshapeWithInferredDimension(XlaOp operand,
4101 absl::Span<const int64_t> new_sizes,
4102 int64_t inferred_dimension) {
4103 return operand.builder()->Reshape(operand, new_sizes, inferred_dimension);
4104 }
4105
Collapse(const XlaOp operand,absl::Span<const int64_t> dimensions)4106 XlaOp Collapse(const XlaOp operand, absl::Span<const int64_t> dimensions) {
4107 return operand.builder()->Collapse(operand, dimensions);
4108 }
4109
Slice(const XlaOp operand,absl::Span<const int64_t> start_indices,absl::Span<const int64_t> limit_indices,absl::Span<const int64_t> strides)4110 XlaOp Slice(const XlaOp operand, absl::Span<const int64_t> start_indices,
4111 absl::Span<const int64_t> limit_indices,
4112 absl::Span<const int64_t> strides) {
4113 return operand.builder()->Slice(operand, start_indices, limit_indices,
4114 strides);
4115 }
4116
SliceInDim(const XlaOp operand,int64_t start_index,int64_t limit_index,int64_t stride,int64_t dimno)4117 XlaOp SliceInDim(const XlaOp operand, int64_t start_index, int64_t limit_index,
4118 int64_t stride, int64_t dimno) {
4119 return operand.builder()->SliceInDim(operand, start_index, limit_index,
4120 stride, dimno);
4121 }
4122
DynamicSlice(const XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64_t> slice_sizes)4123 XlaOp DynamicSlice(const XlaOp operand, absl::Span<const XlaOp> start_indices,
4124 absl::Span<const int64_t> slice_sizes) {
4125 return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
4126 }
4127
DynamicUpdateSlice(const XlaOp operand,const XlaOp update,absl::Span<const XlaOp> start_indices)4128 XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update,
4129 absl::Span<const XlaOp> start_indices) {
4130 return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
4131 }
4132
ConcatInDim(XlaBuilder * builder,absl::Span<const XlaOp> operands,int64_t dimension)4133 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
4134 int64_t dimension) {
4135 return builder->ConcatInDim(operands, dimension);
4136 }
4137
Select(const XlaOp pred,const XlaOp on_true,const XlaOp on_false)4138 XlaOp Select(const XlaOp pred, const XlaOp on_true, const XlaOp on_false) {
4139 return pred.builder()->Select(pred, on_true, on_false);
4140 }
4141
Tuple(XlaBuilder * builder,absl::Span<const XlaOp> elements)4142 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements) {
4143 return builder->Tuple(elements);
4144 }
4145
GetTupleElement(const XlaOp tuple_data,int64_t index)4146 XlaOp GetTupleElement(const XlaOp tuple_data, int64_t index) {
4147 return tuple_data.builder()->GetTupleElement(tuple_data, index);
4148 }
4149
Eq(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4150 XlaOp Eq(const XlaOp lhs, const XlaOp rhs,
4151 absl::Span<const int64_t> broadcast_dimensions) {
4152 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
4153 }
4154
CompareTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions,ComparisonDirection comparison_direction)4155 static XlaOp CompareTotalOrder(const XlaOp lhs, const XlaOp rhs,
4156 absl::Span<const int64_t> broadcast_dimensions,
4157 ComparisonDirection comparison_direction) {
4158 auto b = lhs.builder();
4159 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
4160 TF_ASSIGN_OR_RETURN(auto operand_shape, b->GetShape(lhs));
4161 auto operand_element_type = operand_shape.element_type();
4162 auto compare_type =
4163 primitive_util::IsFloatingPointType(operand_element_type)
4164 ? Comparison::Type::kFloatTotalOrder
4165 : Comparison::DefaultComparisonType(operand_element_type);
4166 return Compare(lhs, rhs, broadcast_dimensions, comparison_direction,
4167 compare_type);
4168 });
4169 }
4170
EqTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4171 XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs,
4172 absl::Span<const int64_t> broadcast_dimensions) {
4173 return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4174 ComparisonDirection::kEq);
4175 }
4176
Ne(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4177 XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
4178 absl::Span<const int64_t> broadcast_dimensions) {
4179 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe);
4180 }
4181
NeTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4182 XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs,
4183 absl::Span<const int64_t> broadcast_dimensions) {
4184 return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4185 ComparisonDirection::kNe);
4186 }
4187
Ge(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4188 XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
4189 absl::Span<const int64_t> broadcast_dimensions) {
4190 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe);
4191 }
4192
GeTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4193 XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs,
4194 absl::Span<const int64_t> broadcast_dimensions) {
4195 return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4196 ComparisonDirection::kGe);
4197 }
4198
Gt(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4199 XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
4200 absl::Span<const int64_t> broadcast_dimensions) {
4201 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt);
4202 }
4203
GtTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4204 XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs,
4205 absl::Span<const int64_t> broadcast_dimensions) {
4206 return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4207 ComparisonDirection::kGt);
4208 }
4209
Le(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4210 XlaOp Le(const XlaOp lhs, const XlaOp rhs,
4211 absl::Span<const int64_t> broadcast_dimensions) {
4212 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe);
4213 }
4214
LeTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4215 XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs,
4216 absl::Span<const int64_t> broadcast_dimensions) {
4217 return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4218 ComparisonDirection::kLe);
4219 }
4220
Lt(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4221 XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
4222 absl::Span<const int64_t> broadcast_dimensions) {
4223 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
4224 }
4225
LtTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4226 XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs,
4227 absl::Span<const int64_t> broadcast_dimensions) {
4228 return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4229 ComparisonDirection::kLt);
4230 }
4231
Compare(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions,ComparisonDirection direction)4232 XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
4233 absl::Span<const int64_t> broadcast_dimensions,
4234 ComparisonDirection direction) {
4235 return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
4236 broadcast_dimensions, direction);
4237 }
4238
Compare(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions,ComparisonDirection direction,Comparison::Type compare_type)4239 XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
4240 absl::Span<const int64_t> broadcast_dimensions,
4241 ComparisonDirection direction, Comparison::Type compare_type) {
4242 return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
4243 broadcast_dimensions, direction, compare_type);
4244 }
4245
Compare(const XlaOp lhs,const XlaOp rhs,ComparisonDirection direction)4246 XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) {
4247 return Compare(lhs, rhs, {}, direction);
4248 }
4249
Dot(const XlaOp lhs,const XlaOp rhs,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)4250 XlaOp Dot(const XlaOp lhs, const XlaOp rhs,
4251 const PrecisionConfig* precision_config,
4252 std::optional<PrimitiveType> preferred_element_type) {
4253 return lhs.builder()->Dot(lhs, rhs, precision_config, preferred_element_type);
4254 }
4255
DotGeneral(const XlaOp lhs,const XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)4256 XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs,
4257 const DotDimensionNumbers& dimension_numbers,
4258 const PrecisionConfig* precision_config,
4259 std::optional<PrimitiveType> preferred_element_type) {
4260 return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
4261 precision_config, preferred_element_type);
4262 }
4263
Conv(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> window_strides,Padding padding,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)4264 XlaOp Conv(const XlaOp lhs, const XlaOp rhs,
4265 absl::Span<const int64_t> window_strides, Padding padding,
4266 int64_t feature_group_count, int64_t batch_group_count,
4267 const PrecisionConfig* precision_config,
4268 std::optional<PrimitiveType> preferred_element_type) {
4269 return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
4270 feature_group_count, batch_group_count,
4271 precision_config, preferred_element_type);
4272 }
4273
ConvWithGeneralPadding(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)4274 XlaOp ConvWithGeneralPadding(
4275 const XlaOp lhs, const XlaOp rhs, absl::Span<const int64_t> window_strides,
4276 absl::Span<const std::pair<int64_t, int64_t>> padding,
4277 int64_t feature_group_count, int64_t batch_group_count,
4278 const PrecisionConfig* precision_config,
4279 std::optional<PrimitiveType> preferred_element_type) {
4280 return lhs.builder()->ConvWithGeneralPadding(
4281 lhs, rhs, window_strides, padding, feature_group_count, batch_group_count,
4282 precision_config, preferred_element_type);
4283 }
4284
ConvWithGeneralDimensions(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)4285 XlaOp ConvWithGeneralDimensions(
4286 const XlaOp lhs, const XlaOp rhs, absl::Span<const int64_t> window_strides,
4287 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
4288 int64_t feature_group_count, int64_t batch_group_count,
4289 const PrecisionConfig* precision_config,
4290 std::optional<PrimitiveType> preferred_element_type) {
4291 return lhs.builder()->ConvWithGeneralDimensions(
4292 lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
4293 batch_group_count, precision_config, preferred_element_type);
4294 }
4295
ConvGeneral(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)4296 XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
4297 absl::Span<const int64_t> window_strides,
4298 absl::Span<const std::pair<int64_t, int64_t>> padding,
4299 const ConvolutionDimensionNumbers& dimension_numbers,
4300 int64_t feature_group_count, int64_t batch_group_count,
4301 const PrecisionConfig* precision_config,
4302 std::optional<PrimitiveType> preferred_element_type) {
4303 return lhs.builder()->ConvGeneral(
4304 lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
4305 batch_group_count, precision_config, preferred_element_type);
4306 }
4307
ConvGeneralDilated(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type,std::optional<std::vector<bool>> window_reversal)4308 XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
4309 absl::Span<const int64_t> window_strides,
4310 absl::Span<const std::pair<int64_t, int64_t>> padding,
4311 absl::Span<const int64_t> lhs_dilation,
4312 absl::Span<const int64_t> rhs_dilation,
4313 const ConvolutionDimensionNumbers& dimension_numbers,
4314 int64_t feature_group_count, int64_t batch_group_count,
4315 const PrecisionConfig* precision_config,
4316 std::optional<PrimitiveType> preferred_element_type,
4317 std::optional<std::vector<bool>> window_reversal) {
4318 return lhs.builder()->ConvGeneralDilated(
4319 lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
4320 dimension_numbers, feature_group_count, batch_group_count,
4321 precision_config, preferred_element_type, window_reversal);
4322 }
4323
DynamicConvInputGrad(XlaOp input_sizes,const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,std::optional<PrimitiveType> preferred_element_type)4324 XlaOp DynamicConvInputGrad(
4325 XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs,
4326 absl::Span<const int64_t> window_strides,
4327 absl::Span<const std::pair<int64_t, int64_t>> padding,
4328 absl::Span<const int64_t> lhs_dilation,
4329 absl::Span<const int64_t> rhs_dilation,
4330 const ConvolutionDimensionNumbers& dimension_numbers,
4331 int64_t feature_group_count, int64_t batch_group_count,
4332 const PrecisionConfig* precision_config, PaddingType padding_type,
4333 std::optional<PrimitiveType> preferred_element_type) {
4334 return lhs.builder()->DynamicConvInputGrad(
4335 input_sizes, lhs, rhs, window_strides, padding, lhs_dilation,
4336 rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
4337 precision_config, padding_type, preferred_element_type);
4338 }
4339
DynamicConvKernelGrad(XlaOp activations,XlaOp gradients,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,std::optional<PrimitiveType> preferred_element_type)4340 XlaOp DynamicConvKernelGrad(
4341 XlaOp activations, XlaOp gradients,
4342 absl::Span<const int64_t> window_strides,
4343 absl::Span<const std::pair<int64_t, int64_t>> padding,
4344 absl::Span<const int64_t> lhs_dilation,
4345 absl::Span<const int64_t> rhs_dilation,
4346 const ConvolutionDimensionNumbers& dimension_numbers,
4347 int64_t feature_group_count, int64_t batch_group_count,
4348 const PrecisionConfig* precision_config, PaddingType padding_type,
4349 std::optional<PrimitiveType> preferred_element_type) {
4350 return activations.builder()->DynamicConvKernelGrad(
4351 activations, gradients, window_strides, padding, lhs_dilation,
4352 rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
4353 precision_config, padding_type, preferred_element_type);
4354 }
4355
DynamicConvForward(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,std::optional<PrimitiveType> preferred_element_type)4356 XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
4357 absl::Span<const int64_t> window_strides,
4358 absl::Span<const std::pair<int64_t, int64_t>> padding,
4359 absl::Span<const int64_t> lhs_dilation,
4360 absl::Span<const int64_t> rhs_dilation,
4361 const ConvolutionDimensionNumbers& dimension_numbers,
4362 int64_t feature_group_count, int64_t batch_group_count,
4363 const PrecisionConfig* precision_config,
4364 PaddingType padding_type,
4365 std::optional<PrimitiveType> preferred_element_type) {
4366 return lhs.builder()->DynamicConvForward(
4367 lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
4368 dimension_numbers, feature_group_count, batch_group_count,
4369 precision_config, padding_type, preferred_element_type);
4370 }
4371
Fft(const XlaOp operand,FftType fft_type,absl::Span<const int64_t> fft_length)4372 XlaOp Fft(const XlaOp operand, FftType fft_type,
4373 absl::Span<const int64_t> fft_length) {
4374 return operand.builder()->Fft(operand, fft_type, fft_length);
4375 }
4376
TriangularSolve(XlaOp a,XlaOp b,bool left_side,bool lower,bool unit_diagonal,TriangularSolveOptions::Transpose transpose_a)4377 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
4378 bool unit_diagonal,
4379 TriangularSolveOptions::Transpose transpose_a) {
4380 XlaBuilder* builder = a.builder();
4381 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
4382 TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a));
4383 TF_ASSIGN_OR_RETURN(const Shape* b_shape, builder->GetShapePtr(b));
4384 TriangularSolveOptions options;
4385 options.set_left_side(left_side);
4386 options.set_lower(lower);
4387 options.set_unit_diagonal(unit_diagonal);
4388 options.set_transpose_a(transpose_a);
4389 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTriangularSolveShape(
4390 *a_shape, *b_shape, options));
4391 return builder->TriangularSolveInternal(shape, a, b, std::move(options));
4392 });
4393 }
4394
Cholesky(XlaOp a,bool lower)4395 XlaOp Cholesky(XlaOp a, bool lower) {
4396 XlaBuilder* builder = a.builder();
4397 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
4398 TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a));
4399 TF_ASSIGN_OR_RETURN(Shape shape,
4400 ShapeInference::InferCholeskyShape(*a_shape));
4401 return builder->CholeskyInternal(shape, a, lower);
4402 });
4403 }
4404
Infeed(XlaBuilder * builder,const Shape & shape,const std::string & config)4405 XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
4406 const std::string& config) {
4407 return builder->Infeed(shape, config);
4408 }
4409
Outfeed(const XlaOp operand,const Shape & shape_with_layout,const std::string & outfeed_config)4410 void Outfeed(const XlaOp operand, const Shape& shape_with_layout,
4411 const std::string& outfeed_config) {
4412 return operand.builder()->Outfeed(operand, shape_with_layout, outfeed_config);
4413 }
4414
Call(XlaBuilder * builder,const XlaComputation & computation,absl::Span<const XlaOp> operands)4415 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
4416 absl::Span<const XlaOp> operands) {
4417 return builder->Call(computation, operands);
4418 }
4419
CustomCall(XlaBuilder * builder,const std::string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const std::string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> output_operand_aliasing,const Literal * literal,CustomCallSchedule schedule,CustomCallApiVersion api_version)4420 XlaOp CustomCall(
4421 XlaBuilder* builder, const std::string& call_target_name,
4422 absl::Span<const XlaOp> operands, const Shape& shape,
4423 const std::string& opaque, bool has_side_effect,
4424 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
4425 output_operand_aliasing,
4426 const Literal* literal, CustomCallSchedule schedule,
4427 CustomCallApiVersion api_version) {
4428 return builder->CustomCall(call_target_name, operands, shape, opaque,
4429 /*operand_shapes_with_layout=*/std::nullopt,
4430 has_side_effect, output_operand_aliasing, literal,
4431 /*window=*/std::nullopt, /*dnums=*/std::nullopt,
4432 schedule, api_version);
4433 }
4434
CustomCallWithComputation(XlaBuilder * builder,const std::string & call_target_name,absl::Span<const XlaOp> operands,const XlaComputation & computation,const Shape & shape,const std::string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> output_operand_aliasing,const Literal * literal,CustomCallSchedule schedule,CustomCallApiVersion api_version)4435 XlaOp CustomCallWithComputation(
4436 XlaBuilder* builder, const std::string& call_target_name,
4437 absl::Span<const XlaOp> operands, const XlaComputation& computation,
4438 const Shape& shape, const std::string& opaque, bool has_side_effect,
4439 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
4440 output_operand_aliasing,
4441 const Literal* literal, CustomCallSchedule schedule,
4442 CustomCallApiVersion api_version) {
4443 return builder->CustomCall(
4444 call_target_name, operands, computation, shape, opaque,
4445 /*operand_shapes_with_layout=*/std::nullopt, has_side_effect,
4446 output_operand_aliasing, literal, schedule, api_version);
4447 }
4448
CustomCallWithLayout(XlaBuilder * builder,const std::string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,absl::Span<const Shape> operand_shapes_with_layout,const std::string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> output_operand_aliasing,const Literal * literal,CustomCallSchedule schedule,CustomCallApiVersion api_version)4449 XlaOp CustomCallWithLayout(
4450 XlaBuilder* builder, const std::string& call_target_name,
4451 absl::Span<const XlaOp> operands, const Shape& shape,
4452 absl::Span<const Shape> operand_shapes_with_layout,
4453 const std::string& opaque, bool has_side_effect,
4454 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
4455 output_operand_aliasing,
4456 const Literal* literal, CustomCallSchedule schedule,
4457 CustomCallApiVersion api_version) {
4458 return builder->CustomCall(
4459 call_target_name, operands, shape, opaque, operand_shapes_with_layout,
4460 has_side_effect, output_operand_aliasing, literal,
4461 /*window=*/std::nullopt, /*dnums=*/std::nullopt, schedule, api_version);
4462 }
4463
CustomCallWithConvDnums(XlaBuilder * builder,const std::string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,absl::Span<const Shape> operand_shapes_with_layout,const std::string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> output_operand_aliasing,const Literal * literal,Window window,ConvolutionDimensionNumbers dnums,CustomCallSchedule schedule,CustomCallApiVersion api_version)4464 XlaOp CustomCallWithConvDnums(
4465 XlaBuilder* builder, const std::string& call_target_name,
4466 absl::Span<const XlaOp> operands, const Shape& shape,
4467 absl::Span<const Shape> operand_shapes_with_layout,
4468 const std::string& opaque, bool has_side_effect,
4469 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
4470 output_operand_aliasing,
4471 const Literal* literal, Window window, ConvolutionDimensionNumbers dnums,
4472 CustomCallSchedule schedule, CustomCallApiVersion api_version) {
4473 std::optional<absl::Span<const Shape>> maybe_operand_shapes;
4474 if (!operand_shapes_with_layout.empty()) {
4475 maybe_operand_shapes = operand_shapes_with_layout;
4476 }
4477 return builder->CustomCall(call_target_name, operands, shape, opaque,
4478 maybe_operand_shapes, has_side_effect,
4479 output_operand_aliasing, literal, window, dnums,
4480 schedule, api_version);
4481 }
4482
OptimizationBarrier(XlaOp operand)4483 XlaOp OptimizationBarrier(XlaOp operand) {
4484 return operand.builder()->OptimizationBarrier(operand);
4485 }
4486
Complex(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4487 XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
4488 absl::Span<const int64_t> broadcast_dimensions) {
4489 return lhs.builder()->BinaryOp(HloOpcode::kComplex, lhs, rhs,
4490 broadcast_dimensions);
4491 }
4492
Conj(const XlaOp operand)4493 XlaOp Conj(const XlaOp operand) {
4494 return Complex(Real(operand), Neg(Imag(operand)));
4495 }
4496
Add(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4497 XlaOp Add(const XlaOp lhs, const XlaOp rhs,
4498 absl::Span<const int64_t> broadcast_dimensions) {
4499 return lhs.builder()->BinaryOp(HloOpcode::kAdd, lhs, rhs,
4500 broadcast_dimensions);
4501 }
4502
Sub(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4503 XlaOp Sub(const XlaOp lhs, const XlaOp rhs,
4504 absl::Span<const int64_t> broadcast_dimensions) {
4505 return lhs.builder()->BinaryOp(HloOpcode::kSubtract, lhs, rhs,
4506 broadcast_dimensions);
4507 }
4508
Mul(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4509 XlaOp Mul(const XlaOp lhs, const XlaOp rhs,
4510 absl::Span<const int64_t> broadcast_dimensions) {
4511 return lhs.builder()->BinaryOp(HloOpcode::kMultiply, lhs, rhs,
4512 broadcast_dimensions);
4513 }
4514
Div(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4515 XlaOp Div(const XlaOp lhs, const XlaOp rhs,
4516 absl::Span<const int64_t> broadcast_dimensions) {
4517 return lhs.builder()->BinaryOp(HloOpcode::kDivide, lhs, rhs,
4518 broadcast_dimensions);
4519 }
4520
Rem(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4521 XlaOp Rem(const XlaOp lhs, const XlaOp rhs,
4522 absl::Span<const int64_t> broadcast_dimensions) {
4523 return lhs.builder()->BinaryOp(HloOpcode::kRemainder, lhs, rhs,
4524 broadcast_dimensions);
4525 }
4526
Max(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4527 XlaOp Max(const XlaOp lhs, const XlaOp rhs,
4528 absl::Span<const int64_t> broadcast_dimensions) {
4529 return lhs.builder()->BinaryOp(HloOpcode::kMaximum, lhs, rhs,
4530 broadcast_dimensions);
4531 }
4532
Min(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4533 XlaOp Min(const XlaOp lhs, const XlaOp rhs,
4534 absl::Span<const int64_t> broadcast_dimensions) {
4535 return lhs.builder()->BinaryOp(HloOpcode::kMinimum, lhs, rhs,
4536 broadcast_dimensions);
4537 }
4538
And(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4539 XlaOp And(const XlaOp lhs, const XlaOp rhs,
4540 absl::Span<const int64_t> broadcast_dimensions) {
4541 return lhs.builder()->BinaryOp(HloOpcode::kAnd, lhs, rhs,
4542 broadcast_dimensions);
4543 }
4544
Or(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4545 XlaOp Or(const XlaOp lhs, const XlaOp rhs,
4546 absl::Span<const int64_t> broadcast_dimensions) {
4547 return lhs.builder()->BinaryOp(HloOpcode::kOr, lhs, rhs,
4548 broadcast_dimensions);
4549 }
4550
Xor(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4551 XlaOp Xor(const XlaOp lhs, const XlaOp rhs,
4552 absl::Span<const int64_t> broadcast_dimensions) {
4553 return lhs.builder()->BinaryOp(HloOpcode::kXor, lhs, rhs,
4554 broadcast_dimensions);
4555 }
4556
Not(const XlaOp operand)4557 XlaOp Not(const XlaOp operand) {
4558 return operand.builder()->UnaryOp(HloOpcode::kNot, operand);
4559 }
4560
PopulationCount(const XlaOp operand)4561 XlaOp PopulationCount(const XlaOp operand) {
4562 return operand.builder()->UnaryOp(HloOpcode::kPopulationCount, operand);
4563 }
4564
ShiftLeft(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4565 XlaOp ShiftLeft(const XlaOp lhs, const XlaOp rhs,
4566 absl::Span<const int64_t> broadcast_dimensions) {
4567 return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs,
4568 broadcast_dimensions);
4569 }
4570
ShiftRightArithmetic(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4571 XlaOp ShiftRightArithmetic(const XlaOp lhs, const XlaOp rhs,
4572 absl::Span<const int64_t> broadcast_dimensions) {
4573 return lhs.builder()->BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs,
4574 broadcast_dimensions);
4575 }
4576
ShiftRightLogical(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4577 XlaOp ShiftRightLogical(const XlaOp lhs, const XlaOp rhs,
4578 absl::Span<const int64_t> broadcast_dimensions) {
4579 return lhs.builder()->BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs,
4580 broadcast_dimensions);
4581 }
4582
Reduce(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64_t> dimensions_to_reduce)4583 XlaOp Reduce(const XlaOp operand, const XlaOp init_value,
4584 const XlaComputation& computation,
4585 absl::Span<const int64_t> dimensions_to_reduce) {
4586 return operand.builder()->Reduce(operand, init_value, computation,
4587 dimensions_to_reduce);
4588 }
4589
4590 // Reduces several arrays simultaneously among the provided dimensions, given
4591 // "computation" as a reduction operator.
Reduce(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64_t> dimensions_to_reduce)4592 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
4593 absl::Span<const XlaOp> init_values,
4594 const XlaComputation& computation,
4595 absl::Span<const int64_t> dimensions_to_reduce) {
4596 return builder->Reduce(operands, init_values, computation,
4597 dimensions_to_reduce);
4598 }
4599
ReduceAll(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation)4600 XlaOp ReduceAll(const XlaOp operand, const XlaOp init_value,
4601 const XlaComputation& computation) {
4602 return operand.builder()->ReduceAll(operand, init_value, computation);
4603 }
4604
ReduceWindow(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,Padding padding)4605 XlaOp ReduceWindow(const XlaOp operand, const XlaOp init_value,
4606 const XlaComputation& computation,
4607 absl::Span<const int64_t> window_dimensions,
4608 absl::Span<const int64_t> window_strides, Padding padding) {
4609 return operand.builder()->ReduceWindow(operand, init_value, computation,
4610 window_dimensions, window_strides,
4611 padding);
4612 }
4613
ReduceWindow(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,Padding padding)4614 XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
4615 absl::Span<const XlaOp> init_values,
4616 const XlaComputation& computation,
4617 absl::Span<const int64_t> window_dimensions,
4618 absl::Span<const int64_t> window_strides, Padding padding) {
4619 CHECK(!operands.empty());
4620 return operands[0].builder()->ReduceWindow(operands, init_values, computation,
4621 window_dimensions, window_strides,
4622 padding);
4623 }
4624
ReduceWindowWithGeneralPadding(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,absl::Span<const int64_t> base_dilations,absl::Span<const int64_t> window_dilations,absl::Span<const std::pair<int64_t,int64_t>> padding)4625 XlaOp ReduceWindowWithGeneralPadding(
4626 const XlaOp operand, const XlaOp init_value,
4627 const XlaComputation& computation,
4628 absl::Span<const int64_t> window_dimensions,
4629 absl::Span<const int64_t> window_strides,
4630 absl::Span<const int64_t> base_dilations,
4631 absl::Span<const int64_t> window_dilations,
4632 absl::Span<const std::pair<int64_t, int64_t>> padding) {
4633 return operand.builder()->ReduceWindowWithGeneralPadding(
4634 absl::MakeSpan(&operand, 1), absl::MakeSpan(&init_value, 1), computation,
4635 window_dimensions, window_strides, base_dilations, window_dilations,
4636 padding);
4637 }
4638
ReduceWindowWithGeneralPadding(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,absl::Span<const int64_t> base_dilations,absl::Span<const int64_t> window_dilations,absl::Span<const std::pair<int64_t,int64_t>> padding)4639 XlaOp ReduceWindowWithGeneralPadding(
4640 absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
4641 const XlaComputation& computation,
4642 absl::Span<const int64_t> window_dimensions,
4643 absl::Span<const int64_t> window_strides,
4644 absl::Span<const int64_t> base_dilations,
4645 absl::Span<const int64_t> window_dilations,
4646 absl::Span<const std::pair<int64_t, int64_t>> padding) {
4647 CHECK(!operands.empty());
4648 return operands[0].builder()->ReduceWindowWithGeneralPadding(
4649 operands, init_values, computation, window_dimensions, window_strides,
4650 base_dilations, window_dilations, padding);
4651 }
4652
AllGather(const XlaOp operand,int64_t all_gather_dimension,int64_t shard_count,absl::Span<const ReplicaGroup> replica_groups,const std::optional<ChannelHandle> & channel_id,const std::optional<Layout> & layout,const std::optional<bool> use_global_device_ids)4653 XlaOp AllGather(const XlaOp operand, int64_t all_gather_dimension,
4654 int64_t shard_count,
4655 absl::Span<const ReplicaGroup> replica_groups,
4656 const std::optional<ChannelHandle>& channel_id,
4657 const std::optional<Layout>& layout,
4658 const std::optional<bool> use_global_device_ids) {
4659 return operand.builder()->AllGather(operand, all_gather_dimension,
4660 shard_count, replica_groups, channel_id,
4661 layout, use_global_device_ids);
4662 }
4663
CrossReplicaSum(const XlaOp operand,absl::Span<const ReplicaGroup> replica_groups)4664 XlaOp CrossReplicaSum(const XlaOp operand,
4665 absl::Span<const ReplicaGroup> replica_groups) {
4666 return operand.builder()->CrossReplicaSum(operand, replica_groups);
4667 }
4668
AllReduce(const XlaOp operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const std::optional<ChannelHandle> & channel_id,const std::optional<Shape> & shape_with_layout,const std::optional<bool> use_global_device_ids)4669 XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
4670 absl::Span<const ReplicaGroup> replica_groups,
4671 const std::optional<ChannelHandle>& channel_id,
4672 const std::optional<Shape>& shape_with_layout,
4673 const std::optional<bool> use_global_device_ids) {
4674 return operand.builder()->AllReduce(operand, computation, replica_groups,
4675 channel_id, shape_with_layout,
4676 use_global_device_ids);
4677 }
4678
ReduceScatter(const XlaOp operand,const XlaComputation & computation,int64_t scatter_dimension,int64_t shard_count,absl::Span<const ReplicaGroup> replica_groups,const std::optional<ChannelHandle> & channel_id,const std::optional<Layout> & layout,const std::optional<bool> use_global_device_ids)4679 XlaOp ReduceScatter(const XlaOp operand, const XlaComputation& computation,
4680 int64_t scatter_dimension, int64_t shard_count,
4681 absl::Span<const ReplicaGroup> replica_groups,
4682 const std::optional<ChannelHandle>& channel_id,
4683 const std::optional<Layout>& layout,
4684 const std::optional<bool> use_global_device_ids) {
4685 return operand.builder()->ReduceScatter(
4686 operand, computation, scatter_dimension, shard_count, replica_groups,
4687 channel_id, layout, use_global_device_ids);
4688 }
4689
AllToAll(const XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups,const std::optional<Layout> & layout)4690 XlaOp AllToAll(const XlaOp operand, int64_t split_dimension,
4691 int64_t concat_dimension, int64_t split_count,
4692 absl::Span<const ReplicaGroup> replica_groups,
4693 const std::optional<Layout>& layout) {
4694 return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
4695 split_count, replica_groups, layout);
4696 }
4697
AllToAllTuple(absl::Span<const XlaOp> operands,absl::Span<const ReplicaGroup> replica_groups,const std::optional<Layout> & layout)4698 XlaOp AllToAllTuple(absl::Span<const XlaOp> operands,
4699 absl::Span<const ReplicaGroup> replica_groups,
4700 const std::optional<Layout>& layout) {
4701 CHECK(!operands.empty());
4702 return operands[0].builder()->AllToAllTuple(operands, replica_groups, layout);
4703 }
4704
AllToAllTuple(const XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups,const std::optional<Layout> & layout)4705 XlaOp AllToAllTuple(const XlaOp operand, int64_t split_dimension,
4706 int64_t concat_dimension, int64_t split_count,
4707 absl::Span<const ReplicaGroup> replica_groups,
4708 const std::optional<Layout>& layout) {
4709 return operand.builder()->AllToAllTuple(operand, split_dimension,
4710 concat_dimension, split_count,
4711 replica_groups, layout);
4712 }
4713
CollectivePermute(const XlaOp operand,const std::vector<std::pair<int64_t,int64_t>> & source_target_pairs)4714 XlaOp CollectivePermute(
4715 const XlaOp operand,
4716 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs) {
4717 return operand.builder()->CollectivePermute(operand, source_target_pairs);
4718 }
4719
ReplicaId(XlaBuilder * builder)4720 XlaOp ReplicaId(XlaBuilder* builder) { return builder->ReplicaId(); }
4721
SelectAndScatter(const XlaOp operand,const XlaComputation & select,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,Padding padding,const XlaOp source,const XlaOp init_value,const XlaComputation & scatter)4722 XlaOp SelectAndScatter(const XlaOp operand, const XlaComputation& select,
4723 absl::Span<const int64_t> window_dimensions,
4724 absl::Span<const int64_t> window_strides,
4725 Padding padding, const XlaOp source,
4726 const XlaOp init_value, const XlaComputation& scatter) {
4727 return operand.builder()->SelectAndScatter(operand, select, window_dimensions,
4728 window_strides, padding, source,
4729 init_value, scatter);
4730 }
4731
SelectAndScatterWithGeneralPadding(const XlaOp operand,const XlaComputation & select,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,const XlaOp source,const XlaOp init_value,const XlaComputation & scatter)4732 XlaOp SelectAndScatterWithGeneralPadding(
4733 const XlaOp operand, const XlaComputation& select,
4734 absl::Span<const int64_t> window_dimensions,
4735 absl::Span<const int64_t> window_strides,
4736 absl::Span<const std::pair<int64_t, int64_t>> padding, const XlaOp source,
4737 const XlaOp init_value, const XlaComputation& scatter) {
4738 return operand.builder()->SelectAndScatterWithGeneralPadding(
4739 operand, select, window_dimensions, window_strides, padding, source,
4740 init_value, scatter);
4741 }
4742
Abs(const XlaOp operand)4743 XlaOp Abs(const XlaOp operand) {
4744 return operand.builder()->UnaryOp(HloOpcode::kAbs, operand);
4745 }
4746
Atan2(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4747 XlaOp Atan2(const XlaOp lhs, const XlaOp rhs,
4748 absl::Span<const int64_t> broadcast_dimensions) {
4749 return lhs.builder()->BinaryOp(HloOpcode::kAtan2, lhs, rhs,
4750 broadcast_dimensions);
4751 }
4752
Exp(const XlaOp operand)4753 XlaOp Exp(const XlaOp operand) {
4754 return operand.builder()->UnaryOp(HloOpcode::kExp, operand);
4755 }
Expm1(const XlaOp operand)4756 XlaOp Expm1(const XlaOp operand) {
4757 return operand.builder()->UnaryOp(HloOpcode::kExpm1, operand);
4758 }
Floor(const XlaOp operand)4759 XlaOp Floor(const XlaOp operand) {
4760 return operand.builder()->UnaryOp(HloOpcode::kFloor, operand);
4761 }
Ceil(const XlaOp operand)4762 XlaOp Ceil(const XlaOp operand) {
4763 return operand.builder()->UnaryOp(HloOpcode::kCeil, operand);
4764 }
Round(const XlaOp operand)4765 XlaOp Round(const XlaOp operand) {
4766 return operand.builder()->UnaryOp(HloOpcode::kRoundNearestAfz, operand);
4767 }
RoundNearestEven(const XlaOp operand)4768 XlaOp RoundNearestEven(const XlaOp operand) {
4769 return operand.builder()->UnaryOp(HloOpcode::kRoundNearestEven, operand);
4770 }
Log(const XlaOp operand)4771 XlaOp Log(const XlaOp operand) {
4772 return operand.builder()->UnaryOp(HloOpcode::kLog, operand);
4773 }
Log1p(const XlaOp operand)4774 XlaOp Log1p(const XlaOp operand) {
4775 return operand.builder()->UnaryOp(HloOpcode::kLog1p, operand);
4776 }
Logistic(const XlaOp operand)4777 XlaOp Logistic(const XlaOp operand) {
4778 return operand.builder()->UnaryOp(HloOpcode::kLogistic, operand);
4779 }
Sign(const XlaOp operand)4780 XlaOp Sign(const XlaOp operand) {
4781 return operand.builder()->UnaryOp(HloOpcode::kSign, operand);
4782 }
Clz(const XlaOp operand)4783 XlaOp Clz(const XlaOp operand) {
4784 return operand.builder()->UnaryOp(HloOpcode::kClz, operand);
4785 }
Cos(const XlaOp operand)4786 XlaOp Cos(const XlaOp operand) {
4787 return operand.builder()->UnaryOp(HloOpcode::kCos, operand);
4788 }
Sin(const XlaOp operand)4789 XlaOp Sin(const XlaOp operand) {
4790 return operand.builder()->UnaryOp(HloOpcode::kSin, operand);
4791 }
Tanh(const XlaOp operand)4792 XlaOp Tanh(const XlaOp operand) {
4793 return operand.builder()->UnaryOp(HloOpcode::kTanh, operand);
4794 }
Real(const XlaOp operand)4795 XlaOp Real(const XlaOp operand) {
4796 return operand.builder()->UnaryOp(HloOpcode::kReal, operand);
4797 }
Imag(const XlaOp operand)4798 XlaOp Imag(const XlaOp operand) {
4799 return operand.builder()->UnaryOp(HloOpcode::kImag, operand);
4800 }
Sqrt(const XlaOp operand)4801 XlaOp Sqrt(const XlaOp operand) {
4802 return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand);
4803 }
Cbrt(const XlaOp operand)4804 XlaOp Cbrt(const XlaOp operand) {
4805 return operand.builder()->UnaryOp(HloOpcode::kCbrt, operand);
4806 }
Rsqrt(const XlaOp operand)4807 XlaOp Rsqrt(const XlaOp operand) {
4808 return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand);
4809 }
4810
Pow(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4811 XlaOp Pow(const XlaOp lhs, const XlaOp rhs,
4812 absl::Span<const int64_t> broadcast_dimensions) {
4813 return lhs.builder()->BinaryOp(HloOpcode::kPower, lhs, rhs,
4814 broadcast_dimensions);
4815 }
4816
IsFinite(const XlaOp operand)4817 XlaOp IsFinite(const XlaOp operand) {
4818 return operand.builder()->UnaryOp(HloOpcode::kIsFinite, operand);
4819 }
4820
ConvertElementType(const XlaOp operand,PrimitiveType new_element_type)4821 XlaOp ConvertElementType(const XlaOp operand, PrimitiveType new_element_type) {
4822 return operand.builder()->ConvertElementType(operand, new_element_type);
4823 }
4824
BitcastConvertType(const XlaOp operand,PrimitiveType new_element_type)4825 XlaOp BitcastConvertType(const XlaOp operand, PrimitiveType new_element_type) {
4826 return operand.builder()->BitcastConvertType(operand, new_element_type);
4827 }
4828
Neg(const XlaOp operand)4829 XlaOp Neg(const XlaOp operand) {
4830 return operand.builder()->UnaryOp(HloOpcode::kNegate, operand);
4831 }
4832
Transpose(const XlaOp operand,absl::Span<const int64_t> permutation)4833 XlaOp Transpose(const XlaOp operand, absl::Span<const int64_t> permutation) {
4834 return operand.builder()->Transpose(operand, permutation);
4835 }
4836
Rev(const XlaOp operand,absl::Span<const int64_t> dimensions)4837 XlaOp Rev(const XlaOp operand, absl::Span<const int64_t> dimensions) {
4838 return operand.builder()->Rev(operand, dimensions);
4839 }
4840
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64_t dimension,bool is_stable)4841 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
4842 int64_t dimension, bool is_stable) {
4843 return operands[0].builder()->Sort(operands, comparator, dimension,
4844 is_stable);
4845 }
4846
Clamp(const XlaOp min,const XlaOp operand,const XlaOp max)4847 XlaOp Clamp(const XlaOp min, const XlaOp operand, const XlaOp max) {
4848 return min.builder()->Clamp(min, operand, max);
4849 }
4850
Map(XlaBuilder * builder,absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64_t> dimensions,absl::Span<const XlaOp> static_operands)4851 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
4852 const XlaComputation& computation,
4853 absl::Span<const int64_t> dimensions,
4854 absl::Span<const XlaOp> static_operands) {
4855 return builder->Map(operands, computation, dimensions, static_operands);
4856 }
4857
RngNormal(const XlaOp mu,const XlaOp sigma,const Shape & shape)4858 XlaOp RngNormal(const XlaOp mu, const XlaOp sigma, const Shape& shape) {
4859 return mu.builder()->RngNormal(mu, sigma, shape);
4860 }
4861
RngUniform(const XlaOp a,const XlaOp b,const Shape & shape)4862 XlaOp RngUniform(const XlaOp a, const XlaOp b, const Shape& shape) {
4863 return a.builder()->RngUniform(a, b, shape);
4864 }
4865
RngBitGenerator(RandomAlgorithm algorithm,const XlaOp initial_state,const Shape & shape)4866 XlaOp RngBitGenerator(RandomAlgorithm algorithm, const XlaOp initial_state,
4867 const Shape& shape) {
4868 return initial_state.builder()->RngBitGenerator(algorithm, initial_state,
4869 shape);
4870 }
4871
While(const XlaComputation & condition,const XlaComputation & body,const XlaOp init)4872 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
4873 const XlaOp init) {
4874 return init.builder()->While(condition, body, init);
4875 }
4876
Conditional(const XlaOp predicate,const XlaOp true_operand,const XlaComputation & true_computation,const XlaOp false_operand,const XlaComputation & false_computation)4877 XlaOp Conditional(const XlaOp predicate, const XlaOp true_operand,
4878 const XlaComputation& true_computation,
4879 const XlaOp false_operand,
4880 const XlaComputation& false_computation) {
4881 return predicate.builder()->Conditional(predicate, true_operand,
4882 true_computation, false_operand,
4883 false_computation);
4884 }
4885
Conditional(const XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)4886 XlaOp Conditional(const XlaOp branch_index,
4887 absl::Span<const XlaComputation* const> branch_computations,
4888 absl::Span<const XlaOp> branch_operands) {
4889 return branch_index.builder()->Conditional(branch_index, branch_computations,
4890 branch_operands);
4891 }
4892
ReducePrecision(const XlaOp operand,const int exponent_bits,const int mantissa_bits)4893 XlaOp ReducePrecision(const XlaOp operand, const int exponent_bits,
4894 const int mantissa_bits) {
4895 return operand.builder()->ReducePrecision(operand, exponent_bits,
4896 mantissa_bits);
4897 }
4898
Gather(const XlaOp input,const XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64_t> slice_sizes,bool indices_are_sorted)4899 XlaOp Gather(const XlaOp input, const XlaOp start_indices,
4900 const GatherDimensionNumbers& dimension_numbers,
4901 absl::Span<const int64_t> slice_sizes, bool indices_are_sorted) {
4902 return input.builder()->Gather(input, start_indices, dimension_numbers,
4903 slice_sizes, indices_are_sorted);
4904 }
4905
Scatter(const XlaOp input,const XlaOp scatter_indices,const XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)4906 XlaOp Scatter(const XlaOp input, const XlaOp scatter_indices,
4907 const XlaOp updates, const XlaComputation& update_computation,
4908 const ScatterDimensionNumbers& dimension_numbers,
4909 bool indices_are_sorted, bool unique_indices) {
4910 return input.builder()->Scatter(input, scatter_indices, updates,
4911 update_computation, dimension_numbers,
4912 indices_are_sorted, unique_indices);
4913 }
4914
Scatter(absl::Span<const XlaOp> inputs,XlaOp scatter_indices,absl::Span<const XlaOp> updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)4915 XlaOp Scatter(absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
4916 absl::Span<const XlaOp> updates,
4917 const XlaComputation& update_computation,
4918 const ScatterDimensionNumbers& dimension_numbers,
4919 bool indices_are_sorted, bool unique_indices) {
4920 return scatter_indices.builder()->Scatter(
4921 inputs, scatter_indices, updates, update_computation, dimension_numbers,
4922 indices_are_sorted, unique_indices);
4923 }
4924
Send(const XlaOp operand,const ChannelHandle & handle)4925 void Send(const XlaOp operand, const ChannelHandle& handle) {
4926 return operand.builder()->Send(operand, handle);
4927 }
4928
Recv(XlaBuilder * builder,const Shape & shape,const ChannelHandle & handle)4929 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
4930 const ChannelHandle& handle) {
4931 return builder->Recv(shape, handle);
4932 }
4933
SendWithToken(const XlaOp operand,const XlaOp token,const ChannelHandle & handle)4934 XlaOp SendWithToken(const XlaOp operand, const XlaOp token,
4935 const ChannelHandle& handle) {
4936 return operand.builder()->SendWithToken(operand, token, handle);
4937 }
4938
RecvWithToken(const XlaOp token,const Shape & shape,const ChannelHandle & handle)4939 XlaOp RecvWithToken(const XlaOp token, const Shape& shape,
4940 const ChannelHandle& handle) {
4941 return token.builder()->RecvWithToken(token, shape, handle);
4942 }
4943
SendToHost(const XlaOp operand,const XlaOp token,const Shape & shape_with_layout,const ChannelHandle & handle)4944 XlaOp SendToHost(const XlaOp operand, const XlaOp token,
4945 const Shape& shape_with_layout, const ChannelHandle& handle) {
4946 return operand.builder()->SendToHost(operand, token, shape_with_layout,
4947 handle);
4948 }
4949
RecvFromHost(const XlaOp token,const Shape & shape,const ChannelHandle & handle)4950 XlaOp RecvFromHost(const XlaOp token, const Shape& shape,
4951 const ChannelHandle& handle) {
4952 return token.builder()->RecvFromHost(token, shape, handle);
4953 }
4954
InfeedWithToken(const XlaOp token,const Shape & shape,const std::string & config)4955 XlaOp InfeedWithToken(const XlaOp token, const Shape& shape,
4956 const std::string& config) {
4957 return token.builder()->InfeedWithToken(token, shape, config);
4958 }
4959
OutfeedWithToken(const XlaOp operand,const XlaOp token,const Shape & shape_with_layout,const std::string & outfeed_config)4960 XlaOp OutfeedWithToken(const XlaOp operand, const XlaOp token,
4961 const Shape& shape_with_layout,
4962 const std::string& outfeed_config) {
4963 return operand.builder()->OutfeedWithToken(operand, token, shape_with_layout,
4964 outfeed_config);
4965 }
4966
CreateToken(XlaBuilder * builder)4967 XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); }
4968
AfterAll(XlaBuilder * builder,absl::Span<const XlaOp> tokens)4969 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens) {
4970 return builder->AfterAll(tokens);
4971 }
4972
BatchNormTraining(const XlaOp operand,const XlaOp scale,const XlaOp offset,float epsilon,int64_t feature_index)4973 XlaOp BatchNormTraining(const XlaOp operand, const XlaOp scale,
4974 const XlaOp offset, float epsilon,
4975 int64_t feature_index) {
4976 return operand.builder()->BatchNormTraining(operand, scale, offset, epsilon,
4977 feature_index);
4978 }
4979
BatchNormInference(const XlaOp operand,const XlaOp scale,const XlaOp offset,const XlaOp mean,const XlaOp variance,float epsilon,int64_t feature_index)4980 XlaOp BatchNormInference(const XlaOp operand, const XlaOp scale,
4981 const XlaOp offset, const XlaOp mean,
4982 const XlaOp variance, float epsilon,
4983 int64_t feature_index) {
4984 return operand.builder()->BatchNormInference(
4985 operand, scale, offset, mean, variance, epsilon, feature_index);
4986 }
4987
BatchNormGrad(const XlaOp operand,const XlaOp scale,const XlaOp batch_mean,const XlaOp batch_var,const XlaOp grad_output,float epsilon,int64_t feature_index)4988 XlaOp BatchNormGrad(const XlaOp operand, const XlaOp scale,
4989 const XlaOp batch_mean, const XlaOp batch_var,
4990 const XlaOp grad_output, float epsilon,
4991 int64_t feature_index) {
4992 return operand.builder()->BatchNormGrad(operand, scale, batch_mean, batch_var,
4993 grad_output, epsilon, feature_index);
4994 }
4995
Iota(XlaBuilder * builder,PrimitiveType type,int64_t size)4996 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size) {
4997 return builder->Iota(type, size);
4998 }
4999
Iota(XlaBuilder * builder,const Shape & shape,int64_t iota_dimension)5000 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64_t iota_dimension) {
5001 return builder->Iota(shape, iota_dimension);
5002 }
5003
GetDimensionSize(const XlaOp operand,int64_t dimension)5004 XlaOp GetDimensionSize(const XlaOp operand, int64_t dimension) {
5005 return operand.builder()->GetDimensionSize(operand, dimension);
5006 }
5007
SetDimensionSize(const XlaOp operand,const XlaOp val,int64_t dimension)5008 XlaOp SetDimensionSize(const XlaOp operand, const XlaOp val,
5009 int64_t dimension) {
5010 return operand.builder()->SetDimensionSize(operand, val, dimension);
5011 }
5012
RemoveDynamicDimension(const XlaOp operand,int64_t dimension)5013 XlaOp RemoveDynamicDimension(const XlaOp operand, int64_t dimension) {
5014 return operand.builder()->RemoveDynamicDimension(operand, dimension);
5015 }
5016
GetManualSharding(const OpSharding & original,int64_t single_dim)5017 OpSharding GetManualSharding(const OpSharding& original, int64_t single_dim) {
5018 OpSharding manual;
5019 if (single_dim < 0 || original.type() != OpSharding::OTHER) {
5020 manual.set_type(OpSharding::MANUAL);
5021 return manual;
5022 }
5023 manual.set_type(OpSharding::OTHER);
5024 std::vector<int64_t> new_tile_shape(
5025 original.tile_assignment_dimensions().begin(),
5026 original.tile_assignment_dimensions().end());
5027 new_tile_shape.push_back(new_tile_shape[single_dim]);
5028 new_tile_shape[single_dim] = 1;
5029 Array<int64_t> new_tile(new_tile_shape);
5030 new_tile.Each([&](absl::Span<const int64_t> indices, int64_t* v) {
5031 int64_t src_index = 0;
5032 for (int64_t i = 0; i < indices.size() - 1; ++i) {
5033 if (i > 0) {
5034 src_index *= new_tile_shape[i];
5035 }
5036 int64_t index = indices[i];
5037 if (i == single_dim) {
5038 index = indices.back();
5039 }
5040 src_index += index;
5041 }
5042 *v = original.tile_assignment_devices(src_index);
5043 });
5044 for (int64_t dim : new_tile_shape) {
5045 manual.add_tile_assignment_dimensions(dim);
5046 }
5047 for (int64_t device : new_tile) {
5048 manual.add_tile_assignment_devices(device);
5049 }
5050 if (original.replicate_on_last_tile_dim()) {
5051 manual.add_last_tile_dims(OpSharding::REPLICATED);
5052 }
5053 for (int64_t type : original.last_tile_dims()) {
5054 manual.add_last_tile_dims(static_cast<OpSharding::Type>(type));
5055 }
5056 manual.add_last_tile_dims(OpSharding::MANUAL);
5057 return manual;
5058 }
5059
ConvertSpmdFullToShardShape(XlaBuilder * builder,XlaOp input,int single_dim,const OpSharding & manual_sharding,absl::Span<const int64_t> unspecified_dims)5060 StatusOr<XlaOp> ConvertSpmdFullToShardShape(
5061 XlaBuilder* builder, XlaOp input, int single_dim,
5062 const OpSharding& manual_sharding,
5063 absl::Span<const int64_t> unspecified_dims) {
5064 TF_ASSIGN_OR_RETURN(const Shape input_shape, builder->GetShape(input));
5065
5066 Shape output_shape = input_shape;
5067 const int64_t rank = output_shape.rank();
5068 if (manual_sharding.type() == OpSharding::OTHER) {
5069 for (int64_t i = 0; i < rank; ++i) {
5070 if (single_dim >= 0 && i != single_dim) {
5071 continue;
5072 }
5073 const int64_t partitions_i =
5074 manual_sharding.tile_assignment_dimensions(i);
5075 if (partitions_i == 1) continue;
5076 const int64_t dim_size =
5077 CeilOfRatio(output_shape.dimensions(i), partitions_i);
5078 output_shape.set_dimensions(i, dim_size);
5079 }
5080 }
5081
5082 XlaOp input_annotation;
5083 {
5084 // Annotate the full-shape input with the sharding.
5085 XlaScopedShardingAssignment assign_sharding(builder, manual_sharding);
5086 input_annotation = CustomCall(
5087 builder, /*call_target_name=*/"Sharding", {input}, input_shape,
5088 /*opaque=*/
5089 sharding_op_util::EncodeAttributes(unspecified_dims));
5090 }
5091
5092 {
5093 // Annotate the shard-shape output with manual sharding, so that the
5094 // partitioner will leave it as is.
5095 OpSharding manual = GetManualSharding(manual_sharding, single_dim);
5096 XlaScopedShardingAssignment assign_sharding(builder, manual);
5097 return CustomCall(builder,
5098 /*call_target_name=*/"SPMDFullToShardShape",
5099 {input_annotation}, output_shape,
5100 /*opaque=*/
5101 sharding_op_util::EncodeAttributes(unspecified_dims));
5102 }
5103 }
5104
ConvertSpmdShardToFullShape(XlaBuilder * builder,XlaOp input,const Shape & output_shape,int single_dim,const OpSharding & manual_sharding,absl::Span<const int64_t> unspecified_dims)5105 StatusOr<XlaOp> ConvertSpmdShardToFullShape(
5106 XlaBuilder* builder, XlaOp input, const Shape& output_shape, int single_dim,
5107 const OpSharding& manual_sharding,
5108 absl::Span<const int64_t> unspecified_dims) {
5109 TF_ASSIGN_OR_RETURN(const Shape input_shape, builder->GetShape(input));
5110
5111 XlaOp input_annotation;
5112 {
5113 // Annotate the shard-shape input with manual sharding, so that the
5114 // partitioner will leave it as is.
5115 OpSharding manual = GetManualSharding(manual_sharding, single_dim);
5116 XlaScopedShardingAssignment assign_sharding(builder, manual);
5117 input_annotation = CustomCall(
5118 builder, /*call_target_name=*/"Sharding", {input}, input_shape,
5119 sharding_op_util::EncodeAttributes(unspecified_dims));
5120 }
5121
5122 {
5123 // Annotate the full-shape output with the sharding.
5124 XlaScopedShardingAssignment assign_sharding(builder, manual_sharding);
5125 return CustomCall(builder,
5126 /*call_target_name=*/"SPMDShardToFullShape",
5127 {input_annotation}, output_shape,
5128 sharding_op_util::EncodeAttributes(unspecified_dims));
5129 }
5130 }
5131
5132 } // namespace xla
5133