• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/xla_builder.h"
17 
18 #include <functional>
19 #include <iterator>
20 #include <memory>
21 #include <numeric>
22 #include <optional>
23 #include <queue>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/match.h"
32 #include "absl/strings/numbers.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/str_split.h"
36 #include "absl/types/span.h"
37 #include "tensorflow/compiler/xla/client/sharding_builder.h"
38 #include "tensorflow/compiler/xla/client/xla_computation.h"
39 #include "tensorflow/compiler/xla/comparison_util.h"
40 #include "tensorflow/compiler/xla/layout_util.h"
41 #include "tensorflow/compiler/xla/permutation_util.h"
42 #include "tensorflow/compiler/xla/primitive_util.h"
43 #include "tensorflow/compiler/xla/service/hlo.pb.h"
44 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
45 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
46 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
47 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
48 #include "tensorflow/compiler/xla/service/shape_inference.h"
49 #include "tensorflow/compiler/xla/sharding_op_util.h"
50 #include "tensorflow/compiler/xla/status_macros.h"
51 #include "tensorflow/compiler/xla/util.h"
52 #include "tensorflow/compiler/xla/window_util.h"
53 #include "tensorflow/compiler/xla/xla_data.pb.h"
54 #include "tensorflow/core/platform/errors.h"
55 
56 namespace xla {
57 
58 using absl::StrCat;
59 
60 namespace {
61 
62 static const char kNameSeparator = '.';
63 
64 // Retrieves the base name of an instruction or computation fully qualified
65 // name, using separator as boundary between the initial base name part, and
66 // the numeric identification.
GetBaseName(const std::string & name,char separator)67 std::string GetBaseName(const std::string& name, char separator) {
68   auto pos = name.rfind(separator);
69   CHECK_NE(pos, std::string::npos) << name;
70   return name.substr(0, pos);
71 }
72 
73 // Generates a fully qualified computation/instruction name.
GetFullName(const std::string & base_name,char separator,int64_t id)74 std::string GetFullName(const std::string& base_name, char separator,
75                         int64_t id) {
76   const char separator_str[] = {separator, '\0'};
77   return StrCat(base_name, separator_str, id);
78 }
79 
80 // Common function to standardize setting name and IDs on computation and
81 // instruction proto entities.
82 template <typename T>
SetProtoIdAndName(T * entry,const std::string & base_name,char separator,int64_t id)83 void SetProtoIdAndName(T* entry, const std::string& base_name, char separator,
84                        int64_t id) {
85   entry->set_id(id);
86   entry->set_name(GetFullName(base_name, separator, id));
87 }
88 
InstrIsSetBound(const HloInstructionProto * instr_proto)89 bool InstrIsSetBound(const HloInstructionProto* instr_proto) {
90   HloOpcode opcode = StringToHloOpcode(instr_proto->opcode()).ValueOrDie();
91   if (opcode == HloOpcode::kCustomCall &&
92       instr_proto->custom_call_target() == "SetBound") {
93     return true;
94   }
95   return false;
96 }
97 
98 }  // namespace
99 
100 namespace internal {
101 
BuildAddDependency(XlaBuilder * builder,XlaOp operand,XlaOp token,const Shape & shape)102 XlaOp XlaBuilderFriend::BuildAddDependency(XlaBuilder* builder, XlaOp operand,
103                                            XlaOp token, const Shape& shape) {
104   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
105     HloInstructionProto instr;
106     *instr.mutable_shape() = shape.ToProto();
107     return builder->AddInstruction(std::move(instr), HloOpcode::kAddDependency,
108                                    {operand, token});
109   });
110 }
111 
BuildFusion(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::string_view fusion_kind,const XlaComputation & fused_computation)112 XlaOp XlaBuilderFriend::BuildFusion(XlaBuilder* builder,
113                                     absl::Span<const XlaOp> operands,
114                                     absl::string_view fusion_kind,
115                                     const XlaComputation& fused_computation) {
116   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
117     HloInstructionProto instr;
118     instr.set_fusion_kind(std::string(fusion_kind));
119     std::vector<const Shape*> operand_shape_ptrs;
120     TF_ASSIGN_OR_RETURN(auto program_shape,
121                         fused_computation.GetProgramShape());
122     *instr.mutable_shape() = program_shape.result().ToProto();
123     builder->AddCalledComputation(fused_computation, &instr);
124     return builder->AddInstruction(std::move(instr), HloOpcode::kFusion,
125                                    operands);
126   });
127 }
128 
BuildBitcast(XlaBuilder * builder,XlaOp operand,const Shape & shape)129 XlaOp XlaBuilderFriend::BuildBitcast(XlaBuilder* builder, XlaOp operand,
130                                      const Shape& shape) {
131   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
132     HloInstructionProto instr;
133     *instr.mutable_shape() = shape.ToProto();
134     return builder->AddInstruction(std::move(instr), HloOpcode::kBitcast,
135                                    {operand});
136   });
137 }
138 
BuildDomain(XlaBuilder * builder,XlaOp operand,const OpSharding entry,const OpSharding exit,const Shape & shape)139 XlaOp XlaBuilderFriend::BuildDomain(XlaBuilder* builder, XlaOp operand,
140                                     const OpSharding entry,
141                                     const OpSharding exit, const Shape& shape) {
142   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
143     HloInstructionProto instr;
144     *instr.mutable_domain_entry_sharding() = entry;
145     *instr.mutable_domain_exit_sharding() = exit;
146     *instr.mutable_shape() = shape.ToProto();
147     return builder->AddInstruction(std::move(instr), HloOpcode::kDomain,
148                                    {operand});
149   });
150 }
151 
BuildPartitionId(XlaBuilder * builder,const Shape & shape)152 XlaOp XlaBuilderFriend::BuildPartitionId(XlaBuilder* builder,
153                                          const Shape& shape) {
154   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
155     HloInstructionProto instr;
156     *instr.mutable_shape() = shape.ToProto();
157     return builder->AddInstruction(std::move(instr), HloOpcode::kPartitionId);
158   });
159 }
160 
BuildRngGetAndUpdateState(XlaBuilder * builder,int64_t delta,const Shape & shape)161 XlaOp XlaBuilderFriend::BuildRngGetAndUpdateState(XlaBuilder* builder,
162 
163                                                   int64_t delta,
164                                                   const Shape& shape) {
165   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
166     HloInstructionProto instr;
167     instr.set_delta(delta);
168     *instr.mutable_shape() = shape.ToProto();
169     return builder->AddInstruction(std::move(instr),
170                                    HloOpcode::kRngGetAndUpdateState);
171   });
172 }
173 
GetInstruction(XlaOp op)174 HloInstructionProto* XlaBuilderFriend::GetInstruction(XlaOp op) {
175   return &op.builder()
176               ->instructions_[op.builder()->handle_to_index_[op.handle_]];
177 }
178 
GetInstructionByHandle(XlaBuilder * builder,int64_t handle)179 HloInstructionProto* XlaBuilderFriend::GetInstructionByHandle(
180     XlaBuilder* builder, int64_t handle) {
181   return &builder->instructions_[builder->handle_to_index_[handle]];
182 }
183 
184 }  // namespace internal
185 
operator -(XlaOp x)186 XlaOp operator-(XlaOp x) { return Neg(x); }
operator +(XlaOp x,XlaOp y)187 XlaOp operator+(XlaOp x, XlaOp y) { return Add(x, y); }
operator -(XlaOp x,XlaOp y)188 XlaOp operator-(XlaOp x, XlaOp y) { return Sub(x, y); }
operator *(XlaOp x,XlaOp y)189 XlaOp operator*(XlaOp x, XlaOp y) { return Mul(x, y); }
operator /(XlaOp x,XlaOp y)190 XlaOp operator/(XlaOp x, XlaOp y) { return Div(x, y); }
operator %(XlaOp x,XlaOp y)191 XlaOp operator%(XlaOp x, XlaOp y) { return Rem(x, y); }
192 
operator ~(XlaOp x)193 XlaOp operator~(XlaOp x) { return Not(x); }
operator &(XlaOp x,XlaOp y)194 XlaOp operator&(XlaOp x, XlaOp y) { return And(x, y); }
operator |(XlaOp x,XlaOp y)195 XlaOp operator|(XlaOp x, XlaOp y) { return Or(x, y); }
operator ^(XlaOp x,XlaOp y)196 XlaOp operator^(XlaOp x, XlaOp y) { return Xor(x, y); }
operator <<(XlaOp x,XlaOp y)197 XlaOp operator<<(XlaOp x, XlaOp y) { return ShiftLeft(x, y); }
198 
operator >>(XlaOp x,XlaOp y)199 XlaOp operator>>(XlaOp x, XlaOp y) {
200   XlaBuilder* builder = x.builder();
201   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
202     TF_ASSIGN_OR_RETURN(const Shape* shape, builder->GetShapePtr(x));
203     if (!ShapeUtil::ElementIsIntegral(*shape)) {
204       return InvalidArgument(
205           "Argument to >> operator does not have an integral type (%s).",
206           ShapeUtil::HumanString(*shape));
207     }
208     if (ShapeUtil::ElementIsSigned(*shape)) {
209       return ShiftRightArithmetic(x, y);
210     } else {
211       return ShiftRightLogical(x, y);
212     }
213   });
214 }
215 
GetShapePtr(XlaOp op) const216 StatusOr<const Shape*> XlaBuilder::GetShapePtr(XlaOp op) const {
217   TF_RETURN_IF_ERROR(first_error_);
218   TF_RETURN_IF_ERROR(CheckOpBuilder(op));
219   auto it = handle_to_index_.find(op.handle());
220   if (it == handle_to_index_.end()) {
221     return InvalidArgument("No XlaOp with handle %d", op.handle());
222   }
223   return instruction_shapes_.at(it->second).get();
224 }
225 
GetShape(XlaOp op) const226 StatusOr<Shape> XlaBuilder::GetShape(XlaOp op) const {
227   TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(op));
228   return *shape;
229 }
230 
GetOperandShapes(absl::Span<const XlaOp> operands) const231 StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
232     absl::Span<const XlaOp> operands) const {
233   std::vector<Shape> operand_shapes;
234   operand_shapes.reserve(operands.size());
235   for (XlaOp operand : operands) {
236     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
237     operand_shapes.push_back(*shape);
238   }
239   return operand_shapes;
240 }
241 
OpToString(XlaOp op) const242 std::string XlaBuilder::OpToString(XlaOp op) const {
243   std::string s;
244   ToStringHelper(&s, /*ident=*/0, op.handle());
245   return s;
246 }
247 
ShapeToString(const ShapeProto & shape)248 static std::string ShapeToString(const ShapeProto& shape) {
249   if (shape.tuple_shapes_size() > 1) {
250     return absl::StrCat(
251         "(",
252         absl::StrJoin(shape.tuple_shapes(), ", ",
253                       [&](std::string* s, const ShapeProto& subshape) {
254                         absl::StrAppend(s, ShapeToString(subshape));
255                       }),
256         ")");
257   }
258   return absl::StrCat("[", absl::StrJoin(shape.dimensions(), ", "), "]");
259 }
260 
ToStringHelper(std::string * out,int ident,int64_t op_handle) const261 void XlaBuilder::ToStringHelper(std::string* out, int ident,
262                                 int64_t op_handle) const {
263   const HloInstructionProto& instr =
264       *(LookUpInstructionByHandle(op_handle).ValueOrDie());
265   absl::StrAppend(out, std::string(ident, ' '), instr.opcode(),
266                   ", shape=", ShapeToString(instr.shape()));
267   if (instr.has_metadata()) {
268     absl::StrAppend(out, ", metadata={", instr.metadata().source_file(), ":",
269                     instr.metadata().source_line(), "}");
270   }
271   if (instr.operand_ids_size()) {
272     absl::StrAppend(out, "\n");
273   }
274   absl::StrAppend(out, absl::StrJoin(instr.operand_ids(), "\n",
275                                      [&](std::string* s, int64_t subop) {
276                                        ToStringHelper(s, ident + 2, subop);
277                                      }));
278 }
279 
XlaBuilder(const std::string & computation_name)280 XlaBuilder::XlaBuilder(const std::string& computation_name)
281     : name_(computation_name) {}
282 
~XlaBuilder()283 XlaBuilder::~XlaBuilder() {}
284 
ReportError(const Status & error)285 XlaOp XlaBuilder::ReportError(const Status& error) {
286   CHECK(!error.ok());
287   if (die_immediately_on_error_) {
288     LOG(FATAL) << "error building computation: " << error;
289   }
290 
291   if (first_error_.ok()) {
292     first_error_ = error;
293     first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
294   }
295   return XlaOp(this);
296 }
297 
ReportErrorOrReturn(const StatusOr<XlaOp> & op)298 XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr<XlaOp>& op) {
299   if (!first_error_.ok()) {
300     return XlaOp(this);
301   }
302   if (!op.ok()) {
303     return ReportError(op.status());
304   }
305   return op.ValueOrDie();
306 }
307 
ReportErrorOrReturn(const std::function<StatusOr<XlaOp> ()> & op_creator)308 XlaOp XlaBuilder::ReportErrorOrReturn(
309     const std::function<StatusOr<XlaOp>()>& op_creator) {
310   return ReportErrorOrReturn(op_creator());
311 }
312 
GetProgramShape(int64_t root_id) const313 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64_t root_id) const {
314   TF_RETURN_IF_ERROR(first_error_);
315   TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
316                       LookUpInstructionByHandle(root_id));
317 
318   ProgramShape program_shape;
319 
320   *program_shape.mutable_result() = Shape(root_proto->shape());
321 
322   // Check that the parameter numbers are continuous from 0, and add parameter
323   // shapes and names to the program shape.
324   const int64_t param_count = parameter_numbers_.size();
325   for (int64_t i = 0; i < param_count; i++) {
326     program_shape.add_parameters();
327     program_shape.add_parameter_names();
328   }
329   for (const HloInstructionProto& instr : instructions_) {
330     // Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So
331     // to verify continuity, we just need to verify that every parameter is in
332     // the right range.
333     if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
334       const int64_t index = instr.parameter_number();
335       TF_RET_CHECK(index >= 0 && index < param_count)
336           << "invalid parameter number: " << index;
337       *program_shape.mutable_parameters(index) = Shape(instr.shape());
338       *program_shape.mutable_parameter_names(index) = instr.name();
339     }
340   }
341   return program_shape;
342 }
343 
GetProgramShape() const344 StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const {
345   TF_RET_CHECK(!instructions_.empty());
346   return GetProgramShape(instructions_.back().id());
347 }
348 
GetProgramShape(XlaOp root) const349 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(XlaOp root) const {
350   if (root.builder_ != this) {
351     return InvalidArgument("Given root operation is not in this computation.");
352   }
353   return GetProgramShape(root.handle());
354 }
355 
IsConstantVisitor(const int64_t op_handle,int depth,absl::flat_hash_set<int64_t> * visited,bool * is_constant) const356 void XlaBuilder::IsConstantVisitor(const int64_t op_handle, int depth,
357                                    absl::flat_hash_set<int64_t>* visited,
358                                    bool* is_constant) const {
359   if (visited->contains(op_handle) || !*is_constant) {
360     return;
361   }
362 
363   const HloInstructionProto& instr =
364       *(LookUpInstructionByHandle(op_handle).ValueOrDie());
365   HloInstructionProto to_print(instr);
366   to_print.clear_shape();
367   const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
368   const std::string indent =
369       absl::StrJoin(std::vector<absl::string_view>(depth, "  "), "");
370   if (VLOG_IS_ON(2)) {
371     VLOG(2) << indent << "Visiting:";
372     for (const auto& l : absl::StrSplit(to_print.DebugString(), '\n')) {
373       VLOG(2) << indent << l;
374     }
375   }
376   switch (opcode) {
377     default:
378       for (const int64_t operand_id : instr.operand_ids()) {
379         IsConstantVisitor(operand_id, depth + 1, visited, is_constant);
380       }
381       // TODO(b/32495713): We aren't checking the called computations.
382       break;
383 
384     case HloOpcode::kGetDimensionSize:
385       // GetDimensionSize is always considered constant in XLA -- If a dynamic
386       // dimension is presented, -1 is returned.
387       break;
388     // Non functional ops.
389     case HloOpcode::kRng:
390     case HloOpcode::kAllReduce:
391     case HloOpcode::kReduceScatter:
392       // TODO(b/33009255): Implement constant folding for cross replica sum.
393     case HloOpcode::kInfeed:
394     case HloOpcode::kOutfeed:
395     case HloOpcode::kCall:
396       // TODO(b/32495713): We aren't checking the to_apply computation itself,
397       // so we conservatively say that computations containing the Call op
398       // cannot be constant.  We cannot set is_functional=false in other similar
399       // cases since we're already relying on IsConstant to return true.
400     case HloOpcode::kCustomCall:
401       if (instr.custom_call_target() == "SetBound") {
402         // Set bound is considered constant -- the bound is used as the value.
403         break;
404       }
405       [[fallthrough]];
406     case HloOpcode::kWhile:
407       // TODO(b/32495713): We aren't checking the condition and body
408       // computations themselves.
409     case HloOpcode::kScatter:
410       // TODO(b/32495713): We aren't checking the embedded computation in
411       // Scatter.
412     case HloOpcode::kSend:
413     case HloOpcode::kRecv:
414     case HloOpcode::kParameter:
415       *is_constant = false;
416       break;
417     case HloOpcode::kGetTupleElement: {
418       const HloInstructionProto& operand_instr =
419           *(LookUpInstructionByHandle(instr.operand_ids(0)).ValueOrDie());
420       if (HloOpcodeString(HloOpcode::kTuple) == operand_instr.opcode()) {
421         IsConstantVisitor(operand_instr.operand_ids(instr.tuple_index()),
422                           depth + 1, visited, is_constant);
423       } else {
424         for (const int64_t operand_id : instr.operand_ids()) {
425           IsConstantVisitor(operand_id, depth + 1, visited, is_constant);
426         }
427       }
428     }
429   }
430   if (VLOG_IS_ON(1) && !*is_constant) {
431     VLOG(1) << indent << "Non-constant: ";
432     for (const auto& l : absl::StrSplit(to_print.DebugString(), '\n')) {
433       VLOG(1) << indent << l;
434     }
435   }
436   visited->insert(op_handle);
437 }
438 
SetDynamicBinding(int64_t dynamic_size_param_num,ShapeIndex dynamic_size_param_index,int64_t target_param_num,ShapeIndex target_param_index,int64_t target_dim_num)439 Status XlaBuilder::SetDynamicBinding(int64_t dynamic_size_param_num,
440                                      ShapeIndex dynamic_size_param_index,
441                                      int64_t target_param_num,
442                                      ShapeIndex target_param_index,
443                                      int64_t target_dim_num) {
444   bool param_exists = false;
445   for (size_t index = 0; index < instructions_.size(); ++index) {
446     HloInstructionProto& instr = instructions_[index];
447     if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
448         instr.parameter_number() == target_param_num) {
449       param_exists = true;
450       Shape param_shape(instr.shape());
451       Shape* param_shape_ptr = &param_shape;
452       for (int64_t index : target_param_index) {
453         param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index);
454       }
455       param_shape_ptr->set_dynamic_dimension(target_dim_num,
456                                              /*is_dynamic=*/true);
457       *instr.mutable_shape() = param_shape.ToProto();
458       instruction_shapes_[index] =
459           std::make_unique<Shape>(std::move(param_shape));
460     }
461   }
462   if (!param_exists) {
463     return InvalidArgument(
464         "Asked to mark parameter %lld as dynamic sized parameter, but the "
465         "doesn't exists",
466         target_param_num);
467   }
468 
469   TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind(
470       DynamicParameterBinding::DynamicParameter{dynamic_size_param_num,
471                                                 dynamic_size_param_index},
472       DynamicParameterBinding::DynamicDimension{
473           target_param_num, target_param_index, target_dim_num}));
474   return OkStatus();
475 }
476 
SetInstructionFrontendAttribute(const XlaOp op,std::string attribute,std::string value)477 Status XlaBuilder::SetInstructionFrontendAttribute(const XlaOp op,
478                                                    std::string attribute,
479                                                    std::string value) {
480   TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op));
481   auto* frontend_attributes = instr_proto->mutable_frontend_attributes();
482   (*frontend_attributes->mutable_map())[attribute] = std::move(value);
483   return OkStatus();
484 }
485 
BuildAndNoteError()486 XlaComputation XlaBuilder::BuildAndNoteError() {
487   DCHECK(parent_builder_ != nullptr);
488   auto build_status = Build();
489   if (!build_status.ok()) {
490     parent_builder_->ReportError(
491         AddStatus(build_status.status(), absl::StrCat("error from: ", name_)));
492     return {};
493   }
494   return std::move(build_status).value();
495 }
496 
GetCurrentStatus() const497 Status XlaBuilder::GetCurrentStatus() const {
498   if (!first_error_.ok()) {
499     std::string backtrace;
500     first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
501     return AppendStatus(first_error_, backtrace);
502   }
503   return OkStatus();
504 }
505 
Build(bool remove_dynamic_dimensions)506 StatusOr<XlaComputation> XlaBuilder::Build(bool remove_dynamic_dimensions) {
507   TF_RETURN_IF_ERROR(GetCurrentStatus());
508   return Build(instructions_.back().id(), remove_dynamic_dimensions);
509 }
510 
Build(XlaOp root,bool remove_dynamic_dimensions)511 StatusOr<XlaComputation> XlaBuilder::Build(XlaOp root,
512                                            bool remove_dynamic_dimensions) {
513   if (root.builder_ != this) {
514     return InvalidArgument("Given root operation is not in this computation.");
515   }
516   return Build(root.handle(), remove_dynamic_dimensions);
517 }
518 
Build(int64_t root_id,bool remove_dynamic_dimensions)519 StatusOr<XlaComputation> XlaBuilder::Build(int64_t root_id,
520                                            bool remove_dynamic_dimensions) {
521   TF_RETURN_IF_ERROR(GetCurrentStatus());
522 
523   // TODO(b/121223198): XLA backend cannot handle dynamic dimensions yet, remove
524   // all dynamic dimensions before building xla program until we have support in
525   // the backend.
526   if (remove_dynamic_dimensions) {
527     std::function<void(Shape*)> remove_dynamic_dimension = [&](Shape* shape) {
528       if (shape->tuple_shapes_size() != 0) {
529         for (int i = 0; i < shape->tuple_shapes_size(); ++i) {
530           remove_dynamic_dimension(shape->mutable_tuple_shapes(i));
531         }
532       }
533       for (int64_t i = 0; i < shape->dimensions_size(); ++i) {
534         shape->set_dynamic_dimension(i, false);
535       }
536     };
537     for (size_t index = 0; index < instructions_.size(); ++index) {
538       remove_dynamic_dimension(instruction_shapes_[index].get());
539       *instructions_[index].mutable_shape() =
540           instruction_shapes_[index]->ToProto();
541     }
542   }
543 
544   HloComputationProto entry;
545   SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId());
546   TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id));
547   *entry.mutable_program_shape() = program_shape.ToProto();
548   entry.set_root_id(root_id);
549 
550   for (auto& instruction : instructions_) {
551     // Ensures that the instruction names are unique among the whole graph.
552     instruction.set_name(
553         GetFullName(instruction.name(), kNameSeparator, instruction.id()));
554     entry.add_instructions()->Swap(&instruction);
555   }
556 
557   XlaComputation computation(entry.id());
558   HloModuleProto* module = computation.mutable_proto();
559   module->set_name(entry.name());
560   module->set_id(entry.id());
561   module->set_entry_computation_name(entry.name());
562   module->set_entry_computation_id(entry.id());
563   *module->mutable_host_program_shape() = entry.program_shape();
564   for (auto& e : embedded_) {
565     module->add_computations()->Swap(&e.second);
566   }
567   module->add_computations()->Swap(&entry);
568   if (!input_output_aliases_.empty()) {
569     TF_RETURN_IF_ERROR(
570         PopulateInputOutputAlias(module, program_shape, input_output_aliases_));
571   }
572   *(module->mutable_dynamic_parameter_binding()) =
573       dynamic_parameter_binding_.ToProto();
574 
575   // Clear data held by this builder.
576   this->instructions_.clear();
577   this->instruction_shapes_.clear();
578   this->handle_to_index_.clear();
579   this->embedded_.clear();
580   this->parameter_numbers_.clear();
581 
582   return std::move(computation);
583 }
584 
PopulateInputOutputAlias(HloModuleProto * module,const ProgramShape & program_shape,const std::vector<InputOutputAlias> & input_output_aliases)585 /* static */ Status XlaBuilder::PopulateInputOutputAlias(
586     HloModuleProto* module, const ProgramShape& program_shape,
587     const std::vector<InputOutputAlias>& input_output_aliases) {
588   HloInputOutputAliasConfig config(program_shape.result());
589   for (auto& alias : input_output_aliases) {
590     // The HloInputOutputAliasConfig does not do parameter validation as it only
591     // carries the result shape. Maybe it should be constructed with a
592     // ProgramShape to allow full validation. We will still get an error when
593     // trying to compile the HLO module, but would be better to have validation
594     // at this stage.
595     if (alias.param_number >= program_shape.parameters_size()) {
596       return InvalidArgument("Invalid parameter number %ld (total %ld)",
597                              alias.param_number,
598                              program_shape.parameters_size());
599     }
600     const Shape& parameter_shape = program_shape.parameters(alias.param_number);
601     if (!ShapeUtil::IndexIsValid(parameter_shape, alias.param_index)) {
602       return InvalidArgument("Invalid parameter %ld index: %s",
603                              alias.param_number,
604                              alias.param_index.ToString().c_str());
605     }
606     TF_RETURN_IF_ERROR(config.SetUpAlias(alias.output_index, alias.param_number,
607                                          alias.param_index, alias.kind));
608   }
609   *module->mutable_input_output_alias() = config.ToProto();
610   return OkStatus();
611 }
612 
InDimBroadcast(const Shape & shape,XlaOp operand,absl::Span<const int64_t> broadcast_dimensions)613 StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
614     const Shape& shape, XlaOp operand,
615     absl::Span<const int64_t> broadcast_dimensions) {
616   TF_RETURN_IF_ERROR(first_error_);
617 
618   HloInstructionProto instr;
619   *instr.mutable_shape() = shape.ToProto();
620   for (int64_t dim : broadcast_dimensions) {
621     instr.add_dimensions(dim);
622   }
623 
624   return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand});
625 }
626 
AddBroadcastSequence(const Shape & output_shape,XlaOp operand)627 StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
628                                                  XlaOp operand) {
629   TF_RETURN_IF_ERROR(first_error_);
630 
631   TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
632 
633   CHECK(ShapeUtil::IsScalar(*operand_shape) ||
634         operand_shape->rank() == output_shape.rank());
635   Shape broadcast_shape =
636       ShapeUtil::ChangeElementType(output_shape, operand_shape->element_type());
637 
638   // Do explicit broadcast for scalar.
639   if (ShapeUtil::IsScalar(*operand_shape)) {
640     return InDimBroadcast(broadcast_shape, operand, {});
641   }
642 
643   // Do explicit broadcast for degenerate broadcast.
644   std::vector<int64_t> broadcast_dimensions;
645   std::vector<int64_t> reshaped_dimensions;
646   for (int i = 0; i < operand_shape->rank(); i++) {
647     if (operand_shape->dimensions(i) == output_shape.dimensions(i)) {
648       broadcast_dimensions.push_back(i);
649       reshaped_dimensions.push_back(operand_shape->dimensions(i));
650     } else {
651       TF_RET_CHECK(operand_shape->dimensions(i) == 1)
652           << "An explicit broadcast sequence requires the broadcasted "
653              "dimensions to be trivial; operand shape: "
654           << *operand_shape << "; output_shape: " << output_shape;
655     }
656   }
657 
658   Shape reshaped_shape =
659       ShapeUtil::MakeShape(operand_shape->element_type(), reshaped_dimensions);
660 
661   std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
662       ShapeUtil::DimensionsUnmodifiedByReshape(*operand_shape, reshaped_shape);
663 
664   for (auto& unmodified : unmodified_dims) {
665     if (operand_shape->is_dynamic_dimension(unmodified.first)) {
666       reshaped_shape.set_dynamic_dimension(unmodified.second, true);
667     }
668   }
669 
670   // Eliminate the size one dimensions.
671   TF_ASSIGN_OR_RETURN(
672       XlaOp reshaped_operand,
673       ReshapeInternal(reshaped_shape, operand, /*inferred_dimension=*/-1));
674   // Broadcast 'reshape' up to the larger size.
675   return InDimBroadcast(broadcast_shape, reshaped_operand,
676                         broadcast_dimensions);
677 }
678 
UnaryOp(HloOpcode unop,XlaOp operand)679 XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
680   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
681     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
682     TF_ASSIGN_OR_RETURN(
683         Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape));
684     return AddOpWithShape(unop, shape, {operand});
685   });
686 }
687 
BinaryOp(HloOpcode binop,XlaOp lhs,XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions,std::optional<ComparisonDirection> direction,std::optional<Comparison::Type> type)688 XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
689                            absl::Span<const int64_t> broadcast_dimensions,
690                            std::optional<ComparisonDirection> direction,
691                            std::optional<Comparison::Type> type) {
692   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
693     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
694     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
695     TF_ASSIGN_OR_RETURN(
696         Shape shape, ShapeInference::InferBinaryOpShape(
697                          binop, *lhs_shape, *rhs_shape, broadcast_dimensions));
698 
699     const int64_t lhs_rank = lhs_shape->rank();
700     const int64_t rhs_rank = rhs_shape->rank();
701 
702     XlaOp updated_lhs = lhs;
703     XlaOp updated_rhs = rhs;
704     if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) {
705       const bool should_broadcast_lhs = lhs_rank < rhs_rank;
706       XlaOp from = should_broadcast_lhs ? lhs : rhs;
707       const Shape& from_shape = should_broadcast_lhs ? *lhs_shape : *rhs_shape;
708 
709       std::vector<int64_t> to_size;
710       std::vector<bool> to_size_is_dynamic;
711       const auto rank = shape.rank();
712       to_size.reserve(rank);
713       to_size_is_dynamic.reserve(rank);
714       for (int i = 0; i < rank; i++) {
715         to_size.push_back(shape.dimensions(i));
716         to_size_is_dynamic.push_back(shape.is_dynamic_dimension(i));
717       }
718       for (int64_t from_dim = 0; from_dim < from_shape.rank(); from_dim++) {
719         int64_t to_dim = broadcast_dimensions[from_dim];
720         to_size[to_dim] = from_shape.dimensions(from_dim);
721         to_size_is_dynamic[to_dim] = from_shape.is_dynamic_dimension(from_dim);
722       }
723 
724       const Shape& broadcasted_shape = ShapeUtil::MakeShape(
725           from_shape.element_type(), to_size, to_size_is_dynamic);
726       TF_ASSIGN_OR_RETURN(
727           XlaOp broadcasted_operand,
728           InDimBroadcast(broadcasted_shape, from, broadcast_dimensions));
729 
730       updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs;
731       updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs;
732     }
733 
734     TF_ASSIGN_OR_RETURN(const Shape* updated_lhs_shape,
735                         GetShapePtr(updated_lhs));
736     if (!ShapeUtil::SameDimensions(shape, *updated_lhs_shape)) {
737       TF_ASSIGN_OR_RETURN(updated_lhs,
738                           AddBroadcastSequence(shape, updated_lhs));
739     }
740     TF_ASSIGN_OR_RETURN(const Shape* updated_rhs_shape,
741                         GetShapePtr(updated_rhs));
742     if (!ShapeUtil::SameDimensions(shape, *updated_rhs_shape)) {
743       TF_ASSIGN_OR_RETURN(updated_rhs,
744                           AddBroadcastSequence(shape, updated_rhs));
745     }
746 
747     if (binop == HloOpcode::kCompare) {
748       if (!direction.has_value()) {
749         return InvalidArgument(
750             "kCompare expects a ComparisonDirection, but none provided.");
751       }
752       if (type == std::nullopt) {
753         return Compare(shape, updated_lhs, updated_rhs, *direction);
754       } else {
755         return Compare(shape, updated_lhs, updated_rhs, *direction, *type);
756       }
757     }
758 
759     if (direction.has_value()) {
760       return InvalidArgument(
761           "A comparison direction is provided for a non-compare opcode: %s.",
762           HloOpcodeString(binop));
763     }
764     return BinaryOpNoBroadcast(binop, shape, updated_lhs, updated_rhs);
765   });
766 }
767 
BinaryOpNoBroadcast(HloOpcode binop,const Shape & shape,XlaOp lhs,XlaOp rhs)768 XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
769                                       XlaOp lhs, XlaOp rhs) {
770   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
771     HloInstructionProto instr;
772     *instr.mutable_shape() = shape.ToProto();
773     return AddInstruction(std::move(instr), binop, {lhs, rhs});
774   });
775 }
776 
Compare(const Shape & shape,XlaOp lhs,XlaOp rhs,ComparisonDirection direction)777 StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
778                                     ComparisonDirection direction) {
779   TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(lhs));
780   return Compare(
781       shape, lhs, rhs, direction,
782       Comparison::DefaultComparisonType(operand_shape.element_type()));
783 }
784 
Compare(const Shape & shape,XlaOp lhs,XlaOp rhs,ComparisonDirection direction,Comparison::Type type)785 StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
786                                     ComparisonDirection direction,
787                                     Comparison::Type type) {
788   HloInstructionProto instr;
789   instr.set_comparison_direction(ComparisonDirectionToString(direction));
790   instr.set_comparison_type(ComparisonTypeToString(type));
791   *instr.mutable_shape() = shape.ToProto();
792   return AddInstruction(std::move(instr), HloOpcode::kCompare, {lhs, rhs});
793 }
794 
TernaryOp(HloOpcode triop,XlaOp lhs,XlaOp rhs,XlaOp ehs)795 XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) {
796   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
797     XlaOp updated_lhs = lhs;
798     XlaOp updated_rhs = rhs;
799     XlaOp updated_ehs = ehs;
800     // The client API supports implicit broadcast for kSelect and kClamp, but
801     // XLA does not support implicit broadcast. Make implicit broadcast explicit
802     // and update the operands.
803     if (triop == HloOpcode::kSelect || triop == HloOpcode::kClamp) {
804       TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
805       TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
806       TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(ehs));
807 
808       std::optional<Shape> non_scalar_shape;
809       for (const Shape* shape : {lhs_shape, rhs_shape, ehs_shape}) {
810         if (shape->IsArray() && shape->rank() != 0) {
811           if (non_scalar_shape.has_value()) {
812             // TODO(jpienaar): The case where we need to compute the broadcasted
813             // shape by considering multiple of the shapes is not implemented.
814             // Consider reusing getBroadcastedType from mlir/Dialect/Traits.h.
815             TF_RET_CHECK(non_scalar_shape.value().dimensions() ==
816                          shape->dimensions())
817                 << "Unimplemented implicit broadcast.";
818           } else {
819             non_scalar_shape = *shape;
820           }
821         }
822       }
823       if (non_scalar_shape.has_value()) {
824         if (ShapeUtil::IsScalar(*lhs_shape)) {
825           TF_ASSIGN_OR_RETURN(updated_lhs,
826                               AddBroadcastSequence(*non_scalar_shape, lhs));
827         }
828         if (ShapeUtil::IsScalar(*rhs_shape)) {
829           TF_ASSIGN_OR_RETURN(updated_rhs,
830                               AddBroadcastSequence(*non_scalar_shape, rhs));
831         }
832         if (ShapeUtil::IsScalar(*ehs_shape)) {
833           TF_ASSIGN_OR_RETURN(updated_ehs,
834                               AddBroadcastSequence(*non_scalar_shape, ehs));
835         }
836       }
837     }
838 
839     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(updated_lhs));
840     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(updated_rhs));
841     TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(updated_ehs));
842     StatusOr<const Shape> status_or_shape = ShapeInference::InferTernaryOpShape(
843         triop, *lhs_shape, *rhs_shape, *ehs_shape);
844     if (!status_or_shape.status().ok()) {
845       return InvalidArgument(
846           "%s Input scalar shapes may have been changed to non-scalar shapes.",
847           status_or_shape.status().error_message());
848     }
849 
850     return AddOpWithShape(triop, status_or_shape.ValueOrDie(),
851                           {updated_lhs, updated_rhs, updated_ehs});
852   });
853 }
854 
ConstantLiteral(const LiteralSlice & literal)855 XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
856   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
857     if (literal.shape().IsArray() && literal.element_count() > 1 &&
858         literal.IsAllFirst()) {
859       Literal scalar = LiteralUtil::GetFirstScalarLiteral(literal);
860       HloInstructionProto instr;
861       *instr.mutable_shape() = scalar.shape().ToProto();
862       *instr.mutable_literal() = scalar.ToProto();
863       TF_ASSIGN_OR_RETURN(
864           XlaOp scalar_op,
865           AddInstruction(std::move(instr), HloOpcode::kConstant));
866       return Broadcast(scalar_op, literal.shape().dimensions());
867     } else {
868       HloInstructionProto instr;
869       *instr.mutable_shape() = literal.shape().ToProto();
870       *instr.mutable_literal() = literal.ToProto();
871       return AddInstruction(std::move(instr), HloOpcode::kConstant);
872     }
873   });
874 }
875 
Iota(const Shape & shape,int64_t iota_dimension)876 XlaOp XlaBuilder::Iota(const Shape& shape, int64_t iota_dimension) {
877   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
878     HloInstructionProto instr;
879     *instr.mutable_shape() = shape.ToProto();
880     instr.add_dimensions(iota_dimension);
881     return AddInstruction(std::move(instr), HloOpcode::kIota);
882   });
883 }
884 
Iota(PrimitiveType type,int64_t size)885 XlaOp XlaBuilder::Iota(PrimitiveType type, int64_t size) {
886   return Iota(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0);
887 }
888 
Call(const XlaComputation & computation,absl::Span<const XlaOp> operands)889 XlaOp XlaBuilder::Call(const XlaComputation& computation,
890                        absl::Span<const XlaOp> operands) {
891   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
892     HloInstructionProto instr;
893     std::vector<const Shape*> operand_shape_ptrs;
894     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
895     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
896                       [](const Shape& shape) { return &shape; });
897     TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
898                         computation.GetProgramShape());
899     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape(
900                                          operand_shape_ptrs,
901                                          /*to_apply=*/called_program_shape));
902     *instr.mutable_shape() = shape.ToProto();
903 
904     AddCalledComputation(computation, &instr);
905 
906     return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
907   });
908 }
909 
Parameter(int64_t parameter_number,const Shape & shape,const std::string & name,const std::vector<bool> & replicated_at_leaf_buffers)910 XlaOp XlaBuilder::Parameter(
911     int64_t parameter_number, const Shape& shape, const std::string& name,
912     const std::vector<bool>& replicated_at_leaf_buffers) {
913   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
914     HloInstructionProto instr;
915     if (!parameter_numbers_.insert(parameter_number).second) {
916       return InvalidArgument("parameter %d already registered",
917                              parameter_number);
918     }
919     instr.set_parameter_number(parameter_number);
920     instr.set_name(name);
921     *instr.mutable_shape() = shape.ToProto();
922     if (!replicated_at_leaf_buffers.empty()) {
923       auto replication = instr.mutable_parameter_replication();
924       for (bool replicated : replicated_at_leaf_buffers) {
925         replication->add_replicated_at_leaf_buffers(replicated);
926       }
927     }
928     return AddInstruction(std::move(instr), HloOpcode::kParameter);
929   });
930 }
931 
Broadcast(XlaOp operand,absl::Span<const int64_t> broadcast_sizes)932 XlaOp XlaBuilder::Broadcast(XlaOp operand,
933                             absl::Span<const int64_t> broadcast_sizes) {
934   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
935     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
936     TF_ASSIGN_OR_RETURN(
937         const Shape& shape,
938         ShapeInference::InferBroadcastShape(*operand_shape, broadcast_sizes));
939 
940     // The client-level broadcast op just appends dimensions on the left (adds
941     // lowest numbered dimensions). The HLO broadcast instruction is more
942     // flexible and can add new dimensions anywhere. The instruction's
943     // dimensions field maps operand dimensions to dimensions in the broadcast
944     // output, so to append dimensions on the left the instruction's dimensions
945     // should just be the n highest dimension numbers of the output shape where
946     // n is the number of input dimensions.
947     const int64_t operand_rank = operand_shape->rank();
948     std::vector<int64_t> dimensions(operand_rank);
949     for (int i = 0; i < operand_rank; ++i) {
950       dimensions[i] = i + shape.rank() - operand_rank;
951     }
952     return InDimBroadcast(shape, operand, dimensions);
953   });
954 }
955 
BroadcastInDim(XlaOp operand,const absl::Span<const int64_t> out_dim_size,const absl::Span<const int64_t> broadcast_dimensions)956 XlaOp XlaBuilder::BroadcastInDim(
957     XlaOp operand, const absl::Span<const int64_t> out_dim_size,
958     const absl::Span<const int64_t> broadcast_dimensions) {
959   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
960     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
961     // Output shape, in the case of degenerate broadcast, the out_dim_size is
962     // not necessarily the same as the dimension sizes of the output shape.
963     TF_ASSIGN_OR_RETURN(auto output_shape,
964                         ShapeUtil::MakeValidatedShape(
965                             operand_shape->element_type(), out_dim_size));
966     int64_t broadcast_rank = broadcast_dimensions.size();
967     if (operand_shape->rank() != broadcast_rank) {
968       return InvalidArgument(
969           "Size of broadcast_dimensions has to match operand's rank; operand "
970           "rank: %lld, size of broadcast_dimensions %u.",
971           operand_shape->rank(), broadcast_dimensions.size());
972     }
973     for (int i = 0; i < broadcast_rank; i++) {
974       const int64_t num_dims = out_dim_size.size();
975       if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] > num_dims) {
976         return InvalidArgument("Broadcast dimension %lld is out of bound",
977                                broadcast_dimensions[i]);
978       }
979       output_shape.set_dynamic_dimension(
980           broadcast_dimensions[i], operand_shape->is_dynamic_dimension(i));
981     }
982 
983     TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape(
984                            *operand_shape, output_shape, broadcast_dimensions)
985                            .status());
986     std::vector<int64_t> in_dim_size(out_dim_size.begin(), out_dim_size.end());
987     for (int i = 0; i < broadcast_rank; i++) {
988       in_dim_size[broadcast_dimensions[i]] = operand_shape->dimensions(i);
989     }
990     const auto& in_dim_shape =
991         ShapeUtil::MakeShape(operand_shape->element_type(), in_dim_size);
992     TF_ASSIGN_OR_RETURN(
993         XlaOp in_dim_broadcast,
994         InDimBroadcast(in_dim_shape, operand, broadcast_dimensions));
995 
996     // If broadcast is not degenerate, return broadcasted result.
997     if (ShapeUtil::Equal(in_dim_shape, output_shape)) {
998       return in_dim_broadcast;
999     }
1000 
1001     // Otherwise handle degenerate broadcast case.
1002     return AddBroadcastSequence(output_shape, in_dim_broadcast);
1003   });
1004 }
1005 
ReshapeInternal(const Shape & shape,XlaOp operand,int64_t inferred_dimension)1006 StatusOr<XlaOp> XlaBuilder::ReshapeInternal(const Shape& shape, XlaOp operand,
1007                                             int64_t inferred_dimension) {
1008   TF_RETURN_IF_ERROR(first_error_);
1009 
1010   HloInstructionProto instr;
1011   *instr.mutable_shape() = shape.ToProto();
1012   if (inferred_dimension != -1) {
1013     instr.add_dimensions(inferred_dimension);
1014   }
1015   return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand});
1016 }
1017 
Slice(XlaOp operand,absl::Span<const int64_t> start_indices,absl::Span<const int64_t> limit_indices,absl::Span<const int64_t> strides)1018 XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span<const int64_t> start_indices,
1019                         absl::Span<const int64_t> limit_indices,
1020                         absl::Span<const int64_t> strides) {
1021   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1022     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1023     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSliceShape(
1024                                          *operand_shape, start_indices,
1025                                          limit_indices, strides));
1026     return SliceInternal(shape, operand, start_indices, limit_indices, strides);
1027   });
1028 }
1029 
SliceInternal(const Shape & shape,XlaOp operand,absl::Span<const int64_t> start_indices,absl::Span<const int64_t> limit_indices,absl::Span<const int64_t> strides)1030 StatusOr<XlaOp> XlaBuilder::SliceInternal(
1031     const Shape& shape, XlaOp operand, absl::Span<const int64_t> start_indices,
1032     absl::Span<const int64_t> limit_indices,
1033     absl::Span<const int64_t> strides) {
1034   HloInstructionProto instr;
1035   *instr.mutable_shape() = shape.ToProto();
1036   for (int i = 0, end = start_indices.size(); i < end; i++) {
1037     auto* slice_config = instr.add_slice_dimensions();
1038     slice_config->set_start(start_indices[i]);
1039     slice_config->set_limit(limit_indices[i]);
1040     slice_config->set_stride(strides[i]);
1041   }
1042   return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
1043 }
1044 
SliceInDim(XlaOp operand,int64_t start_index,int64_t limit_index,int64_t stride,int64_t dimno)1045 XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64_t start_index,
1046                              int64_t limit_index, int64_t stride,
1047                              int64_t dimno) {
1048   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1049     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
1050     std::vector<int64_t> starts(shape->rank(), 0);
1051     std::vector<int64_t> limits(shape->dimensions().begin(),
1052                                 shape->dimensions().end());
1053     std::vector<int64_t> strides(shape->rank(), 1);
1054     starts[dimno] = start_index;
1055     limits[dimno] = limit_index;
1056     strides[dimno] = stride;
1057     return Slice(operand, starts, limits, strides);
1058   });
1059 }
1060 
DynamicSlice(XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64_t> slice_sizes)1061 XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
1062                                absl::Span<const XlaOp> start_indices,
1063                                absl::Span<const int64_t> slice_sizes) {
1064   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1065     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1066     std::vector<const Shape*> start_indices_shape_ptrs;
1067     TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
1068                         GetOperandShapes(start_indices));
1069     absl::c_transform(start_indices_shapes,
1070                       std::back_inserter(start_indices_shape_ptrs),
1071                       [](const Shape& shape) { return &shape; });
1072     TF_ASSIGN_OR_RETURN(Shape shape,
1073                         ShapeInference::InferDynamicSliceShape(
1074                             *operand_shape, start_indices_shapes, slice_sizes));
1075     return DynamicSliceInternal(shape, operand, start_indices, slice_sizes);
1076   });
1077 }
1078 
DynamicSliceInternal(const Shape & shape,XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64_t> slice_sizes)1079 StatusOr<XlaOp> XlaBuilder::DynamicSliceInternal(
1080     const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
1081     absl::Span<const int64_t> slice_sizes) {
1082   HloInstructionProto instr;
1083   *instr.mutable_shape() = shape.ToProto();
1084 
1085   for (int64_t size : slice_sizes) {
1086     instr.add_dynamic_slice_sizes(size);
1087   }
1088 
1089   std::vector<XlaOp> operands = {operand};
1090   operands.insert(operands.end(), start_indices.begin(), start_indices.end());
1091   return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands);
1092 }
1093 
DynamicUpdateSlice(XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)1094 XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
1095                                      absl::Span<const XlaOp> start_indices) {
1096   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1097     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1098     TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
1099     std::vector<const Shape*> start_indices_shape_ptrs;
1100     TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
1101                         GetOperandShapes(start_indices));
1102     absl::c_transform(start_indices_shapes,
1103                       std::back_inserter(start_indices_shape_ptrs),
1104                       [](const Shape& shape) { return &shape; });
1105     TF_ASSIGN_OR_RETURN(
1106         Shape shape, ShapeInference::InferDynamicUpdateSliceShape(
1107                          *operand_shape, *update_shape, start_indices_shapes));
1108     return DynamicUpdateSliceInternal(shape, operand, update, start_indices);
1109   });
1110 }
1111 
DynamicUpdateSliceInternal(const Shape & shape,XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)1112 StatusOr<XlaOp> XlaBuilder::DynamicUpdateSliceInternal(
1113     const Shape& shape, XlaOp operand, XlaOp update,
1114     absl::Span<const XlaOp> start_indices) {
1115   HloInstructionProto instr;
1116   *instr.mutable_shape() = shape.ToProto();
1117 
1118   std::vector<XlaOp> operands = {operand, update};
1119   operands.insert(operands.end(), start_indices.begin(), start_indices.end());
1120   return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
1121                         operands);
1122 }
1123 
ConcatInDim(absl::Span<const XlaOp> operands,int64_t dimension)1124 XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
1125                               int64_t dimension) {
1126   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1127     std::vector<const Shape*> operand_shape_ptrs;
1128     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
1129     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1130                       [](const Shape& shape) { return &shape; });
1131     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape(
1132                                          operand_shape_ptrs, dimension));
1133     return ConcatInDimInternal(shape, operands, dimension);
1134   });
1135 }
1136 
ConcatInDimInternal(const Shape & shape,absl::Span<const XlaOp> operands,int64_t dimension)1137 StatusOr<XlaOp> XlaBuilder::ConcatInDimInternal(
1138     const Shape& shape, absl::Span<const XlaOp> operands, int64_t dimension) {
1139   HloInstructionProto instr;
1140   *instr.mutable_shape() = shape.ToProto();
1141 
1142   instr.add_dimensions(dimension);
1143 
1144   return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
1145 }
1146 
Pad(XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)1147 XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value,
1148                       const PaddingConfig& padding_config) {
1149   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1150     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1151     TF_ASSIGN_OR_RETURN(const Shape* padding_value_shape,
1152                         GetShapePtr(padding_value));
1153     TF_ASSIGN_OR_RETURN(
1154         Shape shape, ShapeInference::InferPadShape(
1155                          *operand_shape, *padding_value_shape, padding_config));
1156     return PadInternal(shape, operand, padding_value, padding_config);
1157   });
1158 }
1159 
PadInDim(XlaOp operand,XlaOp padding_value,int64_t dimno,int64_t pad_lo,int64_t pad_hi)1160 XlaOp XlaBuilder::PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
1161                            int64_t pad_lo, int64_t pad_hi) {
1162   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1163     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
1164     PaddingConfig padding_config = MakeNoPaddingConfig(shape->rank());
1165     auto* dims = padding_config.mutable_dimensions(dimno);
1166     dims->set_edge_padding_low(pad_lo);
1167     dims->set_edge_padding_high(pad_hi);
1168     return Pad(operand, padding_value, padding_config);
1169   });
1170 }
1171 
PadInternal(const Shape & shape,XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)1172 StatusOr<XlaOp> XlaBuilder::PadInternal(const Shape& shape, XlaOp operand,
1173                                         XlaOp padding_value,
1174                                         const PaddingConfig& padding_config) {
1175   HloInstructionProto instr;
1176   *instr.mutable_shape() = shape.ToProto();
1177   *instr.mutable_padding_config() = padding_config;
1178   return AddInstruction(std::move(instr), HloOpcode::kPad,
1179                         {operand, padding_value});
1180 }
1181 
Reshape(XlaOp operand,absl::Span<const int64_t> dimensions,absl::Span<const int64_t> new_sizes,int64_t inferred_dimension)1182 XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64_t> dimensions,
1183                           absl::Span<const int64_t> new_sizes,
1184                           int64_t inferred_dimension) {
1185   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1186     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1187     TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferReshapeShape(
1188                                                *operand_shape, dimensions,
1189                                                new_sizes, inferred_dimension));
1190     XlaOp transposed = IsIdentityPermutation(dimensions)
1191                            ? operand
1192                            : Transpose(operand, dimensions);
1193     return ReshapeInternal(shape, transposed, inferred_dimension);
1194   });
1195 }
1196 
Reshape(XlaOp operand,absl::Span<const int64_t> new_sizes,int64_t inferred_dimension)1197 XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64_t> new_sizes,
1198                           int64_t inferred_dimension) {
1199   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1200     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
1201     std::vector<int64_t> dimensions(shape->dimensions_size());
1202     std::iota(dimensions.begin(), dimensions.end(), 0);
1203     return Reshape(operand, dimensions, new_sizes, inferred_dimension);
1204   });
1205 }
1206 
Reshape(const Shape & shape,XlaOp operand,int64_t inferred_dimension)1207 XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand,
1208                           int64_t inferred_dimension) {
1209   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1210     return ReshapeInternal(shape, operand, inferred_dimension);
1211   });
1212 }
1213 
DynamicReshape(XlaOp operand,absl::Span<const XlaOp> dim_sizes,absl::Span<const int64_t> new_size_bounds,const std::vector<bool> & dims_are_dynamic)1214 XlaOp XlaBuilder::DynamicReshape(XlaOp operand,
1215                                  absl::Span<const XlaOp> dim_sizes,
1216                                  absl::Span<const int64_t> new_size_bounds,
1217                                  const std::vector<bool>& dims_are_dynamic) {
1218   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1219     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1220     std::vector<const Shape*> dim_size_shape_ptrs;
1221     TF_ASSIGN_OR_RETURN(const auto& dim_size_shapes,
1222                         GetOperandShapes(dim_sizes));
1223 
1224     absl::c_transform(dim_size_shapes, std::back_inserter(dim_size_shape_ptrs),
1225                       [](const Shape& shape) { return &shape; });
1226     TF_ASSIGN_OR_RETURN(const Shape shape,
1227                         ShapeInference::InferDynamicReshapeShape(
1228                             *operand_shape, dim_size_shape_ptrs,
1229                             new_size_bounds, dims_are_dynamic));
1230     TF_RETURN_IF_ERROR(first_error_);
1231     std::vector<XlaOp> operands;
1232     operands.reserve(1 + dim_sizes.size());
1233     operands.push_back(operand);
1234     for (const XlaOp& dim_size : dim_sizes) {
1235       operands.push_back(dim_size);
1236     }
1237     HloInstructionProto instr;
1238     *instr.mutable_shape() = shape.ToProto();
1239     return AddInstruction(std::move(instr), HloOpcode::kDynamicReshape,
1240                           operands);
1241   });
1242 }
1243 
Collapse(XlaOp operand,absl::Span<const int64_t> dimensions)1244 XlaOp XlaBuilder::Collapse(XlaOp operand,
1245                            absl::Span<const int64_t> dimensions) {
1246   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1247     if (dimensions.size() <= 1) {
1248       // Not collapsing anything, trivially we can return the operand versus
1249       // enqueueing a trivial reshape.
1250       return operand;
1251     }
1252 
1253     // Out-of-order collapse is not supported.
1254     // Checks that the collapsed dimensions are in order and consecutive.
1255     for (absl::Span<const int64_t>::size_type i = 1; i < dimensions.size();
1256          ++i) {
1257       if (dimensions[i] - 1 != dimensions[i - 1]) {
1258         return InvalidArgument(
1259             "Collapsed dimensions are not in consecutive order.");
1260       }
1261     }
1262 
1263     // Create a new sizes vector from the old shape, replacing the collapsed
1264     // dimensions by the product of their sizes.
1265     TF_ASSIGN_OR_RETURN(const Shape* original_shape, GetShapePtr(operand));
1266 
1267     VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape);
1268     VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ",");
1269 
1270     std::vector<int64_t> new_sizes;
1271     for (int i = 0; i < original_shape->rank(); ++i) {
1272       if (i <= dimensions.front() || i > dimensions.back()) {
1273         new_sizes.push_back(original_shape->dimensions(i));
1274       } else {
1275         new_sizes.back() *= original_shape->dimensions(i);
1276       }
1277     }
1278 
1279     VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]";
1280 
1281     return Reshape(operand, new_sizes);
1282   });
1283 }
1284 
1285 // Dummy pass-through computation returning it's parameter of shape `shape`.
PassthroughComputation(const Shape & shape)1286 static StatusOr<XlaComputation> PassthroughComputation(const Shape& shape) {
1287   XlaBuilder builder("dummy");
1288   XlaOp out = Parameter(&builder, 0, shape, "p");
1289   return builder.Build(out);
1290 }
1291 
Select(XlaOp pred,XlaOp on_true,XlaOp on_false)1292 XlaOp XlaBuilder::Select(XlaOp pred, XlaOp on_true, XlaOp on_false) {
1293   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1294     TF_ASSIGN_OR_RETURN(const Shape* true_shape, GetShapePtr(on_true));
1295     TF_ASSIGN_OR_RETURN(const Shape* false_shape, GetShapePtr(on_false));
1296     TF_RET_CHECK(true_shape->IsTuple() == false_shape->IsTuple());
1297     if (true_shape->IsTuple()) {
1298       TF_ASSIGN_OR_RETURN(XlaComputation passthrough_true,
1299                           PassthroughComputation(*true_shape));
1300       TF_ASSIGN_OR_RETURN(XlaComputation passthrough_false,
1301                           PassthroughComputation(*false_shape));
1302       return Conditional(pred, on_true, passthrough_true, on_false,
1303                          passthrough_false);
1304     }
1305     return TernaryOp(HloOpcode::kSelect, pred, on_true, on_false);
1306   });
1307 }
1308 
Tuple(absl::Span<const XlaOp> elements)1309 XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
1310   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1311     std::vector<const Shape*> operand_shape_ptrs;
1312     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
1313     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1314                       [](const Shape& shape) { return &shape; });
1315     TF_ASSIGN_OR_RETURN(const Shape shape,
1316                         ShapeInference::InferVariadicOpShape(
1317                             HloOpcode::kTuple, operand_shape_ptrs));
1318     return TupleInternal(shape, elements);
1319   });
1320 }
1321 
TupleInternal(const Shape & shape,absl::Span<const XlaOp> elements)1322 StatusOr<XlaOp> XlaBuilder::TupleInternal(const Shape& shape,
1323                                           absl::Span<const XlaOp> elements) {
1324   HloInstructionProto instr;
1325   *instr.mutable_shape() = shape.ToProto();
1326   return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
1327 }
1328 
GetTupleElement(XlaOp tuple_data,int64_t index)1329 XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64_t index) {
1330   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1331     TF_ASSIGN_OR_RETURN(const Shape* tuple_shape, GetShapePtr(tuple_data));
1332     if (!tuple_shape->IsTuple()) {
1333       return InvalidArgument(
1334           "Operand to GetTupleElement() is not a tuple; got %s",
1335           ShapeUtil::HumanString(*tuple_shape));
1336     }
1337     if (index < 0 || index >= ShapeUtil::TupleElementCount(*tuple_shape)) {
1338       return InvalidArgument(
1339           "GetTupleElement() index (%d) out of range for tuple shape %s", index,
1340           ShapeUtil::HumanString(*tuple_shape));
1341     }
1342     return GetTupleElementInternal(
1343         ShapeUtil::GetTupleElementShape(*tuple_shape, index), tuple_data,
1344         index);
1345   });
1346 }
1347 
GetTupleElementInternal(const Shape & shape,XlaOp tuple_data,int64_t index)1348 StatusOr<XlaOp> XlaBuilder::GetTupleElementInternal(const Shape& shape,
1349                                                     XlaOp tuple_data,
1350                                                     int64_t index) {
1351   HloInstructionProto instr;
1352   *instr.mutable_shape() = shape.ToProto();
1353   instr.set_tuple_index(index);
1354   return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
1355                         {tuple_data});
1356 }
1357 
Dot(XlaOp lhs,XlaOp rhs,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)1358 XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
1359                       const PrecisionConfig* precision_config,
1360                       std::optional<PrimitiveType> preferred_element_type) {
1361   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1362     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1363 
1364     DotDimensionNumbers dimension_numbers;
1365     dimension_numbers.add_lhs_contracting_dimensions(
1366         lhs_shape->dimensions_size() == 1 ? 0 : 1);
1367     dimension_numbers.add_rhs_contracting_dimensions(0);
1368     return DotGeneral(lhs, rhs, dimension_numbers, precision_config,
1369                       preferred_element_type);
1370   });
1371 }
1372 
DotGeneral(XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)1373 XlaOp XlaBuilder::DotGeneral(
1374     XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
1375     const PrecisionConfig* precision_config,
1376     std::optional<PrimitiveType> preferred_element_type) {
1377   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1378     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1379     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1380     TF_ASSIGN_OR_RETURN(
1381         Shape shape,
1382         ShapeInference::InferDotOpShape(
1383             *lhs_shape, *rhs_shape, dimension_numbers, preferred_element_type));
1384     return DotGeneralInternal(shape, lhs, rhs, dimension_numbers,
1385                               precision_config);
1386   });
1387 }
1388 
DotGeneralInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config)1389 StatusOr<XlaOp> XlaBuilder::DotGeneralInternal(
1390     const Shape& shape, XlaOp lhs, XlaOp rhs,
1391     const DotDimensionNumbers& dimension_numbers,
1392     const PrecisionConfig* precision_config) {
1393   HloInstructionProto instr;
1394   *instr.mutable_shape() = shape.ToProto();
1395   *instr.mutable_dot_dimension_numbers() = dimension_numbers;
1396   if (precision_config != nullptr) {
1397     *instr.mutable_precision_config() = *precision_config;
1398   }
1399   return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
1400 }
1401 
VerifyConvolution(const Shape & lhs_shape,const Shape & rhs_shape,const ConvolutionDimensionNumbers & dimension_numbers) const1402 Status XlaBuilder::VerifyConvolution(
1403     const Shape& lhs_shape, const Shape& rhs_shape,
1404     const ConvolutionDimensionNumbers& dimension_numbers) const {
1405   if (lhs_shape.rank() != rhs_shape.rank()) {
1406     return InvalidArgument(
1407         "Convolution arguments must have same number of "
1408         "dimensions. Got: %s and %s",
1409         ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1410   }
1411   int num_dims = lhs_shape.rank();
1412   if (num_dims < 2) {
1413     return InvalidArgument(
1414         "Convolution expects argument arrays with >= 3 dimensions. "
1415         "Got: %s and %s",
1416         ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1417   }
1418   int num_spatial_dims = num_dims - 2;
1419 
1420   const auto check_spatial_dimensions = [&](absl::string_view field_name,
1421                                             absl::Span<const int64_t> numbers) {
1422     if (numbers.size() != num_spatial_dims) {
1423       return InvalidArgument("Expected %d elements for %s, but got %d.",
1424                              num_spatial_dims, field_name, numbers.size());
1425     }
1426     for (int i = 0; i < numbers.size(); ++i) {
1427       if (numbers[i] < 0 || numbers[i] >= num_dims) {
1428         return InvalidArgument("Convolution %s[%d] is out of bounds: %d",
1429                                field_name, i, numbers[i]);
1430       }
1431     }
1432     return OkStatus();
1433   };
1434   TF_RETURN_IF_ERROR(
1435       check_spatial_dimensions("input_spatial_dimensions",
1436                                dimension_numbers.input_spatial_dimensions()));
1437   TF_RETURN_IF_ERROR(
1438       check_spatial_dimensions("kernel_spatial_dimensions",
1439                                dimension_numbers.kernel_spatial_dimensions()));
1440   return check_spatial_dimensions(
1441       "output_spatial_dimensions",
1442       dimension_numbers.output_spatial_dimensions());
1443 }
1444 
Conv(XlaOp lhs,XlaOp rhs,absl::Span<const int64_t> window_strides,Padding padding,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)1445 XlaOp XlaBuilder::Conv(XlaOp lhs, XlaOp rhs,
1446                        absl::Span<const int64_t> window_strides,
1447                        Padding padding, int64_t feature_group_count,
1448                        int64_t batch_group_count,
1449                        const PrecisionConfig* precision_config,
1450                        std::optional<PrimitiveType> preferred_element_type) {
1451   return ConvWithGeneralDimensions(
1452       lhs, rhs, window_strides, padding,
1453       CreateDefaultConvDimensionNumbers(window_strides.size()),
1454       feature_group_count, batch_group_count, precision_config,
1455       preferred_element_type);
1456 }
1457 
ConvWithGeneralPadding(XlaOp lhs,XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)1458 XlaOp XlaBuilder::ConvWithGeneralPadding(
1459     XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1460     absl::Span<const std::pair<int64_t, int64_t>> padding,
1461     int64_t feature_group_count, int64_t batch_group_count,
1462     const PrecisionConfig* precision_config,
1463     std::optional<PrimitiveType> preferred_element_type) {
1464   return ConvGeneral(lhs, rhs, window_strides, padding,
1465                      CreateDefaultConvDimensionNumbers(window_strides.size()),
1466                      feature_group_count, batch_group_count, precision_config,
1467                      preferred_element_type);
1468 }
1469 
ConvWithGeneralDimensions(XlaOp lhs,XlaOp rhs,absl::Span<const int64_t> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)1470 XlaOp XlaBuilder::ConvWithGeneralDimensions(
1471     XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1472     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1473     int64_t feature_group_count, int64_t batch_group_count,
1474     const PrecisionConfig* precision_config,
1475     std::optional<PrimitiveType> preferred_element_type) {
1476   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1477     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1478     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1479 
1480     TF_RETURN_IF_ERROR(
1481         VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers));
1482 
1483     std::vector<int64_t> base_area_dimensions(
1484         dimension_numbers.input_spatial_dimensions_size());
1485     for (std::vector<int64_t>::size_type i = 0; i < base_area_dimensions.size();
1486          ++i) {
1487       base_area_dimensions[i] =
1488           lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i));
1489     }
1490 
1491     std::vector<int64_t> window_dimensions(
1492         dimension_numbers.kernel_spatial_dimensions_size());
1493     for (std::vector<int64_t>::size_type i = 0; i < window_dimensions.size();
1494          ++i) {
1495       window_dimensions[i] =
1496           rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1497     }
1498 
1499     return ConvGeneral(lhs, rhs, window_strides,
1500                        MakePadding(base_area_dimensions, window_dimensions,
1501                                    window_strides, padding),
1502                        dimension_numbers, feature_group_count,
1503                        batch_group_count, precision_config,
1504                        preferred_element_type);
1505   });
1506 }
1507 
ConvGeneral(XlaOp lhs,XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)1508 XlaOp XlaBuilder::ConvGeneral(
1509     XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1510     absl::Span<const std::pair<int64_t, int64_t>> padding,
1511     const ConvolutionDimensionNumbers& dimension_numbers,
1512     int64_t feature_group_count, int64_t batch_group_count,
1513     const PrecisionConfig* precision_config,
1514     std::optional<PrimitiveType> preferred_element_type) {
1515   return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
1516                             dimension_numbers, feature_group_count,
1517                             batch_group_count, precision_config,
1518                             preferred_element_type);
1519 }
1520 
1521 // TODO(rmcilroy) Ideally window_reversal would be a absl::Span<const bool>
1522 // however this causes an error in pybind11's conversion code. See
1523 // https://github.com/pybind/pybind11_abseil/issues/4 for details.
ConvGeneralDilated(XlaOp lhs,XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type,std::optional<std::vector<bool>> window_reversal)1524 XlaOp XlaBuilder::ConvGeneralDilated(
1525     XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1526     absl::Span<const std::pair<int64_t, int64_t>> padding,
1527     absl::Span<const int64_t> lhs_dilation,
1528     absl::Span<const int64_t> rhs_dilation,
1529     const ConvolutionDimensionNumbers& dimension_numbers,
1530     int64_t feature_group_count, int64_t batch_group_count,
1531     const PrecisionConfig* precision_config,
1532     std::optional<PrimitiveType> preferred_element_type,
1533     std::optional<std::vector<bool>> window_reversal) {
1534   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1535     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1536     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1537     TF_RETURN_IF_ERROR(
1538         VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers));
1539 
1540     std::vector<int64_t> window_dimensions(
1541         dimension_numbers.kernel_spatial_dimensions_size());
1542     for (std::vector<int64_t>::size_type i = 0; i < window_dimensions.size();
1543          ++i) {
1544       window_dimensions[i] =
1545           rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1546     }
1547 
1548     TF_ASSIGN_OR_RETURN(Window window,
1549                         ShapeInference::InferWindowFromDimensions(
1550                             window_dimensions, window_strides, padding,
1551                             lhs_dilation, rhs_dilation, window_reversal));
1552     TF_ASSIGN_OR_RETURN(
1553         Shape shape,
1554         ShapeInference::InferConvolveShape(
1555             *lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
1556             window, dimension_numbers, preferred_element_type));
1557     return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides,
1558                                       padding, lhs_dilation, rhs_dilation,
1559                                       dimension_numbers, feature_group_count,
1560                                       batch_group_count, precision_config);
1561   });
1562 }
1563 
DynamicConvInstruction(XlaOp lhs,XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,std::optional<PrimitiveType> preferred_element_type)1564 StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
1565     XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1566     absl::Span<const std::pair<int64_t, int64_t>> padding,
1567     absl::Span<const int64_t> lhs_dilation,
1568     absl::Span<const int64_t> rhs_dilation,
1569     const ConvolutionDimensionNumbers& dimension_numbers,
1570     int64_t feature_group_count, int64_t batch_group_count,
1571     const PrecisionConfig* precision_config, PaddingType padding_type,
1572     std::optional<PrimitiveType> preferred_element_type) {
1573   TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1574   TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1575   std::vector<int64_t> window_dimensions(
1576       dimension_numbers.kernel_spatial_dimensions_size());
1577   for (std::vector<int64_t>::size_type i = 0; i < window_dimensions.size();
1578        ++i) {
1579     window_dimensions[i] =
1580         rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1581   }
1582 
1583   TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions(
1584                                          window_dimensions, window_strides,
1585                                          padding, lhs_dilation, rhs_dilation));
1586   TF_ASSIGN_OR_RETURN(
1587       Shape shape,
1588       ShapeInference::InferConvolveShape(
1589           *lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
1590           window, dimension_numbers, preferred_element_type));
1591 
1592   HloInstructionProto instr;
1593   *instr.mutable_shape() = shape.ToProto();
1594 
1595   *instr.mutable_window() = window;
1596   *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
1597   instr.set_feature_group_count(feature_group_count);
1598   instr.set_batch_group_count(batch_group_count);
1599   instr.set_padding_type(padding_type);
1600 
1601   if (precision_config != nullptr) {
1602     *instr.mutable_precision_config() = *precision_config;
1603   }
1604   return std::move(instr);
1605 }
1606 
DynamicConvInputGrad(XlaOp input_sizes,XlaOp lhs,XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,std::optional<PrimitiveType> preferred_element_type)1607 XlaOp XlaBuilder::DynamicConvInputGrad(
1608     XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
1609     absl::Span<const int64_t> window_strides,
1610     absl::Span<const std::pair<int64_t, int64_t>> padding,
1611     absl::Span<const int64_t> lhs_dilation,
1612     absl::Span<const int64_t> rhs_dilation,
1613     const ConvolutionDimensionNumbers& dimension_numbers,
1614     int64_t feature_group_count, int64_t batch_group_count,
1615     const PrecisionConfig* precision_config, PaddingType padding_type,
1616     std::optional<PrimitiveType> preferred_element_type) {
1617   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1618     TF_ASSIGN_OR_RETURN(
1619         HloInstructionProto instr,
1620         DynamicConvInstruction(
1621             lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
1622             dimension_numbers, feature_group_count, batch_group_count,
1623             precision_config, padding_type, preferred_element_type));
1624 
1625     instr.set_custom_call_target("DynamicConvolutionInputGrad");
1626 
1627     return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
1628                           {input_sizes, lhs, rhs});
1629   });
1630 }
1631 
DynamicConvKernelGrad(XlaOp activations,XlaOp gradients,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,std::optional<PrimitiveType> preferred_element_type)1632 XlaOp XlaBuilder::DynamicConvKernelGrad(
1633     XlaOp activations, XlaOp gradients,
1634     absl::Span<const int64_t> window_strides,
1635     absl::Span<const std::pair<int64_t, int64_t>> padding,
1636     absl::Span<const int64_t> lhs_dilation,
1637     absl::Span<const int64_t> rhs_dilation,
1638     const ConvolutionDimensionNumbers& dimension_numbers,
1639     int64_t feature_group_count, int64_t batch_group_count,
1640     const PrecisionConfig* precision_config, PaddingType padding_type,
1641     std::optional<PrimitiveType> preferred_element_type) {
1642   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1643     TF_ASSIGN_OR_RETURN(
1644         HloInstructionProto instr,
1645         DynamicConvInstruction(activations, gradients, window_strides, padding,
1646                                lhs_dilation, rhs_dilation, dimension_numbers,
1647                                feature_group_count, batch_group_count,
1648                                precision_config, padding_type,
1649                                preferred_element_type));
1650 
1651     instr.set_custom_call_target("DynamicConvolutionKernelGrad");
1652     // The gradient of kernel has kernel shape and shouldn't have any dynamic
1653     // sizes.
1654     instr.mutable_shape()->clear_is_dynamic_dimension();
1655     return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
1656                           {activations, gradients});
1657   });
1658 }
1659 
DynamicConvForward(XlaOp lhs,XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,std::optional<PrimitiveType> preferred_element_type)1660 XlaOp XlaBuilder::DynamicConvForward(
1661     XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1662     absl::Span<const std::pair<int64_t, int64_t>> padding,
1663     absl::Span<const int64_t> lhs_dilation,
1664     absl::Span<const int64_t> rhs_dilation,
1665     const ConvolutionDimensionNumbers& dimension_numbers,
1666     int64_t feature_group_count, int64_t batch_group_count,
1667     const PrecisionConfig* precision_config, PaddingType padding_type,
1668     std::optional<PrimitiveType> preferred_element_type) {
1669   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1670     TF_ASSIGN_OR_RETURN(
1671         HloInstructionProto instr,
1672         DynamicConvInstruction(
1673             lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
1674             dimension_numbers, feature_group_count, batch_group_count,
1675             precision_config, padding_type, preferred_element_type));
1676     instr.set_custom_call_target("DynamicConvolutionForward");
1677 
1678     return AddInstruction(std::move(instr), HloOpcode::kCustomCall, {lhs, rhs});
1679   });
1680 }
1681 
ConvGeneralDilatedInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const Window & window,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config)1682 StatusOr<XlaOp> XlaBuilder::ConvGeneralDilatedInternal(
1683     const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
1684     absl::Span<const int64_t> window_strides,
1685     absl::Span<const std::pair<int64_t, int64_t>> padding,
1686     absl::Span<const int64_t> lhs_dilation,
1687     absl::Span<const int64_t> rhs_dilation,
1688     const ConvolutionDimensionNumbers& dimension_numbers,
1689     int64_t feature_group_count, int64_t batch_group_count,
1690     const PrecisionConfig* precision_config) {
1691   HloInstructionProto instr;
1692   *instr.mutable_shape() = shape.ToProto();
1693 
1694   *instr.mutable_window() = window;
1695   *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
1696   instr.set_feature_group_count(feature_group_count);
1697   instr.set_batch_group_count(batch_group_count);
1698 
1699   if (precision_config != nullptr) {
1700     *instr.mutable_precision_config() = *precision_config;
1701   }
1702 
1703   return AddInstruction(std::move(instr), HloOpcode::kConvolution, {lhs, rhs});
1704 }
1705 
Fft(XlaOp operand,const FftType fft_type,const absl::Span<const int64_t> fft_length)1706 XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type,
1707                       const absl::Span<const int64_t> fft_length) {
1708   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1709     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1710     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape(
1711                                          *operand_shape, fft_type, fft_length));
1712     return FftInternal(shape, operand, fft_type, fft_length);
1713   });
1714 }
1715 
FftInternal(const Shape & shape,XlaOp operand,const FftType fft_type,const absl::Span<const int64_t> fft_length)1716 StatusOr<XlaOp> XlaBuilder::FftInternal(
1717     const Shape& shape, XlaOp operand, const FftType fft_type,
1718     const absl::Span<const int64_t> fft_length) {
1719   HloInstructionProto instr;
1720   *instr.mutable_shape() = shape.ToProto();
1721   instr.set_fft_type(fft_type);
1722   for (int64_t i : fft_length) {
1723     instr.add_fft_length(i);
1724   }
1725 
1726   return AddInstruction(std::move(instr), HloOpcode::kFft, {operand});
1727 }
1728 
TriangularSolveInternal(const Shape & shape,XlaOp a,XlaOp b,TriangularSolveOptions options)1729 StatusOr<XlaOp> XlaBuilder::TriangularSolveInternal(
1730     const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) {
1731   HloInstructionProto instr;
1732   *instr.mutable_triangular_solve_options() = std::move(options);
1733   *instr.mutable_shape() = shape.ToProto();
1734 
1735   return AddInstruction(std::move(instr), HloOpcode::kTriangularSolve, {a, b});
1736 }
1737 
CholeskyInternal(const Shape & shape,XlaOp a,bool lower)1738 StatusOr<XlaOp> XlaBuilder::CholeskyInternal(const Shape& shape, XlaOp a,
1739                                              bool lower) {
1740   HloInstructionProto instr;
1741   CholeskyOptions& options = *instr.mutable_cholesky_options();
1742   options.set_lower(lower);
1743   *instr.mutable_shape() = shape.ToProto();
1744 
1745   return AddInstruction(std::move(instr), HloOpcode::kCholesky, {a});
1746 }
1747 
Infeed(const Shape & shape,const std::string & config)1748 XlaOp XlaBuilder::Infeed(const Shape& shape, const std::string& config) {
1749   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1750     HloInstructionProto instr;
1751     if (!LayoutUtil::HasLayout(shape)) {
1752       return InvalidArgument("Given shape to Infeed must have a layout");
1753     }
1754     const Shape infeed_instruction_shape =
1755         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1756     *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1757     instr.set_infeed_config(config);
1758 
1759     if (shape.IsArray() && sharding() &&
1760         sharding()->type() == OpSharding::OTHER) {
1761       // TODO(b/110793772): Support tiled array-shaped infeeds.
1762       return InvalidArgument(
1763           "Tiled sharding is not yet supported for array-shaped infeeds");
1764     }
1765 
1766     if (sharding() && sharding()->type() == OpSharding::REPLICATED) {
1767       return InvalidArgument(
1768           "Replicated sharding is not yet supported for infeeds");
1769     }
1770 
1771     // Infeed takes a single token operand. Generate the token to pass to the
1772     // infeed.
1773     XlaOp token;
1774     auto make_token = [&]() {
1775       HloInstructionProto token_instr;
1776       *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1777       return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {});
1778     };
1779     if (sharding()) {
1780       // Arbitrarily assign token to device 0.
1781       OpSharding sharding = sharding_builder::AssignDevice(0);
1782       XlaScopedShardingAssignment scoped_sharding(this, sharding);
1783       TF_ASSIGN_OR_RETURN(token, make_token());
1784     } else {
1785       TF_ASSIGN_OR_RETURN(token, make_token());
1786     }
1787 
1788     // The sharding is set by the client according to the data tuple shape.
1789     // However, the shape of the infeed instruction is a tuple containing the
1790     // data and a token. For tuple sharding type, the sharding must be changed
1791     // to accommodate the token.
1792     XlaOp infeed;
1793     if (sharding() && sharding()->type() == OpSharding::TUPLE) {
1794       // TODO(b/80000000): Remove this when clients have been updated to handle
1795       // tokens.
1796       OpSharding infeed_instruction_sharding = *sharding();
1797       // Arbitrarily assign the token to device 0.
1798       *infeed_instruction_sharding.add_tuple_shardings() =
1799           sharding_builder::AssignDevice(0);
1800       XlaScopedShardingAssignment scoped_sharding(this,
1801                                                   infeed_instruction_sharding);
1802       TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1803                                                  HloOpcode::kInfeed, {token}));
1804     } else {
1805       TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1806                                                  HloOpcode::kInfeed, {token}));
1807     }
1808 
1809     // The infeed instruction produces a tuple of the infed data and a token
1810     // type. Return XLA op containing the data.
1811     // TODO(b/80000000): Remove this when clients have been updated to handle
1812     // tokens.
1813     HloInstructionProto infeed_data;
1814     *infeed_data.mutable_shape() = shape.ToProto();
1815     infeed_data.set_tuple_index(0);
1816     return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement,
1817                           {infeed});
1818   });
1819 }
1820 
InfeedWithToken(XlaOp token,const Shape & shape,const std::string & config)1821 XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape,
1822                                   const std::string& config) {
1823   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1824     if (!LayoutUtil::HasLayout(shape)) {
1825       return InvalidArgument("Given shape to Infeed must have a layout");
1826     }
1827     const Shape infeed_instruction_shape =
1828         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1829 
1830     if (shape.IsArray() && sharding() &&
1831         sharding()->type() == OpSharding::OTHER) {
1832       // TODO(b/110793772): Support tiled array-shaped infeeds.
1833       return InvalidArgument(
1834           "Tiled sharding is not yet supported for array-shaped infeeds");
1835     }
1836 
1837     if (sharding() && sharding()->type() == OpSharding::REPLICATED) {
1838       return InvalidArgument(
1839           "Replicated sharding is not yet supported for infeeds");
1840     }
1841     return InfeedWithTokenInternal(infeed_instruction_shape, token, config);
1842   });
1843 }
1844 
InfeedWithTokenInternal(const Shape & infeed_instruction_shape,XlaOp token,const std::string & config)1845 StatusOr<XlaOp> XlaBuilder::InfeedWithTokenInternal(
1846     const Shape& infeed_instruction_shape, XlaOp token,
1847     const std::string& config) {
1848   HloInstructionProto instr;
1849   *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1850   instr.set_infeed_config(config);
1851   return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
1852 }
1853 
Outfeed(XlaOp operand,const Shape & shape_with_layout,const std::string & outfeed_config)1854 void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout,
1855                          const std::string& outfeed_config) {
1856   ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1857     HloInstructionProto instr;
1858 
1859     *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1860 
1861     // Check and set outfeed shape.
1862     if (!LayoutUtil::HasLayout(shape_with_layout)) {
1863       return InvalidArgument("Given shape to Outfeed must have a layout");
1864     }
1865     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1866     if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
1867       return InvalidArgument(
1868           "Outfeed shape %s must be compatible with operand shape %s",
1869           ShapeUtil::HumanStringWithLayout(shape_with_layout),
1870           ShapeUtil::HumanStringWithLayout(*operand_shape));
1871     }
1872     *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1873 
1874     instr.set_outfeed_config(outfeed_config);
1875 
1876     // Outfeed takes a token as its second operand. Generate the token to pass
1877     // to the outfeed.
1878     HloInstructionProto token_instr;
1879     *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1880     TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
1881                                                     HloOpcode::kAfterAll, {}));
1882 
1883     TF_RETURN_IF_ERROR(
1884         AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand, token})
1885             .status());
1886 
1887     // The outfeed instruction produces a token. However, existing users expect
1888     // a nil shape (empty tuple). This should only be relevant if the outfeed is
1889     // the root of a computation.
1890     // TODO(b/80000000): Remove this when clients have been updated to handle
1891     // tokens.
1892     HloInstructionProto tuple_instr;
1893     *tuple_instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
1894 
1895     // The dummy tuple should have no sharding.
1896     {
1897       XlaScopedShardingAssignment scoped_sharding(this, OpSharding());
1898       TF_ASSIGN_OR_RETURN(
1899           XlaOp empty_tuple,
1900           AddInstruction(std::move(tuple_instr), HloOpcode::kTuple, {}));
1901       return empty_tuple;
1902     }
1903   });
1904 }
1905 
OutfeedWithToken(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const std::string & outfeed_config)1906 XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token,
1907                                    const Shape& shape_with_layout,
1908                                    const std::string& outfeed_config) {
1909   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1910     // Check and set outfeed shape.
1911     if (!LayoutUtil::HasLayout(shape_with_layout)) {
1912       return InvalidArgument("Given shape to Outfeed must have a layout");
1913     }
1914     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1915     if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
1916       return InvalidArgument(
1917           "Outfeed shape %s must be compatible with operand shape %s",
1918           ShapeUtil::HumanStringWithLayout(shape_with_layout),
1919           ShapeUtil::HumanStringWithLayout(*operand_shape));
1920     }
1921     return OutfeedWithTokenInternal(operand, token, shape_with_layout,
1922                                     outfeed_config);
1923   });
1924 }
1925 
OutfeedWithTokenInternal(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const std::string & outfeed_config)1926 StatusOr<XlaOp> XlaBuilder::OutfeedWithTokenInternal(
1927     XlaOp operand, XlaOp token, const Shape& shape_with_layout,
1928     const std::string& outfeed_config) {
1929   HloInstructionProto instr;
1930   *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1931   *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1932   instr.set_outfeed_config(outfeed_config);
1933   return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
1934                         {operand, token});
1935 }
1936 
CreateToken()1937 XlaOp XlaBuilder::CreateToken() {
1938   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1939     HloInstructionProto instr;
1940     *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1941     return AddInstruction(std::move(instr), HloOpcode::kAfterAll);
1942   });
1943 }
1944 
AfterAll(absl::Span<const XlaOp> tokens)1945 XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
1946   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1947     if (tokens.empty()) {
1948       return InvalidArgument("AfterAll requires at least one operand");
1949     }
1950     for (int i = 0, end = tokens.size(); i < end; ++i) {
1951       XlaOp operand = tokens[i];
1952       TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1953       if (!operand_shape->IsToken()) {
1954         return InvalidArgument(
1955             "All operands to AfterAll must be tokens; operand %d has shape %s",
1956             i, ShapeUtil::HumanString(*operand_shape));
1957       }
1958     }
1959     HloInstructionProto instr;
1960     *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1961     return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens);
1962   });
1963 }
1964 
CustomCall(const std::string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const std::string & opaque,std::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> output_operand_aliasing,const Literal * literal,std::optional<Window> window,std::optional<ConvolutionDimensionNumbers> dnums,CustomCallSchedule schedule,CustomCallApiVersion api_version)1965 XlaOp XlaBuilder::CustomCall(
1966     const std::string& call_target_name, absl::Span<const XlaOp> operands,
1967     const Shape& shape, const std::string& opaque,
1968     std::optional<absl::Span<const Shape>> operand_shapes_with_layout,
1969     bool has_side_effect,
1970     absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
1971         output_operand_aliasing,
1972     const Literal* literal, std::optional<Window> window,
1973     std::optional<ConvolutionDimensionNumbers> dnums,
1974     CustomCallSchedule schedule, CustomCallApiVersion api_version) {
1975   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1976     if (absl::StartsWith(call_target_name, "$")) {
1977       return InvalidArgument(
1978           "Invalid custom_call_target \"%s\": Call targets that start with '$' "
1979           "are reserved for internal use.",
1980           call_target_name);
1981     }
1982     if (operand_shapes_with_layout.has_value()) {
1983       if (!LayoutUtil::HasLayout(shape)) {
1984         return InvalidArgument(
1985             "Result shape must have layout for custom call with constrained "
1986             "layout.");
1987       }
1988       if (operands.size() != operand_shapes_with_layout->size()) {
1989         return InvalidArgument(
1990             "Must specify a shape with layout for each operand for custom call "
1991             "with constrained layout; given %d shapes, expected %d",
1992             operand_shapes_with_layout->size(), operands.size());
1993       }
1994       int64_t operand_num = 0;
1995       for (const Shape& operand_shape : *operand_shapes_with_layout) {
1996         if (!LayoutUtil::HasLayout(operand_shape)) {
1997           return InvalidArgument(
1998               "No layout specified for operand %d for custom call with "
1999               "constrained layout.",
2000               operand_num);
2001         }
2002         ++operand_num;
2003       }
2004     }
2005     return CustomCallInternal(
2006         call_target_name, operands, /*computation=*/nullptr, shape, opaque,
2007         operand_shapes_with_layout, has_side_effect, output_operand_aliasing,
2008         literal, window, dnums, schedule, api_version);
2009   });
2010 }
2011 
CustomCallInternal(const std::string & call_target_name,absl::Span<const XlaOp> operands,const XlaComputation * computation,const Shape & shape,const std::string & opaque,std::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> output_operand_aliasing,const Literal * literal,std::optional<Window> window,std::optional<ConvolutionDimensionNumbers> dnums,CustomCallSchedule schedule,CustomCallApiVersion api_version)2012 StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
2013     const std::string& call_target_name, absl::Span<const XlaOp> operands,
2014     const XlaComputation* computation, const Shape& shape,
2015     const std::string& opaque,
2016     std::optional<absl::Span<const Shape>> operand_shapes_with_layout,
2017     bool has_side_effect,
2018     absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
2019         output_operand_aliasing,
2020     const Literal* literal, std::optional<Window> window,
2021     std::optional<ConvolutionDimensionNumbers> dnums,
2022     CustomCallSchedule schedule, CustomCallApiVersion api_version) {
2023   HloInstructionProto instr;
2024   // Bit of a hack: cudnn conv custom-calls are created through this API. Give
2025   // them a user-friendly name. (This has no effect on correctness, it's just
2026   // cosmetic.)
2027   if (call_target_name == "__cudnn$convForward") {
2028     instr.set_name("cudnn-conv");
2029   } else if (call_target_name == "__cudnn$convBackwardInput") {
2030     instr.set_name("cudnn-conv-bw-input");
2031   } else if (call_target_name == "__cudnn$convBackwardFilter") {
2032     instr.set_name("cudnn-conv-bw-filter");
2033   } else if (call_target_name == "__cudnn$convBiasActivationForward") {
2034     instr.set_name("cudnn-conv-bias-activation");
2035   }
2036   *instr.mutable_shape() = shape.ToProto();
2037   instr.set_custom_call_target(call_target_name);
2038   instr.set_backend_config(opaque);
2039   if (operand_shapes_with_layout.has_value()) {
2040     instr.set_constrain_layout(true);
2041     for (const Shape& operand_shape : *operand_shapes_with_layout) {
2042       *instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
2043     }
2044   }
2045   if (literal != nullptr) {
2046     *instr.mutable_literal() = literal->ToProto();
2047   }
2048   instr.set_custom_call_has_side_effect(has_side_effect);
2049   if (computation != nullptr && !computation->IsNull()) {
2050     AddCalledComputation(*computation, &instr);
2051   }
2052   for (const auto& pair : output_operand_aliasing) {
2053     auto aliasing = instr.add_custom_call_output_operand_aliasing();
2054     aliasing->set_operand_index(pair.second.first);
2055     for (int64_t index : pair.second.second) {
2056       aliasing->add_operand_shape_index(index);
2057     }
2058     for (int64_t index : pair.first) {
2059       aliasing->add_output_shape_index(index);
2060     }
2061   }
2062   if (window.has_value()) {
2063     *instr.mutable_window() = *window;
2064   }
2065   if (dnums.has_value()) {
2066     *instr.mutable_convolution_dimension_numbers() = *dnums;
2067   }
2068   instr.set_custom_call_schedule(schedule);
2069   instr.set_custom_call_api_version(api_version);
2070   return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
2071 }
2072 
CustomCall(const std::string & call_target_name,absl::Span<const XlaOp> operands,const XlaComputation & computation,const Shape & shape,const std::string & opaque,std::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> output_operand_aliasing,const Literal * literal,CustomCallSchedule schedule,CustomCallApiVersion api_version)2073 XlaOp XlaBuilder::CustomCall(
2074     const std::string& call_target_name, absl::Span<const XlaOp> operands,
2075     const XlaComputation& computation, const Shape& shape,
2076     const std::string& opaque,
2077     std::optional<absl::Span<const Shape>> operand_shapes_with_layout,
2078     bool has_side_effect,
2079     absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
2080         output_operand_aliasing,
2081     const Literal* literal, CustomCallSchedule schedule,
2082     CustomCallApiVersion api_version) {
2083   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2084     if (absl::StartsWith(call_target_name, "$")) {
2085       return InvalidArgument(
2086           "Invalid custom_call_target \"%s\": Call targets that start with '$' "
2087           "are reserved for internal use.",
2088           call_target_name);
2089     }
2090     if (operand_shapes_with_layout.has_value()) {
2091       if (!LayoutUtil::HasLayout(shape)) {
2092         return InvalidArgument(
2093             "Result shape must have layout for custom call with constrained "
2094             "layout.");
2095       }
2096       if (operands.size() != operand_shapes_with_layout->size()) {
2097         return InvalidArgument(
2098             "Must specify a shape with layout for each operand for custom call "
2099             "with constrained layout; given %d shapes, expected %d",
2100             operand_shapes_with_layout->size(), operands.size());
2101       }
2102       int64_t operand_num = 0;
2103       for (const Shape& operand_shape : *operand_shapes_with_layout) {
2104         if (!LayoutUtil::HasLayout(operand_shape)) {
2105           return InvalidArgument(
2106               "No layout specified for operand %d for custom call with "
2107               "constrained layout.",
2108               operand_num);
2109         }
2110         ++operand_num;
2111       }
2112     }
2113     return CustomCallInternal(
2114         call_target_name, operands, &computation, shape, opaque,
2115         operand_shapes_with_layout, has_side_effect, output_operand_aliasing,
2116         literal, /*window=*/{}, /*dnums=*/{}, schedule, api_version);
2117   });
2118 }
2119 
OptimizationBarrier(XlaOp operand)2120 XlaOp XlaBuilder::OptimizationBarrier(XlaOp operand) {
2121   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2122     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2123     Shape shape = *operand_shape;
2124     HloInstructionProto instr;
2125     *instr.mutable_shape() = shape.ToProto();
2126     return AddInstruction(std::move(instr), HloOpcode::kOptimizationBarrier,
2127                           {operand});
2128   });
2129 }
2130 
Transpose(XlaOp operand,absl::Span<const int64_t> permutation)2131 XlaOp XlaBuilder::Transpose(XlaOp operand,
2132                             absl::Span<const int64_t> permutation) {
2133   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2134     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2135     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape(
2136                                          *operand_shape, permutation));
2137     return TransposeInternal(shape, operand, permutation);
2138   });
2139 }
2140 
TransposeInternal(const Shape & shape,XlaOp operand,absl::Span<const int64_t> permutation)2141 StatusOr<XlaOp> XlaBuilder::TransposeInternal(
2142     const Shape& shape, XlaOp operand, absl::Span<const int64_t> permutation) {
2143   HloInstructionProto instr;
2144   *instr.mutable_shape() = shape.ToProto();
2145   for (int64_t dim : permutation) {
2146     instr.add_dimensions(dim);
2147   }
2148   return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
2149 }
2150 
Rev(XlaOp operand,absl::Span<const int64_t> dimensions)2151 XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span<const int64_t> dimensions) {
2152   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2153     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2154     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape(
2155                                          *operand_shape, dimensions));
2156     return RevInternal(shape, operand, dimensions);
2157   });
2158 }
2159 
RevInternal(const Shape & shape,XlaOp operand,absl::Span<const int64_t> dimensions)2160 StatusOr<XlaOp> XlaBuilder::RevInternal(const Shape& shape, XlaOp operand,
2161                                         absl::Span<const int64_t> dimensions) {
2162   HloInstructionProto instr;
2163   *instr.mutable_shape() = shape.ToProto();
2164   for (int64_t dim : dimensions) {
2165     instr.add_dimensions(dim);
2166   }
2167   return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand});
2168 }
2169 
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64_t dimension,bool is_stable)2170 XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
2171                        const XlaComputation& comparator, int64_t dimension,
2172                        bool is_stable) {
2173   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2174     std::vector<const Shape*> operand_shape_ptrs;
2175     TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes,
2176                         GetOperandShapes(operands));
2177     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
2178                       [](const Shape& shape) { return &shape; });
2179     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape(
2180                                          HloOpcode::kSort, operand_shape_ptrs));
2181     return SortInternal(shape, operands, comparator, dimension, is_stable);
2182   });
2183 }
2184 
SortInternal(const Shape & shape,absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64_t dimension,bool is_stable)2185 StatusOr<XlaOp> XlaBuilder::SortInternal(const Shape& shape,
2186                                          absl::Span<const XlaOp> operands,
2187                                          const XlaComputation& comparator,
2188                                          int64_t dimension, bool is_stable) {
2189   HloInstructionProto instr;
2190   *instr.mutable_shape() = shape.ToProto();
2191   instr.set_is_stable(is_stable);
2192   if (dimension == -1) {
2193     TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0]));
2194     dimension = keys_shape->rank() - 1;
2195   }
2196   instr.add_dimensions(dimension);
2197   AddCalledComputation(comparator, &instr);
2198   return AddInstruction(std::move(instr), HloOpcode::kSort, operands);
2199 }
2200 
ConvertElementType(XlaOp operand,PrimitiveType new_element_type)2201 XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
2202                                      PrimitiveType new_element_type) {
2203   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2204     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2205     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
2206                                          *operand_shape, new_element_type));
2207     if (primitive_util::IsComplexType(operand_shape->element_type()) &&
2208         !primitive_util::IsComplexType(new_element_type)) {
2209       operand = Real(operand);
2210     }
2211     return AddOpWithShape(HloOpcode::kConvert, shape, {operand});
2212   });
2213 }
2214 
BitcastConvertType(XlaOp operand,PrimitiveType new_element_type)2215 XlaOp XlaBuilder::BitcastConvertType(XlaOp operand,
2216                                      PrimitiveType new_element_type) {
2217   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2218     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2219     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferBitcastConvertShape(
2220                                          *operand_shape, new_element_type));
2221     return BitcastConvertTypeInternal(shape, operand);
2222   });
2223 }
2224 
BitcastConvertTypeInternal(const Shape & shape,XlaOp operand)2225 StatusOr<XlaOp> XlaBuilder::BitcastConvertTypeInternal(const Shape& shape,
2226                                                        XlaOp operand) {
2227   HloInstructionProto instr;
2228   *instr.mutable_shape() = shape.ToProto();
2229   return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
2230                         {operand});
2231 }
2232 
Clamp(XlaOp min,XlaOp operand,XlaOp max)2233 XlaOp XlaBuilder::Clamp(XlaOp min, XlaOp operand, XlaOp max) {
2234   return TernaryOp(HloOpcode::kClamp, min, operand, max);
2235 }
2236 
Map(absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64_t> dimensions,absl::Span<const XlaOp> static_operands)2237 XlaOp XlaBuilder::Map(absl::Span<const XlaOp> operands,
2238                       const XlaComputation& computation,
2239                       absl::Span<const int64_t> dimensions,
2240                       absl::Span<const XlaOp> static_operands) {
2241   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2242     if (!static_operands.empty()) {
2243       return Unimplemented("static_operands is not supported in Map");
2244     }
2245 
2246     HloInstructionProto instr;
2247     std::vector<const Shape*> operand_shape_ptrs;
2248     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
2249     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
2250                       [](const Shape& shape) { return &shape; });
2251     TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
2252                         computation.GetProgramShape());
2253     TF_ASSIGN_OR_RETURN(
2254         Shape shape, ShapeInference::InferMapShape(
2255                          operand_shape_ptrs, called_program_shape, dimensions));
2256     *instr.mutable_shape() = shape.ToProto();
2257 
2258     Shape output_shape(instr.shape());
2259     const int64_t output_rank = output_shape.rank();
2260     AddCalledComputation(computation, &instr);
2261     std::vector<XlaOp> new_operands(operands.begin(), operands.end());
2262     for (XlaOp& new_operand : new_operands) {
2263       TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(new_operand));
2264       const int64_t rank = shape->rank();
2265       if (rank != output_rank) {
2266         TF_ASSIGN_OR_RETURN(new_operand,
2267                             InDimBroadcast(output_shape, new_operand, {}));
2268         TF_ASSIGN_OR_RETURN(shape, GetShapePtr(new_operand));
2269       }
2270       if (!ShapeUtil::SameDimensions(output_shape, *shape)) {
2271         TF_ASSIGN_OR_RETURN(new_operand,
2272                             AddBroadcastSequence(output_shape, new_operand));
2273       }
2274     }
2275 
2276     return AddInstruction(std::move(instr), HloOpcode::kMap, new_operands);
2277   });
2278 }
2279 
RngOp(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)2280 XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
2281                         absl::Span<const XlaOp> parameters,
2282                         const Shape& shape) {
2283   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2284     // Check the number of parameters per RNG distribution.
2285     switch (distribution) {
2286       case RandomDistribution::RNG_NORMAL:
2287       case RandomDistribution::RNG_UNIFORM:
2288         if (parameters.size() != 2) {
2289           return InvalidArgument(
2290               "RNG distribution (%s) expects 2 parameters, but got %ld",
2291               RandomDistribution_Name(distribution), parameters.size());
2292         }
2293         break;
2294       default:
2295         LOG(FATAL) << "unhandled distribution " << distribution;
2296     }
2297 
2298     TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
2299     return RngOpInternal(distribution, parameters, shape);
2300   });
2301 }
2302 
RngOpInternal(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)2303 StatusOr<XlaOp> XlaBuilder::RngOpInternal(RandomDistribution distribution,
2304                                           absl::Span<const XlaOp> parameters,
2305                                           const Shape& shape) {
2306   HloInstructionProto instr;
2307   *instr.mutable_shape() = shape.ToProto();
2308   instr.set_distribution(distribution);
2309 
2310   return AddInstruction(std::move(instr), HloOpcode::kRng, parameters);
2311 }
2312 
RngNormal(XlaOp mu,XlaOp sigma,const Shape & shape)2313 XlaOp XlaBuilder::RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape) {
2314   return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
2315 }
2316 
RngUniform(XlaOp a,XlaOp b,const Shape & shape)2317 XlaOp XlaBuilder::RngUniform(XlaOp a, XlaOp b, const Shape& shape) {
2318   return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
2319 }
2320 
RngBitGenerator(RandomAlgorithm algorithm,XlaOp initial_state,const Shape & shape)2321 XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm,
2322                                   XlaOp initial_state, const Shape& shape) {
2323   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2324     TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
2325     TF_ASSIGN_OR_RETURN(Shape state_shape, GetShape(initial_state));
2326     Shape output_shape = shape;
2327     switch (output_shape.element_type()) {
2328       case PrimitiveType::S8:
2329       case PrimitiveType::U8:
2330         output_shape.set_element_type(PrimitiveType::U8);
2331         break;
2332       case PrimitiveType::BF16:
2333       case PrimitiveType::F16:
2334       case PrimitiveType::S16:
2335       case PrimitiveType::U16:
2336         output_shape.set_element_type(PrimitiveType::U16);
2337         break;
2338       case PrimitiveType::F32:
2339       case PrimitiveType::S32:
2340       case PrimitiveType::U32:
2341         output_shape.set_element_type(PrimitiveType::U32);
2342         break;
2343       case PrimitiveType::F64:
2344       case PrimitiveType::S64:
2345       case PrimitiveType::U64:
2346         output_shape.set_element_type(PrimitiveType::U64);
2347         break;
2348       default:
2349         return InvalidArgument("Unsupported shape for RngBitGenerator: %s",
2350                                PrimitiveType_Name(output_shape.element_type()));
2351     }
2352     return RngBitGeneratorInternal(
2353         ShapeUtil::MakeTupleShapeWithPtrs({&state_shape, &output_shape}),
2354         algorithm, initial_state);
2355   });
2356 }
2357 
RngBitGeneratorInternal(const Shape & full_result_shape,RandomAlgorithm algorithm,XlaOp initial_state)2358 StatusOr<XlaOp> XlaBuilder::RngBitGeneratorInternal(
2359     const Shape& full_result_shape, RandomAlgorithm algorithm,
2360     XlaOp initial_state) {
2361   HloInstructionProto instr;
2362   *instr.mutable_shape() = full_result_shape.ToProto();
2363   instr.set_rng_algorithm(algorithm);
2364   return AddInstruction(std::move(instr), HloOpcode::kRngBitGenerator,
2365                         {initial_state});
2366 }
2367 
While(const XlaComputation & condition,const XlaComputation & body,XlaOp init)2368 XlaOp XlaBuilder::While(const XlaComputation& condition,
2369                         const XlaComputation& body, XlaOp init) {
2370   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2371     // Infer shape.
2372     TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape());
2373     TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
2374                         condition.GetProgramShape());
2375     TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init));
2376     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape(
2377                                          condition_program_shape,
2378                                          body_program_shape, *init_shape));
2379     return WhileInternal(shape, condition, body, init);
2380   });
2381 }
2382 
WhileInternal(const Shape & shape,const XlaComputation & condition,const XlaComputation & body,XlaOp init)2383 StatusOr<XlaOp> XlaBuilder::WhileInternal(const Shape& shape,
2384                                           const XlaComputation& condition,
2385                                           const XlaComputation& body,
2386                                           XlaOp init) {
2387   HloInstructionProto instr;
2388   *instr.mutable_shape() = shape.ToProto();
2389   // Body comes before condition computation in the vector.
2390   AddCalledComputation(body, &instr);
2391   AddCalledComputation(condition, &instr);
2392   return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
2393 }
2394 
Gather(XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64_t> slice_sizes,bool indices_are_sorted)2395 XlaOp XlaBuilder::Gather(XlaOp input, XlaOp start_indices,
2396                          const GatherDimensionNumbers& dimension_numbers,
2397                          absl::Span<const int64_t> slice_sizes,
2398                          bool indices_are_sorted) {
2399   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2400     TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
2401     TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape,
2402                         GetShapePtr(start_indices));
2403     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGatherShape(
2404                                          *input_shape, *start_indices_shape,
2405                                          dimension_numbers, slice_sizes));
2406     return GatherInternal(shape, input, start_indices, dimension_numbers,
2407                           slice_sizes, indices_are_sorted);
2408   });
2409 }
2410 
GatherInternal(const Shape & shape,XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64_t> slice_sizes,bool indices_are_sorted)2411 StatusOr<XlaOp> XlaBuilder::GatherInternal(
2412     const Shape& shape, XlaOp input, XlaOp start_indices,
2413     const GatherDimensionNumbers& dimension_numbers,
2414     absl::Span<const int64_t> slice_sizes, bool indices_are_sorted) {
2415   HloInstructionProto instr;
2416   instr.set_indices_are_sorted(indices_are_sorted);
2417   *instr.mutable_shape() = shape.ToProto();
2418   *instr.mutable_gather_dimension_numbers() = dimension_numbers;
2419   for (int64_t bound : slice_sizes) {
2420     instr.add_gather_slice_sizes(bound);
2421   }
2422 
2423   return AddInstruction(std::move(instr), HloOpcode::kGather,
2424                         {input, start_indices});
2425 }
2426 
Scatter(XlaOp input,XlaOp scatter_indices,XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)2427 XlaOp XlaBuilder::Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
2428                           const XlaComputation& update_computation,
2429                           const ScatterDimensionNumbers& dimension_numbers,
2430                           bool indices_are_sorted, bool unique_indices) {
2431   return Scatter(absl::MakeConstSpan(&input, 1), scatter_indices,
2432                  absl::MakeConstSpan(&updates, 1), update_computation,
2433                  dimension_numbers, indices_are_sorted, unique_indices);
2434 }
2435 
Scatter(absl::Span<const XlaOp> inputs,XlaOp scatter_indices,absl::Span<const XlaOp> updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)2436 XlaOp XlaBuilder::Scatter(absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
2437                           absl::Span<const XlaOp> updates,
2438                           const XlaComputation& update_computation,
2439                           const ScatterDimensionNumbers& dimension_numbers,
2440                           bool indices_are_sorted, bool unique_indices) {
2441   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2442     if (inputs.empty()) {
2443       return InvalidArgument("Scatter inputs cannot be empty.");
2444     }
2445     if (inputs.size() != updates.size()) {
2446       return InvalidArgument(
2447           "Scatter should have same number of inputs and updates: %d vs %d.",
2448           inputs.size(), updates.size());
2449     }
2450     absl::InlinedVector<const Shape*, 3> operand_shapes;
2451     operand_shapes.reserve(inputs.size() + 1 + updates.size());
2452     for (const XlaOp& input : inputs) {
2453       TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
2454       operand_shapes.push_back(input_shape);
2455     }
2456     TF_ASSIGN_OR_RETURN(const Shape* scatter_indices_shape,
2457                         GetShapePtr(scatter_indices));
2458     operand_shapes.push_back(scatter_indices_shape);
2459     for (const XlaOp& update : updates) {
2460       TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
2461       operand_shapes.push_back(update_shape);
2462     }
2463     TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2464                         update_computation.GetProgramShape());
2465     TF_ASSIGN_OR_RETURN(Shape shape,
2466                         ShapeInference::InferScatterShape(
2467                             operand_shapes, to_apply_shape, dimension_numbers));
2468     return ScatterInternal(shape, inputs, scatter_indices, updates,
2469                            update_computation, dimension_numbers,
2470                            indices_are_sorted, unique_indices);
2471   });
2472 }
2473 
ScatterInternal(const Shape & shape,absl::Span<const XlaOp> inputs,XlaOp scatter_indices,absl::Span<const XlaOp> updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)2474 StatusOr<XlaOp> XlaBuilder::ScatterInternal(
2475     const Shape& shape, absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
2476     absl::Span<const XlaOp> updates, const XlaComputation& update_computation,
2477     const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
2478     bool unique_indices) {
2479   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2480     HloInstructionProto instr;
2481     instr.set_indices_are_sorted(indices_are_sorted);
2482     instr.set_unique_indices(unique_indices);
2483     *instr.mutable_shape() = shape.ToProto();
2484     *instr.mutable_scatter_dimension_numbers() = dimension_numbers;
2485 
2486     AddCalledComputation(update_computation, &instr);
2487     absl::InlinedVector<XlaOp, 3> operands;
2488     operands.reserve(inputs.size() + 1 + updates.size());
2489     absl::c_copy(inputs, std::back_inserter(operands));
2490     operands.push_back(scatter_indices);
2491     absl::c_copy(updates, std::back_inserter(operands));
2492     return AddInstruction(std::move(instr), HloOpcode::kScatter, operands);
2493   });
2494 }
2495 
Conditional(XlaOp predicate,XlaOp true_operand,const XlaComputation & true_computation,XlaOp false_operand,const XlaComputation & false_computation)2496 XlaOp XlaBuilder::Conditional(XlaOp predicate, XlaOp true_operand,
2497                               const XlaComputation& true_computation,
2498                               XlaOp false_operand,
2499                               const XlaComputation& false_computation) {
2500   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2501     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(predicate));
2502 
2503     if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != PRED) {
2504       return InvalidArgument(
2505           "Argument to predicated-Conditional is not a scalar of PRED type "
2506           "(%s).",
2507           ShapeUtil::HumanString(*shape));
2508     }
2509     // The index of true_computation must be 0 and that of false computation
2510     // must be 1.
2511     return ConditionalImpl(predicate, {&true_computation, &false_computation},
2512                            {true_operand, false_operand});
2513   });
2514 }
2515 
Conditional(XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)2516 XlaOp XlaBuilder::Conditional(
2517     XlaOp branch_index,
2518     absl::Span<const XlaComputation* const> branch_computations,
2519     absl::Span<const XlaOp> branch_operands) {
2520   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2521     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(branch_index));
2522 
2523     if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != S32) {
2524       return InvalidArgument(
2525           "Argument to indexed-Conditional is not a scalar of S32 type (%s).",
2526           ShapeUtil::HumanString(*shape));
2527     }
2528     return ConditionalImpl(branch_index, branch_computations, branch_operands);
2529   });
2530 }
2531 
ConditionalImpl(XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)2532 XlaOp XlaBuilder::ConditionalImpl(
2533     XlaOp branch_index,
2534     absl::Span<const XlaComputation* const> branch_computations,
2535     absl::Span<const XlaOp> branch_operands) {
2536   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2537     HloInstructionProto instr;
2538 
2539     TF_ASSIGN_OR_RETURN(const Shape* branch_index_shape,
2540                         GetShapePtr(branch_index));
2541     std::vector<Shape> branch_operand_shapes(branch_operands.size());
2542     std::vector<ProgramShape> branch_computation_shapes(
2543         branch_computations.size());
2544     for (int j = 0, end = branch_operands.size(); j < end; ++j) {
2545       TF_ASSIGN_OR_RETURN(branch_operand_shapes[j],
2546                           GetShape(branch_operands[j]));
2547       TF_ASSIGN_OR_RETURN(branch_computation_shapes[j],
2548                           branch_computations[j]->GetProgramShape());
2549     }
2550     TF_ASSIGN_OR_RETURN(const Shape shape,
2551                         ShapeInference::InferConditionalShape(
2552                             *branch_index_shape, branch_computation_shapes,
2553                             branch_operand_shapes));
2554     *instr.mutable_shape() = shape.ToProto();
2555 
2556     for (const XlaComputation* branch_computation : branch_computations) {
2557       AddCalledComputation(*branch_computation, &instr);
2558     }
2559 
2560     std::vector<XlaOp> operands(1, branch_index);
2561     for (const XlaOp branch_operand : branch_operands) {
2562       operands.emplace_back(branch_operand);
2563     }
2564     return AddInstruction(std::move(instr), HloOpcode::kConditional,
2565                           absl::MakeSpan(operands));
2566   });
2567 }
2568 
CheckOpBuilder(XlaOp op) const2569 Status XlaBuilder::CheckOpBuilder(XlaOp op) const {
2570   if (this != op.builder()) {
2571     return InvalidArgument(
2572         "XlaOp with handle %d is built by builder '%s', but is trying to use "
2573         "it in builder '%s'",
2574         op.handle(), op.builder()->name(), name());
2575   }
2576   return OkStatus();
2577 }
2578 
Reduce(XlaOp operand,XlaOp init_value,const XlaComputation & computation,absl::Span<const int64_t> dimensions_to_reduce)2579 XlaOp XlaBuilder::Reduce(XlaOp operand, XlaOp init_value,
2580                          const XlaComputation& computation,
2581                          absl::Span<const int64_t> dimensions_to_reduce) {
2582   return Reduce(absl::Span<const XlaOp>({operand}),
2583                 absl::Span<const XlaOp>({init_value}), computation,
2584                 dimensions_to_reduce);
2585 }
2586 
Reduce(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64_t> dimensions_to_reduce)2587 XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
2588                          absl::Span<const XlaOp> init_values,
2589                          const XlaComputation& computation,
2590                          absl::Span<const int64_t> dimensions_to_reduce) {
2591   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2592     TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
2593                         computation.GetProgramShape());
2594 
2595     std::vector<XlaOp> all_operands;
2596     all_operands.insert(all_operands.end(), operands.begin(), operands.end());
2597     all_operands.insert(all_operands.end(), init_values.begin(),
2598                         init_values.end());
2599 
2600     std::vector<const Shape*> operand_shape_ptrs;
2601     TF_ASSIGN_OR_RETURN(const auto& operand_shapes,
2602                         GetOperandShapes(all_operands));
2603     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
2604                       [](const Shape& shape) { return &shape; });
2605 
2606     TF_ASSIGN_OR_RETURN(
2607         Shape shape,
2608         ShapeInference::InferReduceShape(
2609             operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
2610     return ReduceInternal(shape, all_operands, computation,
2611                           dimensions_to_reduce);
2612   });
2613 }
2614 
ReduceInternal(const Shape & shape,absl::Span<const XlaOp> all_operands,const XlaComputation & computation,absl::Span<const int64_t> dimensions_to_reduce)2615 StatusOr<XlaOp> XlaBuilder::ReduceInternal(
2616     const Shape& shape, absl::Span<const XlaOp> all_operands,
2617     const XlaComputation& computation,
2618     absl::Span<const int64_t> dimensions_to_reduce) {
2619   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2620     HloInstructionProto instr;
2621     *instr.mutable_shape() = shape.ToProto();
2622 
2623     for (int64_t dim : dimensions_to_reduce) {
2624       instr.add_dimensions(dim);
2625     }
2626 
2627     AddCalledComputation(computation, &instr);
2628     return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands);
2629   });
2630 }
2631 
ReduceAll(XlaOp operand,XlaOp init_value,const XlaComputation & computation)2632 XlaOp XlaBuilder::ReduceAll(XlaOp operand, XlaOp init_value,
2633                             const XlaComputation& computation) {
2634   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2635     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2636     std::vector<int64_t> all_dimnos(operand_shape->rank());
2637     std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
2638     return Reduce(operand, init_value, computation, all_dimnos);
2639   });
2640 }
2641 
ReduceWindow(XlaOp operand,XlaOp init_value,const XlaComputation & computation,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,Padding padding)2642 XlaOp XlaBuilder::ReduceWindow(XlaOp operand, XlaOp init_value,
2643                                const XlaComputation& computation,
2644                                absl::Span<const int64_t> window_dimensions,
2645                                absl::Span<const int64_t> window_strides,
2646                                Padding padding) {
2647   return ReduceWindow(absl::MakeSpan(&operand, 1),
2648                       absl::MakeSpan(&init_value, 1), computation,
2649                       window_dimensions, window_strides, padding);
2650 }
2651 
ReduceWindow(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,Padding padding)2652 XlaOp XlaBuilder::ReduceWindow(absl::Span<const XlaOp> operands,
2653                                absl::Span<const XlaOp> init_values,
2654                                const XlaComputation& computation,
2655                                absl::Span<const int64_t> window_dimensions,
2656                                absl::Span<const int64_t> window_strides,
2657                                Padding padding) {
2658   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2659     const Shape* operand_shape = nullptr;
2660     for (const auto& operand : operands) {
2661       TF_ASSIGN_OR_RETURN(operand_shape, GetShapePtr(operand));
2662       TF_RETURN_IF_ERROR(ValidatePaddingValues(
2663           operand_shape->dimensions(), window_dimensions, window_strides));
2664     }
2665     CHECK(operand_shape != nullptr);
2666     std::vector<std::pair<int64_t, int64_t>> padding_values =
2667         MakePadding(operand_shape->dimensions(), window_dimensions,
2668                     window_strides, padding);
2669     TF_ASSIGN_OR_RETURN(auto window,
2670                         ShapeInference::InferWindowFromDimensions(
2671                             window_dimensions, window_strides, padding_values,
2672                             /*lhs_dilation=*/{},
2673                             /*rhs_dilation=*/{}));
2674     PaddingType padding_type = PADDING_INVALID;
2675     for (int64_t i = 0; i < operand_shape->rank(); ++i) {
2676       if (operand_shape->is_dynamic_dimension(i) &&
2677           !window_util::IsTrivialWindowDimension(window.dimensions(i)) &&
2678           padding == Padding::kSame) {
2679         // SAME padding can create dynamic padding sizes. The padding size
2680         // need to be rewritten by dynamic padder using HloInstructions. We
2681         // create a CustomCall to handle this.
2682         padding_type = PADDING_SAME;
2683       }
2684     }
2685     if (padding_type == PADDING_SAME) {
2686       TF_ASSIGN_OR_RETURN(
2687           HloInstructionProto instr,
2688           ReduceWindowInternal(operands, init_values, computation,
2689                                window_dimensions, window_strides, {}, {},
2690                                padding_values));
2691       instr.set_custom_call_target("DynamicReduceWindowSamePadding");
2692       std::vector<XlaOp> args;
2693       args.insert(args.end(), operands.begin(), operands.end());
2694       args.insert(args.end(), init_values.begin(), init_values.end());
2695       return AddInstruction(std::move(instr), HloOpcode::kCustomCall, args);
2696     }
2697     return ReduceWindowWithGeneralPadding(
2698         operands, init_values, computation, window_dimensions, window_strides,
2699         /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values);
2700   });
2701 }
2702 
ReduceWindowWithGeneralPadding(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,absl::Span<const int64_t> base_dilations,absl::Span<const int64_t> window_dilations,absl::Span<const std::pair<int64_t,int64_t>> padding)2703 XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
2704     absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
2705     const XlaComputation& computation,
2706     absl::Span<const int64_t> window_dimensions,
2707     absl::Span<const int64_t> window_strides,
2708     absl::Span<const int64_t> base_dilations,
2709     absl::Span<const int64_t> window_dilations,
2710     absl::Span<const std::pair<int64_t, int64_t>> padding) {
2711   std::vector<const Shape*> operand_shapes, init_shapes;
2712   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2713     if (operands.size() == 1) {
2714       const auto& operand = operands[0];
2715       const auto& init_value = init_values[0];
2716       TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2717       operand_shapes.push_back(operand_shape);
2718       TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
2719       init_shapes.push_back(init_shape);
2720 
2721       TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2722                           computation.GetProgramShape());
2723       TF_ASSIGN_OR_RETURN(auto window,
2724                           ShapeInference::InferWindowFromDimensions(
2725                               window_dimensions, window_strides, padding,
2726                               /*lhs_dilation=*/base_dilations,
2727                               /*rhs_dilation=*/window_dilations));
2728       TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape(
2729                                            absl::MakeSpan(operand_shapes),
2730                                            absl::MakeSpan(init_shapes), window,
2731                                            to_apply_shape));
2732       return ReduceWindowInternal(shape, operands[0], init_values[0],
2733                                   computation, window);
2734     }
2735 
2736     TF_ASSIGN_OR_RETURN(
2737         HloInstructionProto instr,
2738         ReduceWindowInternal(operands, init_values, computation,
2739                              window_dimensions, window_strides, base_dilations,
2740                              window_dilations, padding));
2741     std::vector<XlaOp> args;
2742     args.insert(args.end(), operands.begin(), operands.end());
2743     args.insert(args.end(), init_values.begin(), init_values.end());
2744     return AddInstruction(std::move(instr), HloOpcode::kReduceWindow, args);
2745   });
2746 }
2747 
ReduceWindowInternal(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,absl::Span<const int64_t> base_dilations,absl::Span<const int64_t> window_dilations,absl::Span<const std::pair<int64_t,int64_t>> padding)2748 StatusOr<HloInstructionProto> XlaBuilder::ReduceWindowInternal(
2749     absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
2750     const XlaComputation& computation,
2751     absl::Span<const int64_t> window_dimensions,
2752     absl::Span<const int64_t> window_strides,
2753     absl::Span<const int64_t> base_dilations,
2754     absl::Span<const int64_t> window_dilations,
2755     absl::Span<const std::pair<int64_t, int64_t>> padding) {
2756   std::vector<const Shape*> operand_shapes, init_shapes;
2757   for (int i = 0; i < operands.size(); ++i) {
2758     const auto& operand = operands[i];
2759     const auto& init_value = init_values[i];
2760     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2761     operand_shapes.push_back(operand_shape);
2762     TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
2763     init_shapes.push_back(init_shape);
2764   }
2765   TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2766                       computation.GetProgramShape());
2767   TF_ASSIGN_OR_RETURN(auto window,
2768                       ShapeInference::InferWindowFromDimensions(
2769                           window_dimensions, window_strides, padding,
2770                           /*lhs_dilation=*/base_dilations,
2771                           /*rhs_dilation=*/window_dilations));
2772   TF_ASSIGN_OR_RETURN(Shape shape,
2773                       ShapeInference::InferReduceWindowShape(
2774                           absl::MakeSpan(operand_shapes),
2775                           absl::MakeSpan(init_shapes), window, to_apply_shape));
2776   HloInstructionProto instr;
2777   *instr.mutable_shape() = shape.ToProto();
2778   *instr.mutable_window() = std::move(window);
2779   AddCalledComputation(computation, &instr);
2780   return instr;
2781 }
2782 
ReduceWindowInternal(const Shape & shape,XlaOp operand,XlaOp init_value,const XlaComputation & computation,Window window)2783 StatusOr<XlaOp> XlaBuilder::ReduceWindowInternal(
2784     const Shape& shape, XlaOp operand, XlaOp init_value,
2785     const XlaComputation& computation, Window window) {
2786   HloInstructionProto instr;
2787   *instr.mutable_shape() = shape.ToProto();
2788   *instr.mutable_window() = std::move(window);
2789 
2790   AddCalledComputation(computation, &instr);
2791   return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
2792                         {operand, init_value});
2793 }
2794 
BatchNormTraining(XlaOp operand,XlaOp scale,XlaOp offset,float epsilon,int64_t feature_index)2795 XlaOp XlaBuilder::BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
2796                                     float epsilon, int64_t feature_index) {
2797   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2798     HloInstructionProto instr;
2799 
2800     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2801     TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2802     TF_ASSIGN_OR_RETURN(const Shape* offset_shape, GetShapePtr(offset));
2803     TF_ASSIGN_OR_RETURN(
2804         Shape shape,
2805         ShapeInference::InferBatchNormTrainingShape(
2806             *operand_shape, *scale_shape, *offset_shape, feature_index));
2807     *instr.mutable_shape() = shape.ToProto();
2808 
2809     instr.set_epsilon(epsilon);
2810     instr.set_feature_index(feature_index);
2811 
2812     return AddInstruction(std::move(instr), HloOpcode::kBatchNormTraining,
2813                           {operand, scale, offset});
2814   });
2815 }
2816 
BatchNormInference(XlaOp operand,XlaOp scale,XlaOp offset,XlaOp mean,XlaOp variance,float epsilon,int64_t feature_index)2817 XlaOp XlaBuilder::BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset,
2818                                      XlaOp mean, XlaOp variance, float epsilon,
2819                                      int64_t feature_index) {
2820   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2821     HloInstructionProto instr;
2822 
2823     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2824     TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2825     TF_ASSIGN_OR_RETURN(const Shape* offset_shape, GetShapePtr(offset));
2826     TF_ASSIGN_OR_RETURN(const Shape* mean_shape, GetShapePtr(mean));
2827     TF_ASSIGN_OR_RETURN(const Shape* variance_shape, GetShapePtr(variance));
2828     TF_ASSIGN_OR_RETURN(Shape shape,
2829                         ShapeInference::InferBatchNormInferenceShape(
2830                             *operand_shape, *scale_shape, *offset_shape,
2831                             *mean_shape, *variance_shape, feature_index));
2832     *instr.mutable_shape() = shape.ToProto();
2833 
2834     instr.set_epsilon(epsilon);
2835     instr.set_feature_index(feature_index);
2836 
2837     return AddInstruction(std::move(instr), HloOpcode::kBatchNormInference,
2838                           {operand, scale, offset, mean, variance});
2839   });
2840 }
2841 
BatchNormGrad(XlaOp operand,XlaOp scale,XlaOp batch_mean,XlaOp batch_var,XlaOp grad_output,float epsilon,int64_t feature_index)2842 XlaOp XlaBuilder::BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
2843                                 XlaOp batch_var, XlaOp grad_output,
2844                                 float epsilon, int64_t feature_index) {
2845   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2846     HloInstructionProto instr;
2847 
2848     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2849     TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2850     TF_ASSIGN_OR_RETURN(const Shape* batch_mean_shape, GetShapePtr(batch_mean));
2851     TF_ASSIGN_OR_RETURN(const Shape* batch_var_shape, GetShapePtr(batch_var));
2852     TF_ASSIGN_OR_RETURN(const Shape* grad_output_shape,
2853                         GetShapePtr(grad_output));
2854     TF_ASSIGN_OR_RETURN(
2855         Shape shape, ShapeInference::InferBatchNormGradShape(
2856                          *operand_shape, *scale_shape, *batch_mean_shape,
2857                          *batch_var_shape, *grad_output_shape, feature_index));
2858     *instr.mutable_shape() = shape.ToProto();
2859 
2860     instr.set_epsilon(epsilon);
2861     instr.set_feature_index(feature_index);
2862 
2863     return AddInstruction(std::move(instr), HloOpcode::kBatchNormGrad,
2864                           {operand, scale, batch_mean, batch_var, grad_output});
2865   });
2866 }
2867 
AllGather(XlaOp operand,int64_t all_gather_dimension,int64_t shard_count,absl::Span<const ReplicaGroup> replica_groups,const std::optional<ChannelHandle> & channel_id,const std::optional<Layout> & layout,const std::optional<bool> use_global_device_ids)2868 XlaOp XlaBuilder::AllGather(XlaOp operand, int64_t all_gather_dimension,
2869                             int64_t shard_count,
2870                             absl::Span<const ReplicaGroup> replica_groups,
2871                             const std::optional<ChannelHandle>& channel_id,
2872                             const std::optional<Layout>& layout,
2873                             const std::optional<bool> use_global_device_ids) {
2874   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2875     HloInstructionProto instr;
2876     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2877 
2878     TF_ASSIGN_OR_RETURN(
2879         Shape inferred_shape,
2880         ShapeInference::InferAllGatherShape({operand_shape},
2881                                             all_gather_dimension, shard_count));
2882     if (layout) {
2883       *inferred_shape.mutable_layout() = *layout;
2884       instr.set_constrain_layout(true);
2885     }
2886     *instr.mutable_shape() = inferred_shape.ToProto();
2887 
2888     instr.add_dimensions(all_gather_dimension);
2889     for (const ReplicaGroup& group : replica_groups) {
2890       *instr.add_replica_groups() = group;
2891     }
2892     if (channel_id.has_value()) {
2893       instr.set_channel_id(channel_id->handle());
2894     }
2895     if (use_global_device_ids.has_value()) {
2896       instr.set_use_global_device_ids(use_global_device_ids.value());
2897     }
2898 
2899     TF_ASSIGN_OR_RETURN(
2900         auto all_gather,
2901         AddInstruction(std::move(instr), HloOpcode::kAllGather, {operand}));
2902     return all_gather;
2903   });
2904 }
2905 
CrossReplicaSum(XlaOp operand,absl::Span<const ReplicaGroup> replica_groups)2906 XlaOp XlaBuilder::CrossReplicaSum(
2907     XlaOp operand, absl::Span<const ReplicaGroup> replica_groups) {
2908   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2909     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
2910     const Shape* element_shape;
2911     if (shape->IsTuple()) {
2912       if (shape->tuple_shapes_size() == 0) {
2913         return Unimplemented(
2914             "0 element tuple CrossReplicaSum is not supported");
2915       }
2916       element_shape = &shape->tuple_shapes(0);
2917     } else {
2918       element_shape = shape;
2919     }
2920     const Shape scalar_shape =
2921         ShapeUtil::MakeShape(element_shape->element_type(), {});
2922     auto b = CreateSubBuilder("sum");
2923     auto x = b->Parameter(/*parameter_number=*/0, scalar_shape, "x");
2924     auto y = b->Parameter(/*parameter_number=*/1, scalar_shape, "y");
2925     if (scalar_shape.element_type() == PRED) {
2926       Or(x, y);
2927     } else {
2928       Add(x, y);
2929     }
2930     TF_ASSIGN_OR_RETURN(auto computation, b->Build());
2931     return AllReduce(operand, computation, replica_groups,
2932                      /*channel_id=*/std::nullopt);
2933   });
2934 }
2935 
AllReduce(XlaOp operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const std::optional<ChannelHandle> & channel_id,const std::optional<Shape> & shape_with_layout,const std::optional<bool> use_global_device_ids)2936 XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
2937                             absl::Span<const ReplicaGroup> replica_groups,
2938                             const std::optional<ChannelHandle>& channel_id,
2939                             const std::optional<Shape>& shape_with_layout,
2940                             const std::optional<bool> use_global_device_ids) {
2941   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2942     HloInstructionProto instr;
2943     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2944     std::vector<const Shape*> operand_shapes;
2945     std::vector<XlaOp> operands;
2946     if (operand_shape->IsTuple()) {
2947       if (operand_shape->tuple_shapes_size() == 0) {
2948         return Unimplemented("0 element tuple AllReduce is not supported");
2949       }
2950       for (int i = 0; i < operand_shape->tuple_shapes_size(); ++i) {
2951         if (operand_shape->tuple_shapes(i).element_type() !=
2952             operand_shape->tuple_shapes(0).element_type()) {
2953           return Unimplemented(
2954               "All the shapes of a tuple input of AllReduce must have the same "
2955               "element type");
2956         }
2957         operand_shapes.push_back(&operand_shape->tuple_shapes(i));
2958         operands.push_back(GetTupleElement(operand, i));
2959       }
2960     } else {
2961       operand_shapes.push_back(operand_shape);
2962       operands.push_back(operand);
2963     }
2964 
2965     TF_ASSIGN_OR_RETURN(Shape inferred_shape,
2966                         ShapeInference::InferAllReduceShape(operand_shapes));
2967     if (shape_with_layout) {
2968       if (!LayoutUtil::HasLayout(*shape_with_layout)) {
2969         return InvalidArgument("shape_with_layout must have the layout set: %s",
2970                                shape_with_layout->ToString());
2971       }
2972       if (!ShapeUtil::Compatible(*shape_with_layout, *operand_shape)) {
2973         return InvalidArgument(
2974             "Provided shape_with_layout must be compatible with the "
2975             "operand shape: %s vs %s",
2976             shape_with_layout->ToString(), operand_shape->ToString());
2977       }
2978       instr.set_constrain_layout(true);
2979       if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
2980         // For a single-element tuple, take the tuple element shape.
2981         TF_RET_CHECK(shape_with_layout->tuple_shapes_size() == 1);
2982         *instr.mutable_shape() = shape_with_layout->tuple_shapes(0).ToProto();
2983       } else {
2984         *instr.mutable_shape() = shape_with_layout->ToProto();
2985       }
2986     } else {
2987       *instr.mutable_shape() = inferred_shape.ToProto();
2988     }
2989 
2990     for (const ReplicaGroup& group : replica_groups) {
2991       *instr.add_replica_groups() = group;
2992     }
2993 
2994     if (channel_id.has_value()) {
2995       instr.set_channel_id(channel_id->handle());
2996     }
2997 
2998     if (use_global_device_ids.has_value()) {
2999       instr.set_use_global_device_ids(*use_global_device_ids);
3000     }
3001 
3002     AddCalledComputation(computation, &instr);
3003 
3004     TF_ASSIGN_OR_RETURN(
3005         auto all_reduce,
3006         AddInstruction(std::move(instr), HloOpcode::kAllReduce, operands));
3007     if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
3008       // For a single-element tuple, wrap the result into a tuple.
3009       TF_RET_CHECK(operand_shapes.size() == 1);
3010       TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], inferred_shape));
3011       return Tuple({all_reduce});
3012     }
3013     return all_reduce;
3014   });
3015 }
3016 
ReduceScatter(XlaOp operand,const XlaComputation & computation,int64_t scatter_dimension,int64_t shard_count,absl::Span<const ReplicaGroup> replica_groups,const std::optional<ChannelHandle> & channel_id,const std::optional<Layout> & layout,const std::optional<bool> use_global_device_ids)3017 XlaOp XlaBuilder::ReduceScatter(
3018     XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension,
3019     int64_t shard_count, absl::Span<const ReplicaGroup> replica_groups,
3020     const std::optional<ChannelHandle>& channel_id,
3021     const std::optional<Layout>& layout,
3022     const std::optional<bool> use_global_device_ids) {
3023   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3024     HloInstructionProto instr;
3025     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3026     std::vector<const Shape*> operand_shapes;
3027     std::vector<XlaOp> operands;
3028     if (operand_shape->IsTuple()) {
3029       if (operand_shape->tuple_shapes_size() == 0) {
3030         return Unimplemented("0 element tuple ReduceScatter is not supported");
3031       }
3032       for (int i = 0; i < operand_shape->tuple_shapes_size(); ++i) {
3033         if (operand_shape->tuple_shapes(i).element_type() !=
3034             operand_shape->tuple_shapes(0).element_type()) {
3035           return Unimplemented(
3036               "All the shapes of a tuple input of ReduceScatter must have "
3037               "the same "
3038               "element type");
3039         }
3040         operand_shapes.push_back(&operand_shape->tuple_shapes(i));
3041         operands.push_back(GetTupleElement(operand, i));
3042       }
3043     } else {
3044       operand_shapes.push_back(operand_shape);
3045       operands.push_back(operand);
3046     }
3047 
3048     TF_ASSIGN_OR_RETURN(Shape inferred_shape,
3049                         ShapeInference::InferReduceScatterShape(
3050                             operand_shapes, scatter_dimension, shard_count));
3051     if (layout) {
3052       *inferred_shape.mutable_layout() = *layout;
3053       instr.set_constrain_layout(true);
3054     }
3055     *instr.mutable_shape() = inferred_shape.ToProto();
3056 
3057     AddCalledComputation(computation, &instr);
3058 
3059     instr.add_dimensions(scatter_dimension);
3060     for (const ReplicaGroup& group : replica_groups) {
3061       *instr.add_replica_groups() = group;
3062     }
3063     if (channel_id.has_value()) {
3064       instr.set_channel_id(channel_id->handle());
3065     }
3066     if (use_global_device_ids.has_value()) {
3067       instr.set_use_global_device_ids(use_global_device_ids.value());
3068     }
3069 
3070     TF_ASSIGN_OR_RETURN(
3071         auto reduce_scatter,
3072         AddInstruction(std::move(instr), HloOpcode::kReduceScatter, {operand}));
3073     return reduce_scatter;
3074   });
3075 }
3076 
AllToAll(XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups,const std::optional<Layout> & layout)3077 XlaOp XlaBuilder::AllToAll(XlaOp operand, int64_t split_dimension,
3078                            int64_t concat_dimension, int64_t split_count,
3079                            absl::Span<const ReplicaGroup> replica_groups,
3080                            const std::optional<Layout>& layout) {
3081   // Array all_to_all may need to violate layout constraint to be legal so use
3082   // the tuple version.
3083   if (layout.has_value()) {
3084     return AllToAllTuple(operand, split_dimension, concat_dimension,
3085                          split_count, replica_groups, layout);
3086   }
3087   return AllToAllArray(operand, split_dimension, concat_dimension, split_count,
3088                        replica_groups);
3089 }
3090 
AllToAllArray(XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups)3091 XlaOp XlaBuilder::AllToAllArray(XlaOp operand, int64_t split_dimension,
3092                                 int64_t concat_dimension, int64_t split_count,
3093                                 absl::Span<const ReplicaGroup> replica_groups) {
3094   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3095     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3096     TF_ASSIGN_OR_RETURN(
3097         const Shape all_to_all_shape,
3098         ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
3099                                            concat_dimension, split_count));
3100     HloInstructionProto instr;
3101     *instr.mutable_shape() = operand_shape->ToProto();
3102     if (replica_groups.empty()) {
3103       auto* group = instr.add_replica_groups();
3104       for (int64_t i = 0; i < split_count; ++i) {
3105         group->add_replica_ids(i);
3106       }
3107     } else {
3108       for (const ReplicaGroup& group : replica_groups) {
3109         *instr.add_replica_groups() = group;
3110       }
3111     }
3112     instr.add_dimensions(split_dimension);
3113     TF_ASSIGN_OR_RETURN(
3114         XlaOp all_to_all,
3115         AddInstruction(std::move(instr), HloOpcode::kAllToAll, {operand}));
3116     if (split_dimension == concat_dimension) {
3117       return all_to_all;
3118     }
3119     DimensionVector sizes;
3120     for (int64_t i = 0; i < operand_shape->rank(); ++i) {
3121       if (i != split_dimension) {
3122         sizes.push_back(operand_shape->dimensions(i));
3123         continue;
3124       }
3125       sizes.push_back(split_count);
3126       sizes.push_back(operand_shape->dimensions(i) / split_count);
3127     }
3128     all_to_all = Reshape(all_to_all, sizes);
3129 
3130     std::vector<int64_t> permutation;
3131     const auto rank = operand_shape->rank();
3132     permutation.reserve(rank + 1);
3133     for (int64_t i = 0; i < rank; ++i) {
3134       int64_t dim_after_reshape = i >= split_dimension ? i + 1 : i;
3135       if (i == concat_dimension) {
3136         permutation.push_back(split_dimension);
3137       }
3138       permutation.push_back(dim_after_reshape);
3139     }
3140     all_to_all = Transpose(all_to_all, permutation);
3141     return Reshape(all_to_all_shape, all_to_all);
3142   });
3143 }
3144 
AllToAllTuple(absl::Span<const XlaOp> operands,absl::Span<const ReplicaGroup> replica_groups,const std::optional<Layout> & layout)3145 XlaOp XlaBuilder::AllToAllTuple(absl::Span<const XlaOp> operands,
3146                                 absl::Span<const ReplicaGroup> replica_groups,
3147                                 const std::optional<Layout>& layout) {
3148   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3149     HloInstructionProto instr;
3150     TF_ASSIGN_OR_RETURN(auto operand_shapes, this->GetOperandShapes(operands));
3151     std::vector<const Shape*> operand_shape_ptrs;
3152     operand_shape_ptrs.reserve(operand_shapes.size());
3153     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
3154                       [](const Shape& shape) { return &shape; });
3155     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferAllToAllTupleShape(
3156                                          operand_shape_ptrs));
3157 
3158     if (layout) {
3159       TF_RET_CHECK(shape.IsTuple() && !ShapeUtil::IsNestedTuple(shape));
3160       for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
3161         const int64_t layout_minor_to_major_size =
3162             layout->minor_to_major().size();
3163         if (layout_minor_to_major_size != shape.tuple_shapes(i).rank()) {
3164           return InvalidArgument(
3165               "Provided layout must be compatible with the operands' shape. "
3166               "The layout is %s, but operand %d has shape %s.",
3167               layout->ToString(), i, shape.tuple_shapes(i).ToString());
3168         }
3169         *(shape.mutable_tuple_shapes(i)->mutable_layout()) = *layout;
3170       }
3171       instr.set_constrain_layout(true);
3172     }
3173     *instr.mutable_shape() = shape.ToProto();
3174 
3175     for (const ReplicaGroup& group : replica_groups) {
3176       *instr.add_replica_groups() = group;
3177     }
3178 
3179     return AddInstruction(std::move(instr), HloOpcode::kAllToAll, operands);
3180   });
3181 }
3182 
AllToAllTuple(XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups,const std::optional<Layout> & layout)3183 XlaOp XlaBuilder::AllToAllTuple(XlaOp operand, int64_t split_dimension,
3184                                 int64_t concat_dimension, int64_t split_count,
3185                                 absl::Span<const ReplicaGroup> replica_groups,
3186                                 const std::optional<Layout>& layout) {
3187   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3188     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3189 
3190     // The HloInstruction for Alltoall currently only handles the data
3191     // communication: it accepts N already split parts and scatters them to N
3192     // cores, and each core gathers the N received parts into a tuple as the
3193     // output. So here we explicitly split the operand before the hlo alltoall,
3194     // and concat the tuple elements.
3195     //
3196     // First, run shape inference to make sure the shapes are valid.
3197     TF_RETURN_IF_ERROR(
3198         ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
3199                                            concat_dimension, split_count)
3200             .status());
3201 
3202     // Split into N parts.
3203     std::vector<XlaOp> slices;
3204     slices.reserve(split_count);
3205     const int64_t block_size =
3206         operand_shape->dimensions(split_dimension) / split_count;
3207     for (int i = 0; i < split_count; i++) {
3208       slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size,
3209                                   /*limit_index=*/(i + 1) * block_size,
3210                                   /*stride=*/1, /*dimno=*/split_dimension));
3211     }
3212 
3213     // Handle data communication.
3214     XlaOp alltoall = this->AllToAllTuple(slices, replica_groups, layout);
3215 
3216     // Concat the N received parts.
3217     std::vector<XlaOp> received;
3218     received.reserve(split_count);
3219     for (int i = 0; i < split_count; i++) {
3220       received.push_back(this->GetTupleElement(alltoall, i));
3221     }
3222     return this->ConcatInDim(received, concat_dimension);
3223   });
3224 }
3225 
CollectivePermute(XlaOp operand,const std::vector<std::pair<int64_t,int64_t>> & source_target_pairs)3226 XlaOp XlaBuilder::CollectivePermute(
3227     XlaOp operand,
3228     const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs) {
3229   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3230     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3231     HloInstructionProto instr;
3232     TF_ASSIGN_OR_RETURN(
3233         Shape shape,
3234         ShapeInference::InferCollectivePermuteShape({operand_shape}));
3235     *instr.mutable_shape() = shape.ToProto();
3236 
3237     for (const auto& pair : source_target_pairs) {
3238       auto* proto_pair = instr.add_source_target_pairs();
3239       proto_pair->set_source(pair.first);
3240       proto_pair->set_target(pair.second);
3241     }
3242 
3243     return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute,
3244                           {operand});
3245   });
3246 }
3247 
ReplicaId()3248 XlaOp XlaBuilder::ReplicaId() {
3249   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3250     HloInstructionProto instr;
3251     *instr.mutable_shape() = ShapeUtil::MakeShape(U32, {}).ToProto();
3252     return AddInstruction(std::move(instr), HloOpcode::kReplicaId, {});
3253   });
3254 }
3255 
SelectAndScatter(XlaOp operand,const XlaComputation & select,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,Padding padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)3256 XlaOp XlaBuilder::SelectAndScatter(XlaOp operand, const XlaComputation& select,
3257                                    absl::Span<const int64_t> window_dimensions,
3258                                    absl::Span<const int64_t> window_strides,
3259                                    Padding padding, XlaOp source,
3260                                    XlaOp init_value,
3261                                    const XlaComputation& scatter) {
3262   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3263     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3264 
3265     std::vector<std::pair<int64_t, int64_t>> padding_values =
3266         MakePadding(operand_shape->dimensions(), window_dimensions,
3267                     window_strides, padding);
3268 
3269     TF_ASSIGN_OR_RETURN(auto window,
3270                         ShapeInference::InferWindowFromDimensions(
3271                             window_dimensions, window_strides, padding_values,
3272                             /*lhs_dilation=*/{},
3273                             /*rhs_dilation=*/{}));
3274     PaddingType padding_type = PADDING_INVALID;
3275     for (int64_t i = 0; i < operand_shape->rank(); ++i) {
3276       if (operand_shape->is_dynamic_dimension(i) &&
3277           !window_util::IsTrivialWindowDimension(window.dimensions(i)) &&
3278           padding == Padding::kSame) {
3279         // SAME padding can create dynamic padding sizes. The padding size
3280         // need to be rewritten by dynamic padder using HloInstructions. We
3281         // create a CustomCall to handle this.
3282         padding_type = PADDING_SAME;
3283       }
3284     }
3285     if (padding_type == PADDING_SAME) {
3286       TF_ASSIGN_OR_RETURN(
3287           HloInstructionProto instr,
3288           SelectAndScatterInternal(operand, select, window_dimensions,
3289                                    window_strides, padding_values, source,
3290                                    init_value, scatter));
3291       instr.set_custom_call_target("DynamicSelectAndScatterSamePadding");
3292       return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
3293                             {operand, source, init_value});
3294     }
3295     return SelectAndScatterWithGeneralPadding(
3296         operand, select, window_dimensions, window_strides, padding_values,
3297         source, init_value, scatter);
3298   });
3299 }
3300 
SelectAndScatterInternal(XlaOp operand,const XlaComputation & select,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)3301 StatusOr<HloInstructionProto> XlaBuilder::SelectAndScatterInternal(
3302     XlaOp operand, const XlaComputation& select,
3303     absl::Span<const int64_t> window_dimensions,
3304     absl::Span<const int64_t> window_strides,
3305     absl::Span<const std::pair<int64_t, int64_t>> padding, XlaOp source,
3306     XlaOp init_value, const XlaComputation& scatter) {
3307   HloInstructionProto instr;
3308 
3309   TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3310   TF_ASSIGN_OR_RETURN(const Shape* source_shape, GetShapePtr(source));
3311   TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
3312   TF_ASSIGN_OR_RETURN(const ProgramShape& select_shape,
3313                       select.GetProgramShape());
3314   TF_ASSIGN_OR_RETURN(const ProgramShape& scatter_shape,
3315                       scatter.GetProgramShape());
3316   TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
3317                       ShapeInference::InferWindowFromDimensions(
3318                           window_dimensions, window_strides, padding,
3319                           /*lhs_dilation=*/{}, /*rhs_dilation=*/{}));
3320   TF_ASSIGN_OR_RETURN(Shape shape,
3321                       ShapeInference::InferSelectAndScatterShape(
3322                           *operand_shape, select_shape, instr.window(),
3323                           *source_shape, *init_shape, scatter_shape));
3324   *instr.mutable_shape() = shape.ToProto();
3325 
3326   AddCalledComputation(select, &instr);
3327   AddCalledComputation(scatter, &instr);
3328   return instr;
3329 }
3330 
SelectAndScatterWithGeneralPadding(XlaOp operand,const XlaComputation & select,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)3331 XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
3332     XlaOp operand, const XlaComputation& select,
3333     absl::Span<const int64_t> window_dimensions,
3334     absl::Span<const int64_t> window_strides,
3335     absl::Span<const std::pair<int64_t, int64_t>> padding, XlaOp source,
3336     XlaOp init_value, const XlaComputation& scatter) {
3337   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3338     TF_ASSIGN_OR_RETURN(HloInstructionProto instr,
3339                         SelectAndScatterInternal(
3340                             operand, select, window_dimensions, window_strides,
3341                             padding, source, init_value, scatter));
3342 
3343     return AddInstruction(std::move(instr), HloOpcode::kSelectAndScatter,
3344                           {operand, source, init_value});
3345   });
3346 }
3347 
ReducePrecision(XlaOp operand,const int exponent_bits,const int mantissa_bits)3348 XlaOp XlaBuilder::ReducePrecision(XlaOp operand, const int exponent_bits,
3349                                   const int mantissa_bits) {
3350   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3351     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3352     TF_ASSIGN_OR_RETURN(Shape shape,
3353                         ShapeInference::InferReducePrecisionShape(
3354                             *operand_shape, exponent_bits, mantissa_bits));
3355     return ReducePrecisionInternal(shape, operand, exponent_bits,
3356                                    mantissa_bits);
3357   });
3358 }
3359 
ReducePrecisionInternal(const Shape & shape,XlaOp operand,const int exponent_bits,const int mantissa_bits)3360 StatusOr<XlaOp> XlaBuilder::ReducePrecisionInternal(const Shape& shape,
3361                                                     XlaOp operand,
3362                                                     const int exponent_bits,
3363                                                     const int mantissa_bits) {
3364   HloInstructionProto instr;
3365   *instr.mutable_shape() = shape.ToProto();
3366   instr.set_exponent_bits(exponent_bits);
3367   instr.set_mantissa_bits(mantissa_bits);
3368   return AddInstruction(std::move(instr), HloOpcode::kReducePrecision,
3369                         {operand});
3370 }
3371 
Send(XlaOp operand,const ChannelHandle & handle)3372 void XlaBuilder::Send(XlaOp operand, const ChannelHandle& handle) {
3373   ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3374     // Send HLO takes two operands: a data operand and a token. Generate the
3375     // token to pass into the send.
3376     // TODO(b/80000000): Remove this when clients have been updated to handle
3377     // tokens.
3378     HloInstructionProto token_instr;
3379     *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3380     TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
3381                                                     HloOpcode::kAfterAll, {}));
3382 
3383     return SendWithToken(operand, token, handle);
3384   });
3385 }
3386 
SendWithToken(XlaOp operand,XlaOp token,const ChannelHandle & handle)3387 XlaOp XlaBuilder::SendWithToken(XlaOp operand, XlaOp token,
3388                                 const ChannelHandle& handle) {
3389   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3390     if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
3391       return InvalidArgument("Send must use a device-to-device channel");
3392     }
3393 
3394     // Send instruction produces a tuple of {aliased operand, U32 context,
3395     // token}.
3396     HloInstructionProto send_instr;
3397     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
3398     *send_instr.mutable_shape() =
3399         ShapeUtil::MakeTupleShape({*shape, ShapeUtil::MakeShape(U32, {}),
3400                                    ShapeUtil::MakeTokenShape()})
3401             .ToProto();
3402     send_instr.set_channel_id(handle.handle());
3403     TF_ASSIGN_OR_RETURN(XlaOp send,
3404                         AddInstruction(std::move(send_instr), HloOpcode::kSend,
3405                                        {operand, token}));
3406 
3407     HloInstructionProto send_done_instr;
3408     *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3409     send_done_instr.set_channel_id(handle.handle());
3410     return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
3411                           {send});
3412   });
3413 }
3414 
Recv(const Shape & shape,const ChannelHandle & handle)3415 XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
3416   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3417     // Recv HLO takes a single token operand. Generate the token to pass into
3418     // the Recv and RecvDone instructions.
3419     // TODO(b/80000000): Remove this when clients have been updated to handle
3420     // tokens.
3421     HloInstructionProto token_instr;
3422     *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3423     TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
3424                                                     HloOpcode::kAfterAll, {}));
3425 
3426     XlaOp recv = RecvWithToken(token, shape, handle);
3427 
3428     // The RecvDone instruction produces a tuple of the data and a token
3429     // type. Return XLA op containing the data.
3430     // TODO(b/80000000): Remove this when clients have been updated to handle
3431     // tokens.
3432     HloInstructionProto recv_data;
3433     *recv_data.mutable_shape() = shape.ToProto();
3434     recv_data.set_tuple_index(0);
3435     return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement,
3436                           {recv});
3437   });
3438 }
3439 
RecvWithToken(XlaOp token,const Shape & shape,const ChannelHandle & handle)3440 XlaOp XlaBuilder::RecvWithToken(XlaOp token, const Shape& shape,
3441                                 const ChannelHandle& handle) {
3442   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3443     if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
3444       return InvalidArgument("Recv must use a device-to-device channel");
3445     }
3446 
3447     // Recv instruction produces a tuple of {receive buffer, U32 context,
3448     // token}.
3449     HloInstructionProto recv_instr;
3450     *recv_instr.mutable_shape() =
3451         ShapeUtil::MakeTupleShape(
3452             {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
3453             .ToProto();
3454     recv_instr.set_channel_id(handle.handle());
3455     TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
3456                                                    HloOpcode::kRecv, {token}));
3457 
3458     HloInstructionProto recv_done_instr;
3459     *recv_done_instr.mutable_shape() =
3460         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
3461             .ToProto();
3462     recv_done_instr.set_channel_id(handle.handle());
3463     return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
3464                           {recv});
3465   });
3466 }
3467 
SendToHost(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const ChannelHandle & handle)3468 XlaOp XlaBuilder::SendToHost(XlaOp operand, XlaOp token,
3469                              const Shape& shape_with_layout,
3470                              const ChannelHandle& handle) {
3471   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3472     if (!LayoutUtil::HasLayout(shape_with_layout)) {
3473       return InvalidArgument("Shape passed to SendToHost must have a layout");
3474     }
3475     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3476     if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
3477       return InvalidArgument(
3478           "SendToHost shape %s must be compatible with operand shape %s",
3479           ShapeUtil::HumanStringWithLayout(shape_with_layout),
3480           ShapeUtil::HumanStringWithLayout(*operand_shape));
3481     }
3482     // TODO(b/111544877): Support tuple shapes.
3483     if (!operand_shape->IsArray()) {
3484       return InvalidArgument("SendToHost only supports array shapes, shape: %s",
3485                              ShapeUtil::HumanString(*operand_shape));
3486     }
3487 
3488     if (handle.type() != ChannelHandle::DEVICE_TO_HOST) {
3489       return InvalidArgument("SendToHost must use a device-to-host channel");
3490     }
3491 
3492     // Send instruction produces a tuple of {aliased operand, U32 context,
3493     // token}.
3494     HloInstructionProto send_instr;
3495     *send_instr.mutable_shape() =
3496         ShapeUtil::MakeTupleShape({shape_with_layout,
3497                                    ShapeUtil::MakeShape(U32, {}),
3498                                    ShapeUtil::MakeTokenShape()})
3499             .ToProto();
3500     send_instr.set_channel_id(handle.handle());
3501     send_instr.set_is_host_transfer(true);
3502     TF_ASSIGN_OR_RETURN(XlaOp send,
3503                         AddInstruction(std::move(send_instr), HloOpcode::kSend,
3504                                        {operand, token}));
3505 
3506     HloInstructionProto send_done_instr;
3507     *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3508     send_done_instr.set_channel_id(handle.handle());
3509     send_done_instr.set_is_host_transfer(true);
3510     TF_ASSIGN_OR_RETURN(XlaOp send_done,
3511                         AddInstruction(std::move(send_done_instr),
3512                                        HloOpcode::kSendDone, {send}));
3513     return send_done;
3514   });
3515 }
3516 
RecvFromHost(XlaOp token,const Shape & shape,const ChannelHandle & handle)3517 XlaOp XlaBuilder::RecvFromHost(XlaOp token, const Shape& shape,
3518                                const ChannelHandle& handle) {
3519   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3520     if (!LayoutUtil::HasLayout(shape)) {
3521       return InvalidArgument("Shape passed to RecvFromHost must have a layout");
3522     }
3523 
3524     // TODO(b/111544877): Support tuple shapes.
3525     if (!shape.IsArray()) {
3526       return InvalidArgument(
3527           "RecvFromHost only supports array shapes, shape: %s",
3528           ShapeUtil::HumanString(shape));
3529     }
3530 
3531     if (handle.type() != ChannelHandle::HOST_TO_DEVICE) {
3532       return InvalidArgument("RecvFromHost must use a host-to-device channel");
3533     }
3534 
3535     // Recv instruction produces a tuple of {receive buffer, U32 context,
3536     // token}.
3537     HloInstructionProto recv_instr;
3538     *recv_instr.mutable_shape() =
3539         ShapeUtil::MakeTupleShape(
3540             {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
3541             .ToProto();
3542     recv_instr.set_channel_id(handle.handle());
3543     recv_instr.set_is_host_transfer(true);
3544     TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
3545                                                    HloOpcode::kRecv, {token}));
3546 
3547     HloInstructionProto recv_done_instr;
3548     *recv_done_instr.mutable_shape() =
3549         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
3550             .ToProto();
3551     recv_done_instr.set_channel_id(handle.handle());
3552     recv_done_instr.set_is_host_transfer(true);
3553     return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
3554                           {recv});
3555   });
3556 }
3557 
GetDimensionSize(XlaOp operand,int64_t dimension)3558 XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64_t dimension) {
3559   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3560     HloInstructionProto instr;
3561     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3562     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape(
3563                                          *operand_shape, dimension));
3564     // Calling GetDimensionSize on a static dimension returns a constant
3565     // instruction.
3566     if (!operand_shape->is_dynamic_dimension(dimension)) {
3567       return ConstantR0<int32_t>(this, operand_shape->dimensions(dimension));
3568     }
3569     *instr.mutable_shape() = shape.ToProto();
3570     instr.add_dimensions(dimension);
3571     return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize,
3572                           {operand});
3573   });
3574 }
3575 
RemoveDynamicDimension(XlaOp operand,int64_t dimension)3576 XlaOp XlaBuilder::RemoveDynamicDimension(XlaOp operand, int64_t dimension) {
3577   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3578     HloInstructionProto instr;
3579     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3580 
3581     Shape shape = *operand_shape;
3582     shape.set_dynamic_dimension(dimension, false);
3583     // Setting an op's dynamic dimension to its static size removes the dynamic
3584     // dimension.
3585     XlaOp static_size =
3586         ConstantR0<int32_t>(this, operand_shape->dimensions(dimension));
3587     return SetDimensionSizeInternal(shape, operand, static_size, dimension);
3588   });
3589 }
3590 
SetDimensionSize(XlaOp operand,XlaOp val,int64_t dimension)3591 XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val,
3592                                    int64_t dimension) {
3593   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3594     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3595     TF_ASSIGN_OR_RETURN(const Shape* val_shape, GetShapePtr(val));
3596 
3597     TF_ASSIGN_OR_RETURN(Shape shape,
3598                         ShapeInference::InferSetDimensionSizeShape(
3599                             *operand_shape, *val_shape, dimension));
3600     return SetDimensionSizeInternal(shape, operand, val, dimension);
3601   });
3602 }
3603 
SetDimensionSizeInternal(const Shape & shape,XlaOp operand,XlaOp val,int64_t dimension)3604 StatusOr<XlaOp> XlaBuilder::SetDimensionSizeInternal(const Shape& shape,
3605                                                      XlaOp operand, XlaOp val,
3606                                                      int64_t dimension) {
3607   TF_ASSIGN_OR_RETURN(const HloInstructionProto* val_proto,
3608                       LookUpInstruction(val));
3609   if (StringToHloOpcode(val_proto->opcode()).ValueOrDie() ==
3610           HloOpcode::kConstant &&
3611       shape.is_dynamic_dimension(dimension)) {
3612     TF_ASSIGN_OR_RETURN(auto constant_size,
3613                         Literal::CreateFromProto(val_proto->literal(), true));
3614     if (constant_size.Get<int32_t>({}) == shape.dimensions(dimension)) {
3615       return operand;
3616     }
3617   }
3618 
3619   HloInstructionProto instr;
3620   *instr.mutable_shape() = shape.ToProto();
3621   instr.add_dimensions(dimension);
3622   return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize,
3623                         {operand, val});
3624 }
3625 
IsConstant(XlaOp operand) const3626 StatusOr<bool> XlaBuilder::IsConstant(XlaOp operand) const {
3627   TF_RETURN_IF_ERROR(first_error_);
3628 
3629   // Verify that the handle is valid.
3630   TF_RETURN_IF_ERROR(LookUpInstruction(operand).status());
3631 
3632   bool is_constant = true;
3633   absl::flat_hash_set<int64_t> visited;
3634   IsConstantVisitor(operand.handle(), /*depth=*/0, &visited, &is_constant);
3635   return is_constant;
3636 }
3637 
BuildConstantSubGraph(XlaOp root_op,bool dynamic_dimension_is_minus_one)3638 StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
3639     XlaOp root_op, bool dynamic_dimension_is_minus_one) {
3640   TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
3641   if (!is_constant) {
3642     auto op_status = LookUpInstruction(root_op);
3643     std::string op_string =
3644         op_status.ok() ? op_status.ValueOrDie()->name() : "<unknown operation>";
3645     return InvalidArgument(
3646         "Operand to BuildConstantSubGraph depends on a parameter.\n\n"
3647         "  op requested for constant subgraph: %s\n\n"
3648         "This is an internal error that typically happens when the XLA user "
3649         "(e.g. TensorFlow) is attempting to determine a value that must be a "
3650         "compile-time constant (e.g. an array dimension) but it is not capable "
3651         "of being evaluated at XLA compile time.\n\n"
3652         "Please file a usability bug with the framework being used (e.g. "
3653         "TensorFlow).",
3654         op_string);
3655   }
3656 
3657   TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
3658                       LookUpInstruction(root_op));
3659   if (VLOG_IS_ON(4)) {
3660     VLOG(4) << "Build constant subgraph for:\n" << OpToString(root_op);
3661   }
3662 
3663   HloComputationProto entry;
3664   SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator,
3665                     GetNextId());
3666   ProgramShapeProto* program_shape = entry.mutable_program_shape();
3667   *program_shape->mutable_result() = root->shape();
3668 
3669   // We use std::set to keep the instruction ids in ascending order (which is
3670   // also a valid dependency order). The related ops will be added to the
3671   // subgraph in the same order.
3672   std::set<int64_t> related_ops;
3673   absl::flat_hash_map<int64_t, int64_t> substitutions;
3674   absl::flat_hash_set<int64_t> related_calls;  // Related computations.
3675   std::queue<int64_t> worklist;
3676   worklist.push(root->id());
3677   related_ops.insert(root->id());
3678 
3679   while (!worklist.empty()) {
3680     int64_t handle = worklist.front();
3681     worklist.pop();
3682     TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
3683                         LookUpInstructionByHandle(handle));
3684 
3685     auto default_behavior = [&related_ops, &worklist, &related_calls,
3686                              instr_proto]() {
3687       for (int64_t id : instr_proto->operand_ids()) {
3688         if (related_ops.insert(id).second) {
3689           worklist.push(id);
3690         }
3691       }
3692       for (int64_t called_id : instr_proto->called_computation_ids()) {
3693         related_calls.insert(called_id);
3694       }
3695     };
3696 
3697     if (instr_proto->opcode() ==
3698             HloOpcodeString(HloOpcode::kGetDimensionSize) ||
3699         InstrIsSetBound(instr_proto)) {
3700       int32_t constant_value = -1;
3701       HloInstructionProto const_instr;
3702 
3703       if (instr_proto->opcode() ==
3704           HloOpcodeString(HloOpcode::kGetDimensionSize)) {
3705         // At this point, BuildConstantSubGraph should never encounter a
3706         // GetDimensionSize with a dynamic dimension. IsConstant check would
3707         // have failed at the beginning of this function.
3708         //
3709         // Replace GetDimensionSize with a Constant representing the static
3710         // bound of the shape.
3711         int64_t dimension = instr_proto->dimensions(0);
3712         int64_t operand_handle = instr_proto->operand_ids(0);
3713         TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
3714                             LookUpInstructionByHandle(operand_handle));
3715 
3716         if (!(operand_proto->shape().is_dynamic_dimension(dimension) &&
3717               dynamic_dimension_is_minus_one)) {
3718           constant_value = static_cast<int32_t>(
3719               operand_proto->shape().dimensions(dimension));
3720         }
3721         Literal literal = LiteralUtil::CreateR0(constant_value);
3722         *const_instr.mutable_literal() = literal.ToProto();
3723         *const_instr.mutable_shape() = literal.shape().ToProto();
3724       } else {
3725         if (instr_proto->literal().shape().element_type() == TUPLE) {
3726           *const_instr.mutable_literal() =
3727               // First literal of SetBound contains bounds, second literal
3728               // contains dynamism indicators.
3729               instr_proto->literal().tuple_literals(0);
3730         } else {
3731           *const_instr.mutable_literal() = instr_proto->literal();
3732         }
3733 
3734         *const_instr.mutable_shape() = instr_proto->shape();
3735       }
3736       *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
3737       const_instr.set_id(handle);
3738       *const_instr.mutable_name() =
3739           GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id());
3740       *entry.add_instructions() =
3741           const_instr;  // Add to the result constant graph.
3742 
3743     } else if (instr_proto->opcode() ==
3744                HloOpcodeString(HloOpcode::kGetTupleElement)) {
3745       // Look through GTE(Tuple(..), i).
3746       TF_ASSIGN_OR_RETURN(
3747           const HloInstructionProto* maybe_tuple_instr,
3748           LookUpInstructionByHandle(instr_proto->operand_ids(0)));
3749 
3750       if (maybe_tuple_instr->opcode() == HloOpcodeString(HloOpcode::kTuple)) {
3751         int64_t id = maybe_tuple_instr->operand_ids(instr_proto->tuple_index());
3752         // Enqueue any dependencies of `id`.
3753         if (related_ops.insert(id).second) {
3754           worklist.push(id);
3755         }
3756         substitutions[handle] = id;
3757 
3758       } else {
3759         default_behavior();
3760       }
3761 
3762     } else {
3763       default_behavior();
3764     }
3765   }
3766 
3767   // Resolve any substitutions for the root id.
3768   int64_t root_id = root->id();
3769   auto it = substitutions.find(root_id);
3770   while (it != substitutions.end()) {
3771     root_id = it->second;
3772     it = substitutions.find(root_id);
3773   }
3774   entry.set_root_id(root_id);
3775 
3776   // Add related ops to the computation.
3777   for (int64_t id : related_ops) {
3778     if (substitutions.find(id) != substitutions.end()) {
3779       // Skip adding this instruction; we will replace references to it with the
3780       // substitution instruction's id.
3781       continue;
3782     }
3783     TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
3784                         LookUpInstructionByHandle(id));
3785 
3786     if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize) ||
3787         InstrIsSetBound(instr_src)) {
3788       continue;
3789     }
3790     HloInstructionProto* instr = entry.add_instructions();
3791     *instr = *instr_src;
3792     // Replace operands in case we have substitutions mapped.
3793     instr->clear_operand_ids();
3794     for (int64_t operand_id : instr_src->operand_ids()) {
3795       auto it = substitutions.find(operand_id);
3796       while (it != substitutions.end()) {
3797         operand_id = it->second;
3798         it = substitutions.find(operand_id);
3799       }
3800       instr->add_operand_ids(operand_id);
3801     }
3802     // Ensures that the instruction names are unique among the graph.
3803     const std::string& new_name =
3804         StrCat(instr->name(), ".", entry.id(), ".", instr->id());
3805     instr->set_name(new_name);
3806   }
3807 
3808   XlaComputation computation(entry.id());
3809   HloModuleProto* module = computation.mutable_proto();
3810   module->set_name(entry.name());
3811   module->set_id(entry.id());
3812   module->set_entry_computation_name(entry.name());
3813   module->set_entry_computation_id(entry.id());
3814   *module->mutable_host_program_shape() = *program_shape;
3815   for (auto& e : embedded_) {
3816     if (related_calls.find(e.second.id()) != related_calls.end()) {
3817       *module->add_computations() = e.second;
3818     }
3819   }
3820   *module->add_computations() = std::move(entry);
3821   if (VLOG_IS_ON(4)) {
3822     VLOG(4) << "Constant computation:\n" << module->DebugString();
3823   }
3824   return std::move(computation);
3825 }
3826 
CreateSubBuilder(const std::string & computation_name)3827 std::unique_ptr<XlaBuilder> XlaBuilder::CreateSubBuilder(
3828     const std::string& computation_name) {
3829   auto sub_builder = std::make_unique<XlaBuilder>(computation_name);
3830   sub_builder->parent_builder_ = this;
3831   sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_;
3832   return sub_builder;
3833 }
3834 
3835 /* static */ ConvolutionDimensionNumbers
CreateDefaultConvDimensionNumbers(int num_spatial_dims)3836 XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
3837   ConvolutionDimensionNumbers dimension_numbers;
3838   dimension_numbers.set_input_batch_dimension(kConvBatchDimension);
3839   dimension_numbers.set_input_feature_dimension(kConvFeatureDimension);
3840   dimension_numbers.set_output_batch_dimension(kConvBatchDimension);
3841   dimension_numbers.set_output_feature_dimension(kConvFeatureDimension);
3842   dimension_numbers.set_kernel_output_feature_dimension(
3843       kConvKernelOutputDimension);
3844   dimension_numbers.set_kernel_input_feature_dimension(
3845       kConvKernelInputDimension);
3846   for (int i = 0; i < num_spatial_dims; ++i) {
3847     dimension_numbers.add_input_spatial_dimensions(i + 2);
3848     dimension_numbers.add_kernel_spatial_dimensions(i + 2);
3849     dimension_numbers.add_output_spatial_dimensions(i + 2);
3850   }
3851   return dimension_numbers;
3852 }
3853 
Validate(const ConvolutionDimensionNumbers & dnum)3854 /* static */ Status XlaBuilder::Validate(
3855     const ConvolutionDimensionNumbers& dnum) {
3856   if (dnum.input_spatial_dimensions_size() < 2) {
3857     return FailedPrecondition("input spacial dimension < 2: %d",
3858                               dnum.input_spatial_dimensions_size());
3859   }
3860   if (dnum.kernel_spatial_dimensions_size() < 2) {
3861     return FailedPrecondition("kernel spacial dimension < 2: %d",
3862                               dnum.kernel_spatial_dimensions_size());
3863   }
3864   if (dnum.output_spatial_dimensions_size() < 2) {
3865     return FailedPrecondition("output spacial dimension < 2: %d",
3866                               dnum.output_spatial_dimensions_size());
3867   }
3868 
3869   if (std::set<int64_t>(
3870           {dnum.input_batch_dimension(), dnum.input_feature_dimension(),
3871            dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)})
3872           .size() != 4) {
3873     return FailedPrecondition(
3874         "dimension numbers for the input are not unique: (%d, %d, %d, "
3875         "%d)",
3876         dnum.input_batch_dimension(), dnum.input_feature_dimension(),
3877         dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1));
3878   }
3879   if (std::set<int64_t>({dnum.kernel_output_feature_dimension(),
3880                          dnum.kernel_input_feature_dimension(),
3881                          dnum.kernel_spatial_dimensions(0),
3882                          dnum.kernel_spatial_dimensions(1)})
3883           .size() != 4) {
3884     return FailedPrecondition(
3885         "dimension numbers for the weight are not unique: (%d, %d, %d, "
3886         "%d)",
3887         dnum.kernel_output_feature_dimension(),
3888         dnum.kernel_input_feature_dimension(),
3889         dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1));
3890   }
3891   if (std::set<int64_t>({dnum.output_batch_dimension(),
3892                          dnum.output_feature_dimension(),
3893                          dnum.output_spatial_dimensions(0),
3894                          dnum.output_spatial_dimensions(1)})
3895           .size() != 4) {
3896     return FailedPrecondition(
3897         "dimension numbers for the output are not unique: (%d, %d, %d, "
3898         "%d)",
3899         dnum.output_batch_dimension(), dnum.output_feature_dimension(),
3900         dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1));
3901   }
3902   return OkStatus();
3903 }
3904 
AddInstruction(HloInstructionProto && instr,HloOpcode opcode,absl::Span<const XlaOp> operands)3905 StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
3906                                            HloOpcode opcode,
3907                                            absl::Span<const XlaOp> operands) {
3908   TF_RETURN_IF_ERROR(first_error_);
3909 
3910   const int64_t handle = GetNextId();
3911   instr.set_id(handle);
3912   instr.set_opcode(HloOpcodeString(opcode));
3913   if (instr.name().empty()) {
3914     instr.set_name(instr.opcode());
3915   }
3916   for (const auto& operand : operands) {
3917     if (operand.builder_ == nullptr) {
3918       return InvalidArgument("invalid XlaOp with handle %d", operand.handle());
3919     }
3920     if (operand.builder_ != this) {
3921       return InvalidArgument("Do not add XlaOp from builder %s to builder %s",
3922                              operand.builder_->name(), this->name());
3923     }
3924     instr.add_operand_ids(operand.handle());
3925   }
3926 
3927   if (one_shot_metadata_.has_value()) {
3928     *instr.mutable_metadata() = one_shot_metadata_.value();
3929     one_shot_metadata_.reset();
3930   } else {
3931     *instr.mutable_metadata() = metadata_;
3932   }
3933   if (sharding_) {
3934     *instr.mutable_sharding() = *sharding_;
3935   }
3936   *instr.mutable_frontend_attributes() = frontend_attributes_;
3937 
3938   handle_to_index_[handle] = instructions_.size();
3939   instructions_.push_back(std::move(instr));
3940   instruction_shapes_.push_back(
3941       std::make_unique<Shape>(instructions_.back().shape()));
3942 
3943   XlaOp op(handle, this);
3944   return op;
3945 }
3946 
AddOpWithShape(HloOpcode opcode,const Shape & shape,absl::Span<const XlaOp> operands)3947 StatusOr<XlaOp> XlaBuilder::AddOpWithShape(HloOpcode opcode, const Shape& shape,
3948                                            absl::Span<const XlaOp> operands) {
3949   HloInstructionProto instr;
3950   *instr.mutable_shape() = shape.ToProto();
3951   return AddInstruction(std::move(instr), opcode, operands);
3952 }
3953 
AddCalledComputation(const XlaComputation & computation,HloInstructionProto * instr)3954 void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
3955                                       HloInstructionProto* instr) {
3956   absl::flat_hash_map<int64_t, int64_t> remapped_ids;
3957   std::vector<HloComputationProto> imported_computations;
3958   imported_computations.reserve(computation.proto().computations_size());
3959   // Before we import the computations by remapping IDs, and capturing the
3960   // old->new mappings in remapped_ids.
3961   for (const HloComputationProto& e : computation.proto().computations()) {
3962     HloComputationProto new_computation(e);
3963     int64_t computation_id = GetNextId();
3964     remapped_ids[new_computation.id()] = computation_id;
3965     SetProtoIdAndName(&new_computation,
3966                       GetBaseName(new_computation.name(), kNameSeparator),
3967                       kNameSeparator, computation_id);
3968     for (auto& instruction : *new_computation.mutable_instructions()) {
3969       int64_t instruction_id = GetNextId();
3970       remapped_ids[instruction.id()] = instruction_id;
3971       SetProtoIdAndName(&instruction,
3972                         GetBaseName(instruction.name(), kNameSeparator),
3973                         kNameSeparator, instruction_id);
3974     }
3975     new_computation.set_root_id(remapped_ids.at(new_computation.root_id()));
3976 
3977     imported_computations.push_back(std::move(new_computation));
3978   }
3979   // Once we have imported all the computations, and captured all the ID
3980   // mappings, we go back and fixup the IDs in the imported computations.
3981   instr->add_called_computation_ids(
3982       remapped_ids.at(computation.proto().entry_computation_id()));
3983   for (auto& imported_computation : imported_computations) {
3984     for (auto& instruction : *imported_computation.mutable_instructions()) {
3985       for (auto& operand_id : *instruction.mutable_operand_ids()) {
3986         operand_id = remapped_ids.at(operand_id);
3987       }
3988       for (auto& control_predecessor_id :
3989            *instruction.mutable_control_predecessor_ids()) {
3990         control_predecessor_id = remapped_ids.at(control_predecessor_id);
3991       }
3992       for (auto& called_computation_id :
3993            *instruction.mutable_called_computation_ids()) {
3994         called_computation_id = remapped_ids.at(called_computation_id);
3995       }
3996     }
3997 
3998     int64_t computation_id = imported_computation.id();
3999     for (int64_t i = 0; i < imported_computation.instructions_size(); ++i) {
4000       ImportedInstruction imported_instruction;
4001       imported_instruction.computation_id = computation_id;
4002       imported_instruction.instruction_index = i;
4003       handle_to_imported_index_.insert(
4004           {imported_computation.instructions(i).id(), imported_instruction});
4005     }
4006     embedded_.insert({computation_id, std::move(imported_computation)});
4007   }
4008 }
4009 
LookUpInstruction(const XlaOp op) const4010 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
4011     const XlaOp op) const {
4012   TF_RETURN_IF_ERROR(first_error_);
4013   return LookUpInstructionInternal<const HloInstructionProto*>(op);
4014 }
4015 
LookUpInstructionByHandle(int64_t handle) const4016 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
4017     int64_t handle) const {
4018   return LookUpInstructionByHandleInternal<const HloInstructionProto*>(handle);
4019 }
4020 
LookUpMutableInstruction(const XlaOp op)4021 StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstruction(
4022     const XlaOp op) {
4023   TF_RETURN_IF_ERROR(first_error_);
4024   return LookUpInstructionInternal<HloInstructionProto*>(op);
4025 }
4026 
LookUpMutableInstructionByHandle(int64_t handle)4027 StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstructionByHandle(
4028     int64_t handle) {
4029   return LookUpInstructionByHandleInternal<HloInstructionProto*>(handle);
4030 }
4031 
4032 // Enqueues a "retrieve parameter value" instruction for a parameter that was
4033 // passed to the computation.
Parameter(XlaBuilder * builder,int64_t parameter_number,const Shape & shape,const std::string & name)4034 XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
4035                 const Shape& shape, const std::string& name) {
4036   std::vector<bool> empty_bools;
4037   return Parameter(builder, parameter_number, shape, name, empty_bools);
4038 }
4039 
Parameter(XlaBuilder * builder,int64_t parameter_number,const Shape & shape,const std::string & name,const std::vector<bool> & replicated_at_leaf_buffers)4040 XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
4041                 const Shape& shape, const std::string& name,
4042                 const std::vector<bool>& replicated_at_leaf_buffers) {
4043   return builder->Parameter(parameter_number, shape, name,
4044                             replicated_at_leaf_buffers);
4045 }
4046 
4047 // Enqueues a constant with the value of the given literal onto the
4048 // computation.
ConstantLiteral(XlaBuilder * builder,const LiteralSlice & literal)4049 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) {
4050   return builder->ConstantLiteral(literal);
4051 }
4052 
Broadcast(const XlaOp operand,absl::Span<const int64_t> broadcast_sizes)4053 XlaOp Broadcast(const XlaOp operand,
4054                 absl::Span<const int64_t> broadcast_sizes) {
4055   return operand.builder()->Broadcast(operand, broadcast_sizes);
4056 }
4057 
BroadcastInDim(const XlaOp operand,const absl::Span<const int64_t> out_dim_size,const absl::Span<const int64_t> broadcast_dimensions)4058 XlaOp BroadcastInDim(const XlaOp operand,
4059                      const absl::Span<const int64_t> out_dim_size,
4060                      const absl::Span<const int64_t> broadcast_dimensions) {
4061   return operand.builder()->BroadcastInDim(operand, out_dim_size,
4062                                            broadcast_dimensions);
4063 }
4064 
Copy(const XlaOp operand)4065 XlaOp Copy(const XlaOp operand) {
4066   return operand.builder()->UnaryOp(HloOpcode::kCopy, operand);
4067 }
4068 
Pad(const XlaOp operand,const XlaOp padding_value,const PaddingConfig & padding_config)4069 XlaOp Pad(const XlaOp operand, const XlaOp padding_value,
4070           const PaddingConfig& padding_config) {
4071   return operand.builder()->Pad(operand, padding_value, padding_config);
4072 }
4073 
PadInDim(XlaOp operand,XlaOp padding_value,int64_t dimno,int64_t pad_lo,int64_t pad_hi)4074 XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
4075                int64_t pad_lo, int64_t pad_hi) {
4076   return operand.builder()->PadInDim(operand, padding_value, dimno, pad_lo,
4077                                      pad_hi);
4078 }
4079 
Reshape(const XlaOp operand,absl::Span<const int64_t> dimensions,absl::Span<const int64_t> new_sizes)4080 XlaOp Reshape(const XlaOp operand, absl::Span<const int64_t> dimensions,
4081               absl::Span<const int64_t> new_sizes) {
4082   return operand.builder()->Reshape(operand, dimensions, new_sizes);
4083 }
4084 
Reshape(const XlaOp operand,absl::Span<const int64_t> new_sizes)4085 XlaOp Reshape(const XlaOp operand, absl::Span<const int64_t> new_sizes) {
4086   return operand.builder()->Reshape(operand, new_sizes);
4087 }
4088 
Reshape(const Shape & shape,XlaOp operand)4089 XlaOp Reshape(const Shape& shape, XlaOp operand) {
4090   return operand.builder()->Reshape(shape, operand);
4091 }
4092 
DynamicReshape(XlaOp operand,absl::Span<const XlaOp> dim_sizes,absl::Span<const int64_t> new_size_bounds,const std::vector<bool> & dims_are_dynamic)4093 XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
4094                      absl::Span<const int64_t> new_size_bounds,
4095                      const std::vector<bool>& dims_are_dynamic) {
4096   return operand.builder()->DynamicReshape(operand, dim_sizes, new_size_bounds,
4097                                            dims_are_dynamic);
4098 }
4099 
ReshapeWithInferredDimension(XlaOp operand,absl::Span<const int64_t> new_sizes,int64_t inferred_dimension)4100 XlaOp ReshapeWithInferredDimension(XlaOp operand,
4101                                    absl::Span<const int64_t> new_sizes,
4102                                    int64_t inferred_dimension) {
4103   return operand.builder()->Reshape(operand, new_sizes, inferred_dimension);
4104 }
4105 
Collapse(const XlaOp operand,absl::Span<const int64_t> dimensions)4106 XlaOp Collapse(const XlaOp operand, absl::Span<const int64_t> dimensions) {
4107   return operand.builder()->Collapse(operand, dimensions);
4108 }
4109 
Slice(const XlaOp operand,absl::Span<const int64_t> start_indices,absl::Span<const int64_t> limit_indices,absl::Span<const int64_t> strides)4110 XlaOp Slice(const XlaOp operand, absl::Span<const int64_t> start_indices,
4111             absl::Span<const int64_t> limit_indices,
4112             absl::Span<const int64_t> strides) {
4113   return operand.builder()->Slice(operand, start_indices, limit_indices,
4114                                   strides);
4115 }
4116 
SliceInDim(const XlaOp operand,int64_t start_index,int64_t limit_index,int64_t stride,int64_t dimno)4117 XlaOp SliceInDim(const XlaOp operand, int64_t start_index, int64_t limit_index,
4118                  int64_t stride, int64_t dimno) {
4119   return operand.builder()->SliceInDim(operand, start_index, limit_index,
4120                                        stride, dimno);
4121 }
4122 
DynamicSlice(const XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64_t> slice_sizes)4123 XlaOp DynamicSlice(const XlaOp operand, absl::Span<const XlaOp> start_indices,
4124                    absl::Span<const int64_t> slice_sizes) {
4125   return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
4126 }
4127 
DynamicUpdateSlice(const XlaOp operand,const XlaOp update,absl::Span<const XlaOp> start_indices)4128 XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update,
4129                          absl::Span<const XlaOp> start_indices) {
4130   return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
4131 }
4132 
ConcatInDim(XlaBuilder * builder,absl::Span<const XlaOp> operands,int64_t dimension)4133 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
4134                   int64_t dimension) {
4135   return builder->ConcatInDim(operands, dimension);
4136 }
4137 
Select(const XlaOp pred,const XlaOp on_true,const XlaOp on_false)4138 XlaOp Select(const XlaOp pred, const XlaOp on_true, const XlaOp on_false) {
4139   return pred.builder()->Select(pred, on_true, on_false);
4140 }
4141 
Tuple(XlaBuilder * builder,absl::Span<const XlaOp> elements)4142 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements) {
4143   return builder->Tuple(elements);
4144 }
4145 
GetTupleElement(const XlaOp tuple_data,int64_t index)4146 XlaOp GetTupleElement(const XlaOp tuple_data, int64_t index) {
4147   return tuple_data.builder()->GetTupleElement(tuple_data, index);
4148 }
4149 
Eq(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4150 XlaOp Eq(const XlaOp lhs, const XlaOp rhs,
4151          absl::Span<const int64_t> broadcast_dimensions) {
4152   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
4153 }
4154 
CompareTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions,ComparisonDirection comparison_direction)4155 static XlaOp CompareTotalOrder(const XlaOp lhs, const XlaOp rhs,
4156                                absl::Span<const int64_t> broadcast_dimensions,
4157                                ComparisonDirection comparison_direction) {
4158   auto b = lhs.builder();
4159   return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
4160     TF_ASSIGN_OR_RETURN(auto operand_shape, b->GetShape(lhs));
4161     auto operand_element_type = operand_shape.element_type();
4162     auto compare_type =
4163         primitive_util::IsFloatingPointType(operand_element_type)
4164             ? Comparison::Type::kFloatTotalOrder
4165             : Comparison::DefaultComparisonType(operand_element_type);
4166     return Compare(lhs, rhs, broadcast_dimensions, comparison_direction,
4167                    compare_type);
4168   });
4169 }
4170 
EqTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4171 XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs,
4172                    absl::Span<const int64_t> broadcast_dimensions) {
4173   return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4174                            ComparisonDirection::kEq);
4175 }
4176 
Ne(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4177 XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
4178          absl::Span<const int64_t> broadcast_dimensions) {
4179   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe);
4180 }
4181 
NeTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4182 XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs,
4183                    absl::Span<const int64_t> broadcast_dimensions) {
4184   return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4185                            ComparisonDirection::kNe);
4186 }
4187 
Ge(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4188 XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
4189          absl::Span<const int64_t> broadcast_dimensions) {
4190   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe);
4191 }
4192 
GeTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4193 XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs,
4194                    absl::Span<const int64_t> broadcast_dimensions) {
4195   return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4196                            ComparisonDirection::kGe);
4197 }
4198 
Gt(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4199 XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
4200          absl::Span<const int64_t> broadcast_dimensions) {
4201   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt);
4202 }
4203 
GtTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4204 XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs,
4205                    absl::Span<const int64_t> broadcast_dimensions) {
4206   return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4207                            ComparisonDirection::kGt);
4208 }
4209 
Le(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4210 XlaOp Le(const XlaOp lhs, const XlaOp rhs,
4211          absl::Span<const int64_t> broadcast_dimensions) {
4212   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe);
4213 }
4214 
LeTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4215 XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs,
4216                    absl::Span<const int64_t> broadcast_dimensions) {
4217   return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4218                            ComparisonDirection::kLe);
4219 }
4220 
Lt(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4221 XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
4222          absl::Span<const int64_t> broadcast_dimensions) {
4223   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
4224 }
4225 
LtTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4226 XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs,
4227                    absl::Span<const int64_t> broadcast_dimensions) {
4228   return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4229                            ComparisonDirection::kLt);
4230 }
4231 
Compare(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions,ComparisonDirection direction)4232 XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
4233               absl::Span<const int64_t> broadcast_dimensions,
4234               ComparisonDirection direction) {
4235   return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
4236                                  broadcast_dimensions, direction);
4237 }
4238 
Compare(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions,ComparisonDirection direction,Comparison::Type compare_type)4239 XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
4240               absl::Span<const int64_t> broadcast_dimensions,
4241               ComparisonDirection direction, Comparison::Type compare_type) {
4242   return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
4243                                  broadcast_dimensions, direction, compare_type);
4244 }
4245 
Compare(const XlaOp lhs,const XlaOp rhs,ComparisonDirection direction)4246 XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) {
4247   return Compare(lhs, rhs, {}, direction);
4248 }
4249 
Dot(const XlaOp lhs,const XlaOp rhs,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)4250 XlaOp Dot(const XlaOp lhs, const XlaOp rhs,
4251           const PrecisionConfig* precision_config,
4252           std::optional<PrimitiveType> preferred_element_type) {
4253   return lhs.builder()->Dot(lhs, rhs, precision_config, preferred_element_type);
4254 }
4255 
DotGeneral(const XlaOp lhs,const XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)4256 XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs,
4257                  const DotDimensionNumbers& dimension_numbers,
4258                  const PrecisionConfig* precision_config,
4259                  std::optional<PrimitiveType> preferred_element_type) {
4260   return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
4261                                    precision_config, preferred_element_type);
4262 }
4263 
Conv(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> window_strides,Padding padding,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)4264 XlaOp Conv(const XlaOp lhs, const XlaOp rhs,
4265            absl::Span<const int64_t> window_strides, Padding padding,
4266            int64_t feature_group_count, int64_t batch_group_count,
4267            const PrecisionConfig* precision_config,
4268            std::optional<PrimitiveType> preferred_element_type) {
4269   return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
4270                              feature_group_count, batch_group_count,
4271                              precision_config, preferred_element_type);
4272 }
4273 
ConvWithGeneralPadding(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)4274 XlaOp ConvWithGeneralPadding(
4275     const XlaOp lhs, const XlaOp rhs, absl::Span<const int64_t> window_strides,
4276     absl::Span<const std::pair<int64_t, int64_t>> padding,
4277     int64_t feature_group_count, int64_t batch_group_count,
4278     const PrecisionConfig* precision_config,
4279     std::optional<PrimitiveType> preferred_element_type) {
4280   return lhs.builder()->ConvWithGeneralPadding(
4281       lhs, rhs, window_strides, padding, feature_group_count, batch_group_count,
4282       precision_config, preferred_element_type);
4283 }
4284 
ConvWithGeneralDimensions(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)4285 XlaOp ConvWithGeneralDimensions(
4286     const XlaOp lhs, const XlaOp rhs, absl::Span<const int64_t> window_strides,
4287     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
4288     int64_t feature_group_count, int64_t batch_group_count,
4289     const PrecisionConfig* precision_config,
4290     std::optional<PrimitiveType> preferred_element_type) {
4291   return lhs.builder()->ConvWithGeneralDimensions(
4292       lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
4293       batch_group_count, precision_config, preferred_element_type);
4294 }
4295 
ConvGeneral(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type)4296 XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
4297                   absl::Span<const int64_t> window_strides,
4298                   absl::Span<const std::pair<int64_t, int64_t>> padding,
4299                   const ConvolutionDimensionNumbers& dimension_numbers,
4300                   int64_t feature_group_count, int64_t batch_group_count,
4301                   const PrecisionConfig* precision_config,
4302                   std::optional<PrimitiveType> preferred_element_type) {
4303   return lhs.builder()->ConvGeneral(
4304       lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
4305       batch_group_count, precision_config, preferred_element_type);
4306 }
4307 
ConvGeneralDilated(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,std::optional<PrimitiveType> preferred_element_type,std::optional<std::vector<bool>> window_reversal)4308 XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
4309                          absl::Span<const int64_t> window_strides,
4310                          absl::Span<const std::pair<int64_t, int64_t>> padding,
4311                          absl::Span<const int64_t> lhs_dilation,
4312                          absl::Span<const int64_t> rhs_dilation,
4313                          const ConvolutionDimensionNumbers& dimension_numbers,
4314                          int64_t feature_group_count, int64_t batch_group_count,
4315                          const PrecisionConfig* precision_config,
4316                          std::optional<PrimitiveType> preferred_element_type,
4317                          std::optional<std::vector<bool>> window_reversal) {
4318   return lhs.builder()->ConvGeneralDilated(
4319       lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
4320       dimension_numbers, feature_group_count, batch_group_count,
4321       precision_config, preferred_element_type, window_reversal);
4322 }
4323 
DynamicConvInputGrad(XlaOp input_sizes,const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,std::optional<PrimitiveType> preferred_element_type)4324 XlaOp DynamicConvInputGrad(
4325     XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs,
4326     absl::Span<const int64_t> window_strides,
4327     absl::Span<const std::pair<int64_t, int64_t>> padding,
4328     absl::Span<const int64_t> lhs_dilation,
4329     absl::Span<const int64_t> rhs_dilation,
4330     const ConvolutionDimensionNumbers& dimension_numbers,
4331     int64_t feature_group_count, int64_t batch_group_count,
4332     const PrecisionConfig* precision_config, PaddingType padding_type,
4333     std::optional<PrimitiveType> preferred_element_type) {
4334   return lhs.builder()->DynamicConvInputGrad(
4335       input_sizes, lhs, rhs, window_strides, padding, lhs_dilation,
4336       rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
4337       precision_config, padding_type, preferred_element_type);
4338 }
4339 
DynamicConvKernelGrad(XlaOp activations,XlaOp gradients,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,std::optional<PrimitiveType> preferred_element_type)4340 XlaOp DynamicConvKernelGrad(
4341     XlaOp activations, XlaOp gradients,
4342     absl::Span<const int64_t> window_strides,
4343     absl::Span<const std::pair<int64_t, int64_t>> padding,
4344     absl::Span<const int64_t> lhs_dilation,
4345     absl::Span<const int64_t> rhs_dilation,
4346     const ConvolutionDimensionNumbers& dimension_numbers,
4347     int64_t feature_group_count, int64_t batch_group_count,
4348     const PrecisionConfig* precision_config, PaddingType padding_type,
4349     std::optional<PrimitiveType> preferred_element_type) {
4350   return activations.builder()->DynamicConvKernelGrad(
4351       activations, gradients, window_strides, padding, lhs_dilation,
4352       rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
4353       precision_config, padding_type, preferred_element_type);
4354 }
4355 
DynamicConvForward(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,std::optional<PrimitiveType> preferred_element_type)4356 XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
4357                          absl::Span<const int64_t> window_strides,
4358                          absl::Span<const std::pair<int64_t, int64_t>> padding,
4359                          absl::Span<const int64_t> lhs_dilation,
4360                          absl::Span<const int64_t> rhs_dilation,
4361                          const ConvolutionDimensionNumbers& dimension_numbers,
4362                          int64_t feature_group_count, int64_t batch_group_count,
4363                          const PrecisionConfig* precision_config,
4364                          PaddingType padding_type,
4365                          std::optional<PrimitiveType> preferred_element_type) {
4366   return lhs.builder()->DynamicConvForward(
4367       lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
4368       dimension_numbers, feature_group_count, batch_group_count,
4369       precision_config, padding_type, preferred_element_type);
4370 }
4371 
Fft(const XlaOp operand,FftType fft_type,absl::Span<const int64_t> fft_length)4372 XlaOp Fft(const XlaOp operand, FftType fft_type,
4373           absl::Span<const int64_t> fft_length) {
4374   return operand.builder()->Fft(operand, fft_type, fft_length);
4375 }
4376 
TriangularSolve(XlaOp a,XlaOp b,bool left_side,bool lower,bool unit_diagonal,TriangularSolveOptions::Transpose transpose_a)4377 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
4378                       bool unit_diagonal,
4379                       TriangularSolveOptions::Transpose transpose_a) {
4380   XlaBuilder* builder = a.builder();
4381   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
4382     TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a));
4383     TF_ASSIGN_OR_RETURN(const Shape* b_shape, builder->GetShapePtr(b));
4384     TriangularSolveOptions options;
4385     options.set_left_side(left_side);
4386     options.set_lower(lower);
4387     options.set_unit_diagonal(unit_diagonal);
4388     options.set_transpose_a(transpose_a);
4389     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTriangularSolveShape(
4390                                          *a_shape, *b_shape, options));
4391     return builder->TriangularSolveInternal(shape, a, b, std::move(options));
4392   });
4393 }
4394 
Cholesky(XlaOp a,bool lower)4395 XlaOp Cholesky(XlaOp a, bool lower) {
4396   XlaBuilder* builder = a.builder();
4397   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
4398     TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a));
4399     TF_ASSIGN_OR_RETURN(Shape shape,
4400                         ShapeInference::InferCholeskyShape(*a_shape));
4401     return builder->CholeskyInternal(shape, a, lower);
4402   });
4403 }
4404 
Infeed(XlaBuilder * builder,const Shape & shape,const std::string & config)4405 XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
4406              const std::string& config) {
4407   return builder->Infeed(shape, config);
4408 }
4409 
Outfeed(const XlaOp operand,const Shape & shape_with_layout,const std::string & outfeed_config)4410 void Outfeed(const XlaOp operand, const Shape& shape_with_layout,
4411              const std::string& outfeed_config) {
4412   return operand.builder()->Outfeed(operand, shape_with_layout, outfeed_config);
4413 }
4414 
Call(XlaBuilder * builder,const XlaComputation & computation,absl::Span<const XlaOp> operands)4415 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
4416            absl::Span<const XlaOp> operands) {
4417   return builder->Call(computation, operands);
4418 }
4419 
CustomCall(XlaBuilder * builder,const std::string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const std::string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> output_operand_aliasing,const Literal * literal,CustomCallSchedule schedule,CustomCallApiVersion api_version)4420 XlaOp CustomCall(
4421     XlaBuilder* builder, const std::string& call_target_name,
4422     absl::Span<const XlaOp> operands, const Shape& shape,
4423     const std::string& opaque, bool has_side_effect,
4424     absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
4425         output_operand_aliasing,
4426     const Literal* literal, CustomCallSchedule schedule,
4427     CustomCallApiVersion api_version) {
4428   return builder->CustomCall(call_target_name, operands, shape, opaque,
4429                              /*operand_shapes_with_layout=*/std::nullopt,
4430                              has_side_effect, output_operand_aliasing, literal,
4431                              /*window=*/std::nullopt, /*dnums=*/std::nullopt,
4432                              schedule, api_version);
4433 }
4434 
CustomCallWithComputation(XlaBuilder * builder,const std::string & call_target_name,absl::Span<const XlaOp> operands,const XlaComputation & computation,const Shape & shape,const std::string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> output_operand_aliasing,const Literal * literal,CustomCallSchedule schedule,CustomCallApiVersion api_version)4435 XlaOp CustomCallWithComputation(
4436     XlaBuilder* builder, const std::string& call_target_name,
4437     absl::Span<const XlaOp> operands, const XlaComputation& computation,
4438     const Shape& shape, const std::string& opaque, bool has_side_effect,
4439     absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
4440         output_operand_aliasing,
4441     const Literal* literal, CustomCallSchedule schedule,
4442     CustomCallApiVersion api_version) {
4443   return builder->CustomCall(
4444       call_target_name, operands, computation, shape, opaque,
4445       /*operand_shapes_with_layout=*/std::nullopt, has_side_effect,
4446       output_operand_aliasing, literal, schedule, api_version);
4447 }
4448 
CustomCallWithLayout(XlaBuilder * builder,const std::string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,absl::Span<const Shape> operand_shapes_with_layout,const std::string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> output_operand_aliasing,const Literal * literal,CustomCallSchedule schedule,CustomCallApiVersion api_version)4449 XlaOp CustomCallWithLayout(
4450     XlaBuilder* builder, const std::string& call_target_name,
4451     absl::Span<const XlaOp> operands, const Shape& shape,
4452     absl::Span<const Shape> operand_shapes_with_layout,
4453     const std::string& opaque, bool has_side_effect,
4454     absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
4455         output_operand_aliasing,
4456     const Literal* literal, CustomCallSchedule schedule,
4457     CustomCallApiVersion api_version) {
4458   return builder->CustomCall(
4459       call_target_name, operands, shape, opaque, operand_shapes_with_layout,
4460       has_side_effect, output_operand_aliasing, literal,
4461       /*window=*/std::nullopt, /*dnums=*/std::nullopt, schedule, api_version);
4462 }
4463 
CustomCallWithConvDnums(XlaBuilder * builder,const std::string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,absl::Span<const Shape> operand_shapes_with_layout,const std::string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> output_operand_aliasing,const Literal * literal,Window window,ConvolutionDimensionNumbers dnums,CustomCallSchedule schedule,CustomCallApiVersion api_version)4464 XlaOp CustomCallWithConvDnums(
4465     XlaBuilder* builder, const std::string& call_target_name,
4466     absl::Span<const XlaOp> operands, const Shape& shape,
4467     absl::Span<const Shape> operand_shapes_with_layout,
4468     const std::string& opaque, bool has_side_effect,
4469     absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
4470         output_operand_aliasing,
4471     const Literal* literal, Window window, ConvolutionDimensionNumbers dnums,
4472     CustomCallSchedule schedule, CustomCallApiVersion api_version) {
4473   std::optional<absl::Span<const Shape>> maybe_operand_shapes;
4474   if (!operand_shapes_with_layout.empty()) {
4475     maybe_operand_shapes = operand_shapes_with_layout;
4476   }
4477   return builder->CustomCall(call_target_name, operands, shape, opaque,
4478                              maybe_operand_shapes, has_side_effect,
4479                              output_operand_aliasing, literal, window, dnums,
4480                              schedule, api_version);
4481 }
4482 
OptimizationBarrier(XlaOp operand)4483 XlaOp OptimizationBarrier(XlaOp operand) {
4484   return operand.builder()->OptimizationBarrier(operand);
4485 }
4486 
Complex(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4487 XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
4488               absl::Span<const int64_t> broadcast_dimensions) {
4489   return lhs.builder()->BinaryOp(HloOpcode::kComplex, lhs, rhs,
4490                                  broadcast_dimensions);
4491 }
4492 
Conj(const XlaOp operand)4493 XlaOp Conj(const XlaOp operand) {
4494   return Complex(Real(operand), Neg(Imag(operand)));
4495 }
4496 
Add(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4497 XlaOp Add(const XlaOp lhs, const XlaOp rhs,
4498           absl::Span<const int64_t> broadcast_dimensions) {
4499   return lhs.builder()->BinaryOp(HloOpcode::kAdd, lhs, rhs,
4500                                  broadcast_dimensions);
4501 }
4502 
Sub(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4503 XlaOp Sub(const XlaOp lhs, const XlaOp rhs,
4504           absl::Span<const int64_t> broadcast_dimensions) {
4505   return lhs.builder()->BinaryOp(HloOpcode::kSubtract, lhs, rhs,
4506                                  broadcast_dimensions);
4507 }
4508 
Mul(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4509 XlaOp Mul(const XlaOp lhs, const XlaOp rhs,
4510           absl::Span<const int64_t> broadcast_dimensions) {
4511   return lhs.builder()->BinaryOp(HloOpcode::kMultiply, lhs, rhs,
4512                                  broadcast_dimensions);
4513 }
4514 
Div(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4515 XlaOp Div(const XlaOp lhs, const XlaOp rhs,
4516           absl::Span<const int64_t> broadcast_dimensions) {
4517   return lhs.builder()->BinaryOp(HloOpcode::kDivide, lhs, rhs,
4518                                  broadcast_dimensions);
4519 }
4520 
Rem(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4521 XlaOp Rem(const XlaOp lhs, const XlaOp rhs,
4522           absl::Span<const int64_t> broadcast_dimensions) {
4523   return lhs.builder()->BinaryOp(HloOpcode::kRemainder, lhs, rhs,
4524                                  broadcast_dimensions);
4525 }
4526 
Max(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4527 XlaOp Max(const XlaOp lhs, const XlaOp rhs,
4528           absl::Span<const int64_t> broadcast_dimensions) {
4529   return lhs.builder()->BinaryOp(HloOpcode::kMaximum, lhs, rhs,
4530                                  broadcast_dimensions);
4531 }
4532 
Min(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4533 XlaOp Min(const XlaOp lhs, const XlaOp rhs,
4534           absl::Span<const int64_t> broadcast_dimensions) {
4535   return lhs.builder()->BinaryOp(HloOpcode::kMinimum, lhs, rhs,
4536                                  broadcast_dimensions);
4537 }
4538 
And(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4539 XlaOp And(const XlaOp lhs, const XlaOp rhs,
4540           absl::Span<const int64_t> broadcast_dimensions) {
4541   return lhs.builder()->BinaryOp(HloOpcode::kAnd, lhs, rhs,
4542                                  broadcast_dimensions);
4543 }
4544 
Or(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4545 XlaOp Or(const XlaOp lhs, const XlaOp rhs,
4546          absl::Span<const int64_t> broadcast_dimensions) {
4547   return lhs.builder()->BinaryOp(HloOpcode::kOr, lhs, rhs,
4548                                  broadcast_dimensions);
4549 }
4550 
Xor(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4551 XlaOp Xor(const XlaOp lhs, const XlaOp rhs,
4552           absl::Span<const int64_t> broadcast_dimensions) {
4553   return lhs.builder()->BinaryOp(HloOpcode::kXor, lhs, rhs,
4554                                  broadcast_dimensions);
4555 }
4556 
Not(const XlaOp operand)4557 XlaOp Not(const XlaOp operand) {
4558   return operand.builder()->UnaryOp(HloOpcode::kNot, operand);
4559 }
4560 
PopulationCount(const XlaOp operand)4561 XlaOp PopulationCount(const XlaOp operand) {
4562   return operand.builder()->UnaryOp(HloOpcode::kPopulationCount, operand);
4563 }
4564 
ShiftLeft(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4565 XlaOp ShiftLeft(const XlaOp lhs, const XlaOp rhs,
4566                 absl::Span<const int64_t> broadcast_dimensions) {
4567   return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs,
4568                                  broadcast_dimensions);
4569 }
4570 
ShiftRightArithmetic(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4571 XlaOp ShiftRightArithmetic(const XlaOp lhs, const XlaOp rhs,
4572                            absl::Span<const int64_t> broadcast_dimensions) {
4573   return lhs.builder()->BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs,
4574                                  broadcast_dimensions);
4575 }
4576 
ShiftRightLogical(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4577 XlaOp ShiftRightLogical(const XlaOp lhs, const XlaOp rhs,
4578                         absl::Span<const int64_t> broadcast_dimensions) {
4579   return lhs.builder()->BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs,
4580                                  broadcast_dimensions);
4581 }
4582 
Reduce(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64_t> dimensions_to_reduce)4583 XlaOp Reduce(const XlaOp operand, const XlaOp init_value,
4584              const XlaComputation& computation,
4585              absl::Span<const int64_t> dimensions_to_reduce) {
4586   return operand.builder()->Reduce(operand, init_value, computation,
4587                                    dimensions_to_reduce);
4588 }
4589 
4590 // Reduces several arrays simultaneously among the provided dimensions, given
4591 // "computation" as a reduction operator.
Reduce(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64_t> dimensions_to_reduce)4592 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
4593              absl::Span<const XlaOp> init_values,
4594              const XlaComputation& computation,
4595              absl::Span<const int64_t> dimensions_to_reduce) {
4596   return builder->Reduce(operands, init_values, computation,
4597                          dimensions_to_reduce);
4598 }
4599 
ReduceAll(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation)4600 XlaOp ReduceAll(const XlaOp operand, const XlaOp init_value,
4601                 const XlaComputation& computation) {
4602   return operand.builder()->ReduceAll(operand, init_value, computation);
4603 }
4604 
ReduceWindow(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,Padding padding)4605 XlaOp ReduceWindow(const XlaOp operand, const XlaOp init_value,
4606                    const XlaComputation& computation,
4607                    absl::Span<const int64_t> window_dimensions,
4608                    absl::Span<const int64_t> window_strides, Padding padding) {
4609   return operand.builder()->ReduceWindow(operand, init_value, computation,
4610                                          window_dimensions, window_strides,
4611                                          padding);
4612 }
4613 
ReduceWindow(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,Padding padding)4614 XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
4615                    absl::Span<const XlaOp> init_values,
4616                    const XlaComputation& computation,
4617                    absl::Span<const int64_t> window_dimensions,
4618                    absl::Span<const int64_t> window_strides, Padding padding) {
4619   CHECK(!operands.empty());
4620   return operands[0].builder()->ReduceWindow(operands, init_values, computation,
4621                                              window_dimensions, window_strides,
4622                                              padding);
4623 }
4624 
ReduceWindowWithGeneralPadding(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,absl::Span<const int64_t> base_dilations,absl::Span<const int64_t> window_dilations,absl::Span<const std::pair<int64_t,int64_t>> padding)4625 XlaOp ReduceWindowWithGeneralPadding(
4626     const XlaOp operand, const XlaOp init_value,
4627     const XlaComputation& computation,
4628     absl::Span<const int64_t> window_dimensions,
4629     absl::Span<const int64_t> window_strides,
4630     absl::Span<const int64_t> base_dilations,
4631     absl::Span<const int64_t> window_dilations,
4632     absl::Span<const std::pair<int64_t, int64_t>> padding) {
4633   return operand.builder()->ReduceWindowWithGeneralPadding(
4634       absl::MakeSpan(&operand, 1), absl::MakeSpan(&init_value, 1), computation,
4635       window_dimensions, window_strides, base_dilations, window_dilations,
4636       padding);
4637 }
4638 
ReduceWindowWithGeneralPadding(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,absl::Span<const int64_t> base_dilations,absl::Span<const int64_t> window_dilations,absl::Span<const std::pair<int64_t,int64_t>> padding)4639 XlaOp ReduceWindowWithGeneralPadding(
4640     absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
4641     const XlaComputation& computation,
4642     absl::Span<const int64_t> window_dimensions,
4643     absl::Span<const int64_t> window_strides,
4644     absl::Span<const int64_t> base_dilations,
4645     absl::Span<const int64_t> window_dilations,
4646     absl::Span<const std::pair<int64_t, int64_t>> padding) {
4647   CHECK(!operands.empty());
4648   return operands[0].builder()->ReduceWindowWithGeneralPadding(
4649       operands, init_values, computation, window_dimensions, window_strides,
4650       base_dilations, window_dilations, padding);
4651 }
4652 
AllGather(const XlaOp operand,int64_t all_gather_dimension,int64_t shard_count,absl::Span<const ReplicaGroup> replica_groups,const std::optional<ChannelHandle> & channel_id,const std::optional<Layout> & layout,const std::optional<bool> use_global_device_ids)4653 XlaOp AllGather(const XlaOp operand, int64_t all_gather_dimension,
4654                 int64_t shard_count,
4655                 absl::Span<const ReplicaGroup> replica_groups,
4656                 const std::optional<ChannelHandle>& channel_id,
4657                 const std::optional<Layout>& layout,
4658                 const std::optional<bool> use_global_device_ids) {
4659   return operand.builder()->AllGather(operand, all_gather_dimension,
4660                                       shard_count, replica_groups, channel_id,
4661                                       layout, use_global_device_ids);
4662 }
4663 
CrossReplicaSum(const XlaOp operand,absl::Span<const ReplicaGroup> replica_groups)4664 XlaOp CrossReplicaSum(const XlaOp operand,
4665                       absl::Span<const ReplicaGroup> replica_groups) {
4666   return operand.builder()->CrossReplicaSum(operand, replica_groups);
4667 }
4668 
AllReduce(const XlaOp operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const std::optional<ChannelHandle> & channel_id,const std::optional<Shape> & shape_with_layout,const std::optional<bool> use_global_device_ids)4669 XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
4670                 absl::Span<const ReplicaGroup> replica_groups,
4671                 const std::optional<ChannelHandle>& channel_id,
4672                 const std::optional<Shape>& shape_with_layout,
4673                 const std::optional<bool> use_global_device_ids) {
4674   return operand.builder()->AllReduce(operand, computation, replica_groups,
4675                                       channel_id, shape_with_layout,
4676                                       use_global_device_ids);
4677 }
4678 
ReduceScatter(const XlaOp operand,const XlaComputation & computation,int64_t scatter_dimension,int64_t shard_count,absl::Span<const ReplicaGroup> replica_groups,const std::optional<ChannelHandle> & channel_id,const std::optional<Layout> & layout,const std::optional<bool> use_global_device_ids)4679 XlaOp ReduceScatter(const XlaOp operand, const XlaComputation& computation,
4680                     int64_t scatter_dimension, int64_t shard_count,
4681                     absl::Span<const ReplicaGroup> replica_groups,
4682                     const std::optional<ChannelHandle>& channel_id,
4683                     const std::optional<Layout>& layout,
4684                     const std::optional<bool> use_global_device_ids) {
4685   return operand.builder()->ReduceScatter(
4686       operand, computation, scatter_dimension, shard_count, replica_groups,
4687       channel_id, layout, use_global_device_ids);
4688 }
4689 
AllToAll(const XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups,const std::optional<Layout> & layout)4690 XlaOp AllToAll(const XlaOp operand, int64_t split_dimension,
4691                int64_t concat_dimension, int64_t split_count,
4692                absl::Span<const ReplicaGroup> replica_groups,
4693                const std::optional<Layout>& layout) {
4694   return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
4695                                      split_count, replica_groups, layout);
4696 }
4697 
AllToAllTuple(absl::Span<const XlaOp> operands,absl::Span<const ReplicaGroup> replica_groups,const std::optional<Layout> & layout)4698 XlaOp AllToAllTuple(absl::Span<const XlaOp> operands,
4699                     absl::Span<const ReplicaGroup> replica_groups,
4700                     const std::optional<Layout>& layout) {
4701   CHECK(!operands.empty());
4702   return operands[0].builder()->AllToAllTuple(operands, replica_groups, layout);
4703 }
4704 
AllToAllTuple(const XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups,const std::optional<Layout> & layout)4705 XlaOp AllToAllTuple(const XlaOp operand, int64_t split_dimension,
4706                     int64_t concat_dimension, int64_t split_count,
4707                     absl::Span<const ReplicaGroup> replica_groups,
4708                     const std::optional<Layout>& layout) {
4709   return operand.builder()->AllToAllTuple(operand, split_dimension,
4710                                           concat_dimension, split_count,
4711                                           replica_groups, layout);
4712 }
4713 
CollectivePermute(const XlaOp operand,const std::vector<std::pair<int64_t,int64_t>> & source_target_pairs)4714 XlaOp CollectivePermute(
4715     const XlaOp operand,
4716     const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs) {
4717   return operand.builder()->CollectivePermute(operand, source_target_pairs);
4718 }
4719 
ReplicaId(XlaBuilder * builder)4720 XlaOp ReplicaId(XlaBuilder* builder) { return builder->ReplicaId(); }
4721 
SelectAndScatter(const XlaOp operand,const XlaComputation & select,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,Padding padding,const XlaOp source,const XlaOp init_value,const XlaComputation & scatter)4722 XlaOp SelectAndScatter(const XlaOp operand, const XlaComputation& select,
4723                        absl::Span<const int64_t> window_dimensions,
4724                        absl::Span<const int64_t> window_strides,
4725                        Padding padding, const XlaOp source,
4726                        const XlaOp init_value, const XlaComputation& scatter) {
4727   return operand.builder()->SelectAndScatter(operand, select, window_dimensions,
4728                                              window_strides, padding, source,
4729                                              init_value, scatter);
4730 }
4731 
SelectAndScatterWithGeneralPadding(const XlaOp operand,const XlaComputation & select,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,const XlaOp source,const XlaOp init_value,const XlaComputation & scatter)4732 XlaOp SelectAndScatterWithGeneralPadding(
4733     const XlaOp operand, const XlaComputation& select,
4734     absl::Span<const int64_t> window_dimensions,
4735     absl::Span<const int64_t> window_strides,
4736     absl::Span<const std::pair<int64_t, int64_t>> padding, const XlaOp source,
4737     const XlaOp init_value, const XlaComputation& scatter) {
4738   return operand.builder()->SelectAndScatterWithGeneralPadding(
4739       operand, select, window_dimensions, window_strides, padding, source,
4740       init_value, scatter);
4741 }
4742 
Abs(const XlaOp operand)4743 XlaOp Abs(const XlaOp operand) {
4744   return operand.builder()->UnaryOp(HloOpcode::kAbs, operand);
4745 }
4746 
Atan2(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4747 XlaOp Atan2(const XlaOp lhs, const XlaOp rhs,
4748             absl::Span<const int64_t> broadcast_dimensions) {
4749   return lhs.builder()->BinaryOp(HloOpcode::kAtan2, lhs, rhs,
4750                                  broadcast_dimensions);
4751 }
4752 
Exp(const XlaOp operand)4753 XlaOp Exp(const XlaOp operand) {
4754   return operand.builder()->UnaryOp(HloOpcode::kExp, operand);
4755 }
Expm1(const XlaOp operand)4756 XlaOp Expm1(const XlaOp operand) {
4757   return operand.builder()->UnaryOp(HloOpcode::kExpm1, operand);
4758 }
Floor(const XlaOp operand)4759 XlaOp Floor(const XlaOp operand) {
4760   return operand.builder()->UnaryOp(HloOpcode::kFloor, operand);
4761 }
Ceil(const XlaOp operand)4762 XlaOp Ceil(const XlaOp operand) {
4763   return operand.builder()->UnaryOp(HloOpcode::kCeil, operand);
4764 }
Round(const XlaOp operand)4765 XlaOp Round(const XlaOp operand) {
4766   return operand.builder()->UnaryOp(HloOpcode::kRoundNearestAfz, operand);
4767 }
RoundNearestEven(const XlaOp operand)4768 XlaOp RoundNearestEven(const XlaOp operand) {
4769   return operand.builder()->UnaryOp(HloOpcode::kRoundNearestEven, operand);
4770 }
Log(const XlaOp operand)4771 XlaOp Log(const XlaOp operand) {
4772   return operand.builder()->UnaryOp(HloOpcode::kLog, operand);
4773 }
Log1p(const XlaOp operand)4774 XlaOp Log1p(const XlaOp operand) {
4775   return operand.builder()->UnaryOp(HloOpcode::kLog1p, operand);
4776 }
Logistic(const XlaOp operand)4777 XlaOp Logistic(const XlaOp operand) {
4778   return operand.builder()->UnaryOp(HloOpcode::kLogistic, operand);
4779 }
Sign(const XlaOp operand)4780 XlaOp Sign(const XlaOp operand) {
4781   return operand.builder()->UnaryOp(HloOpcode::kSign, operand);
4782 }
Clz(const XlaOp operand)4783 XlaOp Clz(const XlaOp operand) {
4784   return operand.builder()->UnaryOp(HloOpcode::kClz, operand);
4785 }
Cos(const XlaOp operand)4786 XlaOp Cos(const XlaOp operand) {
4787   return operand.builder()->UnaryOp(HloOpcode::kCos, operand);
4788 }
Sin(const XlaOp operand)4789 XlaOp Sin(const XlaOp operand) {
4790   return operand.builder()->UnaryOp(HloOpcode::kSin, operand);
4791 }
Tanh(const XlaOp operand)4792 XlaOp Tanh(const XlaOp operand) {
4793   return operand.builder()->UnaryOp(HloOpcode::kTanh, operand);
4794 }
Real(const XlaOp operand)4795 XlaOp Real(const XlaOp operand) {
4796   return operand.builder()->UnaryOp(HloOpcode::kReal, operand);
4797 }
Imag(const XlaOp operand)4798 XlaOp Imag(const XlaOp operand) {
4799   return operand.builder()->UnaryOp(HloOpcode::kImag, operand);
4800 }
Sqrt(const XlaOp operand)4801 XlaOp Sqrt(const XlaOp operand) {
4802   return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand);
4803 }
Cbrt(const XlaOp operand)4804 XlaOp Cbrt(const XlaOp operand) {
4805   return operand.builder()->UnaryOp(HloOpcode::kCbrt, operand);
4806 }
Rsqrt(const XlaOp operand)4807 XlaOp Rsqrt(const XlaOp operand) {
4808   return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand);
4809 }
4810 
Pow(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64_t> broadcast_dimensions)4811 XlaOp Pow(const XlaOp lhs, const XlaOp rhs,
4812           absl::Span<const int64_t> broadcast_dimensions) {
4813   return lhs.builder()->BinaryOp(HloOpcode::kPower, lhs, rhs,
4814                                  broadcast_dimensions);
4815 }
4816 
IsFinite(const XlaOp operand)4817 XlaOp IsFinite(const XlaOp operand) {
4818   return operand.builder()->UnaryOp(HloOpcode::kIsFinite, operand);
4819 }
4820 
ConvertElementType(const XlaOp operand,PrimitiveType new_element_type)4821 XlaOp ConvertElementType(const XlaOp operand, PrimitiveType new_element_type) {
4822   return operand.builder()->ConvertElementType(operand, new_element_type);
4823 }
4824 
BitcastConvertType(const XlaOp operand,PrimitiveType new_element_type)4825 XlaOp BitcastConvertType(const XlaOp operand, PrimitiveType new_element_type) {
4826   return operand.builder()->BitcastConvertType(operand, new_element_type);
4827 }
4828 
Neg(const XlaOp operand)4829 XlaOp Neg(const XlaOp operand) {
4830   return operand.builder()->UnaryOp(HloOpcode::kNegate, operand);
4831 }
4832 
Transpose(const XlaOp operand,absl::Span<const int64_t> permutation)4833 XlaOp Transpose(const XlaOp operand, absl::Span<const int64_t> permutation) {
4834   return operand.builder()->Transpose(operand, permutation);
4835 }
4836 
Rev(const XlaOp operand,absl::Span<const int64_t> dimensions)4837 XlaOp Rev(const XlaOp operand, absl::Span<const int64_t> dimensions) {
4838   return operand.builder()->Rev(operand, dimensions);
4839 }
4840 
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64_t dimension,bool is_stable)4841 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
4842            int64_t dimension, bool is_stable) {
4843   return operands[0].builder()->Sort(operands, comparator, dimension,
4844                                      is_stable);
4845 }
4846 
Clamp(const XlaOp min,const XlaOp operand,const XlaOp max)4847 XlaOp Clamp(const XlaOp min, const XlaOp operand, const XlaOp max) {
4848   return min.builder()->Clamp(min, operand, max);
4849 }
4850 
Map(XlaBuilder * builder,absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64_t> dimensions,absl::Span<const XlaOp> static_operands)4851 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
4852           const XlaComputation& computation,
4853           absl::Span<const int64_t> dimensions,
4854           absl::Span<const XlaOp> static_operands) {
4855   return builder->Map(operands, computation, dimensions, static_operands);
4856 }
4857 
RngNormal(const XlaOp mu,const XlaOp sigma,const Shape & shape)4858 XlaOp RngNormal(const XlaOp mu, const XlaOp sigma, const Shape& shape) {
4859   return mu.builder()->RngNormal(mu, sigma, shape);
4860 }
4861 
RngUniform(const XlaOp a,const XlaOp b,const Shape & shape)4862 XlaOp RngUniform(const XlaOp a, const XlaOp b, const Shape& shape) {
4863   return a.builder()->RngUniform(a, b, shape);
4864 }
4865 
RngBitGenerator(RandomAlgorithm algorithm,const XlaOp initial_state,const Shape & shape)4866 XlaOp RngBitGenerator(RandomAlgorithm algorithm, const XlaOp initial_state,
4867                       const Shape& shape) {
4868   return initial_state.builder()->RngBitGenerator(algorithm, initial_state,
4869                                                   shape);
4870 }
4871 
While(const XlaComputation & condition,const XlaComputation & body,const XlaOp init)4872 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
4873             const XlaOp init) {
4874   return init.builder()->While(condition, body, init);
4875 }
4876 
Conditional(const XlaOp predicate,const XlaOp true_operand,const XlaComputation & true_computation,const XlaOp false_operand,const XlaComputation & false_computation)4877 XlaOp Conditional(const XlaOp predicate, const XlaOp true_operand,
4878                   const XlaComputation& true_computation,
4879                   const XlaOp false_operand,
4880                   const XlaComputation& false_computation) {
4881   return predicate.builder()->Conditional(predicate, true_operand,
4882                                           true_computation, false_operand,
4883                                           false_computation);
4884 }
4885 
Conditional(const XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)4886 XlaOp Conditional(const XlaOp branch_index,
4887                   absl::Span<const XlaComputation* const> branch_computations,
4888                   absl::Span<const XlaOp> branch_operands) {
4889   return branch_index.builder()->Conditional(branch_index, branch_computations,
4890                                              branch_operands);
4891 }
4892 
ReducePrecision(const XlaOp operand,const int exponent_bits,const int mantissa_bits)4893 XlaOp ReducePrecision(const XlaOp operand, const int exponent_bits,
4894                       const int mantissa_bits) {
4895   return operand.builder()->ReducePrecision(operand, exponent_bits,
4896                                             mantissa_bits);
4897 }
4898 
Gather(const XlaOp input,const XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64_t> slice_sizes,bool indices_are_sorted)4899 XlaOp Gather(const XlaOp input, const XlaOp start_indices,
4900              const GatherDimensionNumbers& dimension_numbers,
4901              absl::Span<const int64_t> slice_sizes, bool indices_are_sorted) {
4902   return input.builder()->Gather(input, start_indices, dimension_numbers,
4903                                  slice_sizes, indices_are_sorted);
4904 }
4905 
Scatter(const XlaOp input,const XlaOp scatter_indices,const XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)4906 XlaOp Scatter(const XlaOp input, const XlaOp scatter_indices,
4907               const XlaOp updates, const XlaComputation& update_computation,
4908               const ScatterDimensionNumbers& dimension_numbers,
4909               bool indices_are_sorted, bool unique_indices) {
4910   return input.builder()->Scatter(input, scatter_indices, updates,
4911                                   update_computation, dimension_numbers,
4912                                   indices_are_sorted, unique_indices);
4913 }
4914 
Scatter(absl::Span<const XlaOp> inputs,XlaOp scatter_indices,absl::Span<const XlaOp> updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)4915 XlaOp Scatter(absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
4916               absl::Span<const XlaOp> updates,
4917               const XlaComputation& update_computation,
4918               const ScatterDimensionNumbers& dimension_numbers,
4919               bool indices_are_sorted, bool unique_indices) {
4920   return scatter_indices.builder()->Scatter(
4921       inputs, scatter_indices, updates, update_computation, dimension_numbers,
4922       indices_are_sorted, unique_indices);
4923 }
4924 
Send(const XlaOp operand,const ChannelHandle & handle)4925 void Send(const XlaOp operand, const ChannelHandle& handle) {
4926   return operand.builder()->Send(operand, handle);
4927 }
4928 
Recv(XlaBuilder * builder,const Shape & shape,const ChannelHandle & handle)4929 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
4930            const ChannelHandle& handle) {
4931   return builder->Recv(shape, handle);
4932 }
4933 
SendWithToken(const XlaOp operand,const XlaOp token,const ChannelHandle & handle)4934 XlaOp SendWithToken(const XlaOp operand, const XlaOp token,
4935                     const ChannelHandle& handle) {
4936   return operand.builder()->SendWithToken(operand, token, handle);
4937 }
4938 
RecvWithToken(const XlaOp token,const Shape & shape,const ChannelHandle & handle)4939 XlaOp RecvWithToken(const XlaOp token, const Shape& shape,
4940                     const ChannelHandle& handle) {
4941   return token.builder()->RecvWithToken(token, shape, handle);
4942 }
4943 
SendToHost(const XlaOp operand,const XlaOp token,const Shape & shape_with_layout,const ChannelHandle & handle)4944 XlaOp SendToHost(const XlaOp operand, const XlaOp token,
4945                  const Shape& shape_with_layout, const ChannelHandle& handle) {
4946   return operand.builder()->SendToHost(operand, token, shape_with_layout,
4947                                        handle);
4948 }
4949 
RecvFromHost(const XlaOp token,const Shape & shape,const ChannelHandle & handle)4950 XlaOp RecvFromHost(const XlaOp token, const Shape& shape,
4951                    const ChannelHandle& handle) {
4952   return token.builder()->RecvFromHost(token, shape, handle);
4953 }
4954 
InfeedWithToken(const XlaOp token,const Shape & shape,const std::string & config)4955 XlaOp InfeedWithToken(const XlaOp token, const Shape& shape,
4956                       const std::string& config) {
4957   return token.builder()->InfeedWithToken(token, shape, config);
4958 }
4959 
OutfeedWithToken(const XlaOp operand,const XlaOp token,const Shape & shape_with_layout,const std::string & outfeed_config)4960 XlaOp OutfeedWithToken(const XlaOp operand, const XlaOp token,
4961                        const Shape& shape_with_layout,
4962                        const std::string& outfeed_config) {
4963   return operand.builder()->OutfeedWithToken(operand, token, shape_with_layout,
4964                                              outfeed_config);
4965 }
4966 
CreateToken(XlaBuilder * builder)4967 XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); }
4968 
AfterAll(XlaBuilder * builder,absl::Span<const XlaOp> tokens)4969 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens) {
4970   return builder->AfterAll(tokens);
4971 }
4972 
BatchNormTraining(const XlaOp operand,const XlaOp scale,const XlaOp offset,float epsilon,int64_t feature_index)4973 XlaOp BatchNormTraining(const XlaOp operand, const XlaOp scale,
4974                         const XlaOp offset, float epsilon,
4975                         int64_t feature_index) {
4976   return operand.builder()->BatchNormTraining(operand, scale, offset, epsilon,
4977                                               feature_index);
4978 }
4979 
BatchNormInference(const XlaOp operand,const XlaOp scale,const XlaOp offset,const XlaOp mean,const XlaOp variance,float epsilon,int64_t feature_index)4980 XlaOp BatchNormInference(const XlaOp operand, const XlaOp scale,
4981                          const XlaOp offset, const XlaOp mean,
4982                          const XlaOp variance, float epsilon,
4983                          int64_t feature_index) {
4984   return operand.builder()->BatchNormInference(
4985       operand, scale, offset, mean, variance, epsilon, feature_index);
4986 }
4987 
BatchNormGrad(const XlaOp operand,const XlaOp scale,const XlaOp batch_mean,const XlaOp batch_var,const XlaOp grad_output,float epsilon,int64_t feature_index)4988 XlaOp BatchNormGrad(const XlaOp operand, const XlaOp scale,
4989                     const XlaOp batch_mean, const XlaOp batch_var,
4990                     const XlaOp grad_output, float epsilon,
4991                     int64_t feature_index) {
4992   return operand.builder()->BatchNormGrad(operand, scale, batch_mean, batch_var,
4993                                           grad_output, epsilon, feature_index);
4994 }
4995 
Iota(XlaBuilder * builder,PrimitiveType type,int64_t size)4996 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size) {
4997   return builder->Iota(type, size);
4998 }
4999 
Iota(XlaBuilder * builder,const Shape & shape,int64_t iota_dimension)5000 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64_t iota_dimension) {
5001   return builder->Iota(shape, iota_dimension);
5002 }
5003 
GetDimensionSize(const XlaOp operand,int64_t dimension)5004 XlaOp GetDimensionSize(const XlaOp operand, int64_t dimension) {
5005   return operand.builder()->GetDimensionSize(operand, dimension);
5006 }
5007 
SetDimensionSize(const XlaOp operand,const XlaOp val,int64_t dimension)5008 XlaOp SetDimensionSize(const XlaOp operand, const XlaOp val,
5009                        int64_t dimension) {
5010   return operand.builder()->SetDimensionSize(operand, val, dimension);
5011 }
5012 
RemoveDynamicDimension(const XlaOp operand,int64_t dimension)5013 XlaOp RemoveDynamicDimension(const XlaOp operand, int64_t dimension) {
5014   return operand.builder()->RemoveDynamicDimension(operand, dimension);
5015 }
5016 
GetManualSharding(const OpSharding & original,int64_t single_dim)5017 OpSharding GetManualSharding(const OpSharding& original, int64_t single_dim) {
5018   OpSharding manual;
5019   if (single_dim < 0 || original.type() != OpSharding::OTHER) {
5020     manual.set_type(OpSharding::MANUAL);
5021     return manual;
5022   }
5023   manual.set_type(OpSharding::OTHER);
5024   std::vector<int64_t> new_tile_shape(
5025       original.tile_assignment_dimensions().begin(),
5026       original.tile_assignment_dimensions().end());
5027   new_tile_shape.push_back(new_tile_shape[single_dim]);
5028   new_tile_shape[single_dim] = 1;
5029   Array<int64_t> new_tile(new_tile_shape);
5030   new_tile.Each([&](absl::Span<const int64_t> indices, int64_t* v) {
5031     int64_t src_index = 0;
5032     for (int64_t i = 0; i < indices.size() - 1; ++i) {
5033       if (i > 0) {
5034         src_index *= new_tile_shape[i];
5035       }
5036       int64_t index = indices[i];
5037       if (i == single_dim) {
5038         index = indices.back();
5039       }
5040       src_index += index;
5041     }
5042     *v = original.tile_assignment_devices(src_index);
5043   });
5044   for (int64_t dim : new_tile_shape) {
5045     manual.add_tile_assignment_dimensions(dim);
5046   }
5047   for (int64_t device : new_tile) {
5048     manual.add_tile_assignment_devices(device);
5049   }
5050   if (original.replicate_on_last_tile_dim()) {
5051     manual.add_last_tile_dims(OpSharding::REPLICATED);
5052   }
5053   for (int64_t type : original.last_tile_dims()) {
5054     manual.add_last_tile_dims(static_cast<OpSharding::Type>(type));
5055   }
5056   manual.add_last_tile_dims(OpSharding::MANUAL);
5057   return manual;
5058 }
5059 
ConvertSpmdFullToShardShape(XlaBuilder * builder,XlaOp input,int single_dim,const OpSharding & manual_sharding,absl::Span<const int64_t> unspecified_dims)5060 StatusOr<XlaOp> ConvertSpmdFullToShardShape(
5061     XlaBuilder* builder, XlaOp input, int single_dim,
5062     const OpSharding& manual_sharding,
5063     absl::Span<const int64_t> unspecified_dims) {
5064   TF_ASSIGN_OR_RETURN(const Shape input_shape, builder->GetShape(input));
5065 
5066   Shape output_shape = input_shape;
5067   const int64_t rank = output_shape.rank();
5068   if (manual_sharding.type() == OpSharding::OTHER) {
5069     for (int64_t i = 0; i < rank; ++i) {
5070       if (single_dim >= 0 && i != single_dim) {
5071         continue;
5072       }
5073       const int64_t partitions_i =
5074           manual_sharding.tile_assignment_dimensions(i);
5075       if (partitions_i == 1) continue;
5076       const int64_t dim_size =
5077           CeilOfRatio(output_shape.dimensions(i), partitions_i);
5078       output_shape.set_dimensions(i, dim_size);
5079     }
5080   }
5081 
5082   XlaOp input_annotation;
5083   {
5084     // Annotate the full-shape input with the sharding.
5085     XlaScopedShardingAssignment assign_sharding(builder, manual_sharding);
5086     input_annotation = CustomCall(
5087         builder, /*call_target_name=*/"Sharding", {input}, input_shape,
5088         /*opaque=*/
5089         sharding_op_util::EncodeAttributes(unspecified_dims));
5090   }
5091 
5092   {
5093     // Annotate the shard-shape output with manual sharding, so that the
5094     // partitioner will leave it as is.
5095     OpSharding manual = GetManualSharding(manual_sharding, single_dim);
5096     XlaScopedShardingAssignment assign_sharding(builder, manual);
5097     return CustomCall(builder,
5098                       /*call_target_name=*/"SPMDFullToShardShape",
5099                       {input_annotation}, output_shape,
5100                       /*opaque=*/
5101                       sharding_op_util::EncodeAttributes(unspecified_dims));
5102   }
5103 }
5104 
ConvertSpmdShardToFullShape(XlaBuilder * builder,XlaOp input,const Shape & output_shape,int single_dim,const OpSharding & manual_sharding,absl::Span<const int64_t> unspecified_dims)5105 StatusOr<XlaOp> ConvertSpmdShardToFullShape(
5106     XlaBuilder* builder, XlaOp input, const Shape& output_shape, int single_dim,
5107     const OpSharding& manual_sharding,
5108     absl::Span<const int64_t> unspecified_dims) {
5109   TF_ASSIGN_OR_RETURN(const Shape input_shape, builder->GetShape(input));
5110 
5111   XlaOp input_annotation;
5112   {
5113     // Annotate the shard-shape input with manual sharding, so that the
5114     // partitioner will leave it as is.
5115     OpSharding manual = GetManualSharding(manual_sharding, single_dim);
5116     XlaScopedShardingAssignment assign_sharding(builder, manual);
5117     input_annotation = CustomCall(
5118         builder, /*call_target_name=*/"Sharding", {input}, input_shape,
5119         sharding_op_util::EncodeAttributes(unspecified_dims));
5120   }
5121 
5122   {
5123     // Annotate the full-shape output with the sharding.
5124     XlaScopedShardingAssignment assign_sharding(builder, manual_sharding);
5125     return CustomCall(builder,
5126                       /*call_target_name=*/"SPMDShardToFullShape",
5127                       {input_annotation}, output_shape,
5128                       sharding_op_util::EncodeAttributes(unspecified_dims));
5129   }
5130 }
5131 
5132 }  // namespace xla
5133