• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/xla_builder.h"
17 
18 #include <functional>
19 #include <numeric>
20 #include <queue>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/memory/memory.h"
28 #include "absl/strings/match.h"
29 #include "absl/strings/numbers.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/strings/str_split.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/xla/client/sharding_builder.h"
35 #include "tensorflow/compiler/xla/client/xla_computation.h"
36 #include "tensorflow/compiler/xla/comparison_util.h"
37 #include "tensorflow/compiler/xla/execution_options_util.h"
38 #include "tensorflow/compiler/xla/layout_util.h"
39 #include "tensorflow/compiler/xla/permutation_util.h"
40 #include "tensorflow/compiler/xla/primitive_util.h"
41 #include "tensorflow/compiler/xla/service/hlo.pb.h"
42 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
43 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
44 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
45 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
46 #include "tensorflow/compiler/xla/service/shape_inference.h"
47 #include "tensorflow/compiler/xla/status_macros.h"
48 #include "tensorflow/compiler/xla/util.h"
49 #include "tensorflow/compiler/xla/window_util.h"
50 #include "tensorflow/compiler/xla/xla_data.pb.h"
51 #include "tensorflow/core/platform/errors.h"
52 #include "tensorflow/core/platform/macros.h"
53 #include "tensorflow/stream_executor/lib/statusor.h"
54 
55 namespace xla {
56 
57 using absl::StrCat;
58 
59 namespace {
60 
61 static const char kNameSeparator = '.';
62 
63 // Retrieves the base name of an instruction or computation fully qualified
64 // name, using separator as boundary between the initial base name part, and
65 // the numeric identification.
GetBaseName(const string & name,char separator)66 string GetBaseName(const string& name, char separator) {
67   auto pos = name.rfind(separator);
68   CHECK_NE(pos, string::npos) << name;
69   return name.substr(0, pos);
70 }
71 
72 // Generates a fully qualified computation/instruction name.
GetFullName(const string & base_name,char separator,int64_t id)73 string GetFullName(const string& base_name, char separator, int64_t id) {
74   const char separator_str[] = {separator, '\0'};
75   return StrCat(base_name, separator_str, id);
76 }
77 
78 // Common function to standardize setting name and IDs on computation and
79 // instruction proto entities.
80 template <typename T>
SetProtoIdAndName(T * entry,const string & base_name,char separator,int64_t id)81 void SetProtoIdAndName(T* entry, const string& base_name, char separator,
82                        int64_t id) {
83   entry->set_id(id);
84   entry->set_name(GetFullName(base_name, separator, id));
85 }
86 
InstrIsSetBound(const HloInstructionProto * instr_proto)87 bool InstrIsSetBound(const HloInstructionProto* instr_proto) {
88   HloOpcode opcode = StringToHloOpcode(instr_proto->opcode()).ValueOrDie();
89   if (opcode == HloOpcode::kCustomCall &&
90       instr_proto->custom_call_target() == "SetBound") {
91     return true;
92   }
93   return false;
94 }
95 
96 }  // namespace
97 
98 namespace internal {
99 
BuildFusion(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::string_view fusion_kind,const XlaComputation & fused_computation)100 XlaOp XlaBuilderFriend::BuildFusion(XlaBuilder* builder,
101                                     absl::Span<const XlaOp> operands,
102                                     absl::string_view fusion_kind,
103                                     const XlaComputation& fused_computation) {
104   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
105     HloInstructionProto instr;
106     instr.set_fusion_kind(std::string(fusion_kind));
107     std::vector<const Shape*> operand_shape_ptrs;
108     TF_ASSIGN_OR_RETURN(auto program_shape,
109                         fused_computation.GetProgramShape());
110     *instr.mutable_shape() = program_shape.result().ToProto();
111     builder->AddCalledComputation(fused_computation, &instr);
112     return builder->AddInstruction(std::move(instr), HloOpcode::kFusion,
113                                    operands);
114   });
115 }
116 
BuildBitcast(XlaBuilder * builder,XlaOp operand,const Shape & shape)117 XlaOp XlaBuilderFriend::BuildBitcast(XlaBuilder* builder, XlaOp operand,
118                                      const Shape& shape) {
119   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
120     HloInstructionProto instr;
121     *instr.mutable_shape() = shape.ToProto();
122     return builder->AddInstruction(std::move(instr), HloOpcode::kBitcast,
123                                    {operand});
124   });
125 }
126 
GetInstruction(XlaOp op)127 HloInstructionProto* XlaBuilderFriend::GetInstruction(XlaOp op) {
128   return &op.builder()
129               ->instructions_[op.builder()->handle_to_index_[op.handle_]];
130 }
131 
132 }  // namespace internal
133 
operator -(XlaOp x)134 XlaOp operator-(XlaOp x) { return Neg(x); }
operator +(XlaOp x,XlaOp y)135 XlaOp operator+(XlaOp x, XlaOp y) { return Add(x, y); }
operator -(XlaOp x,XlaOp y)136 XlaOp operator-(XlaOp x, XlaOp y) { return Sub(x, y); }
operator *(XlaOp x,XlaOp y)137 XlaOp operator*(XlaOp x, XlaOp y) { return Mul(x, y); }
operator /(XlaOp x,XlaOp y)138 XlaOp operator/(XlaOp x, XlaOp y) { return Div(x, y); }
operator %(XlaOp x,XlaOp y)139 XlaOp operator%(XlaOp x, XlaOp y) { return Rem(x, y); }
140 
operator ~(XlaOp x)141 XlaOp operator~(XlaOp x) { return Not(x); }
operator &(XlaOp x,XlaOp y)142 XlaOp operator&(XlaOp x, XlaOp y) { return And(x, y); }
operator |(XlaOp x,XlaOp y)143 XlaOp operator|(XlaOp x, XlaOp y) { return Or(x, y); }
operator ^(XlaOp x,XlaOp y)144 XlaOp operator^(XlaOp x, XlaOp y) { return Xor(x, y); }
operator <<(XlaOp x,XlaOp y)145 XlaOp operator<<(XlaOp x, XlaOp y) { return ShiftLeft(x, y); }
146 
operator >>(XlaOp x,XlaOp y)147 XlaOp operator>>(XlaOp x, XlaOp y) {
148   XlaBuilder* builder = x.builder();
149   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
150     TF_ASSIGN_OR_RETURN(const xla::Shape* shape, builder->GetShapePtr(x));
151     if (!ShapeUtil::ElementIsIntegral(*shape)) {
152       return InvalidArgument(
153           "Argument to >> operator does not have an integral type (%s).",
154           ShapeUtil::HumanString(*shape));
155     }
156     if (ShapeUtil::ElementIsSigned(*shape)) {
157       return ShiftRightArithmetic(x, y);
158     } else {
159       return ShiftRightLogical(x, y);
160     }
161   });
162 }
163 
GetShapePtr(XlaOp op) const164 StatusOr<const Shape*> XlaBuilder::GetShapePtr(XlaOp op) const {
165   TF_RETURN_IF_ERROR(first_error_);
166   TF_RETURN_IF_ERROR(CheckOpBuilder(op));
167   auto it = handle_to_index_.find(op.handle());
168   if (it == handle_to_index_.end()) {
169     return InvalidArgument("No XlaOp with handle %d", op.handle());
170   }
171   return instruction_shapes_.at(it->second).get();
172 }
173 
GetShape(XlaOp op) const174 StatusOr<Shape> XlaBuilder::GetShape(XlaOp op) const {
175   TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(op));
176   return *shape;
177 }
178 
GetOperandShapes(absl::Span<const XlaOp> operands) const179 StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
180     absl::Span<const XlaOp> operands) const {
181   std::vector<Shape> operand_shapes;
182   for (XlaOp operand : operands) {
183     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
184     operand_shapes.push_back(*shape);
185   }
186   return operand_shapes;
187 }
188 
OpToString(XlaOp op) const189 std::string XlaBuilder::OpToString(XlaOp op) const {
190   std::string s;
191   ToStringHelper(&s, /*ident=*/0, op.handle());
192   return s;
193 }
194 
ShapeToString(const xla::ShapeProto & shape)195 static std::string ShapeToString(const xla::ShapeProto& shape) {
196   if (shape.tuple_shapes_size() > 1) {
197     return absl::StrCat(
198         "(",
199         absl::StrJoin(shape.tuple_shapes(), ", ",
200                       [&](std::string* s, const xla::ShapeProto& subshape) {
201                         absl::StrAppend(s, ShapeToString(subshape));
202                       }),
203         ")");
204   }
205   return absl::StrCat("[", absl::StrJoin(shape.dimensions(), ", "), "]");
206 }
207 
ToStringHelper(std::string * out,int ident,int64_t op_handle) const208 void XlaBuilder::ToStringHelper(std::string* out, int ident,
209                                 int64_t op_handle) const {
210   const HloInstructionProto& instr =
211       *(LookUpInstructionByHandle(op_handle).ValueOrDie());
212   absl::StrAppend(out, std::string(ident, ' '), instr.opcode(),
213                   ", shape=", ShapeToString(instr.shape()));
214   if (instr.has_metadata()) {
215     absl::StrAppend(out, ", metadata={", instr.metadata().source_file(), ":",
216                     instr.metadata().source_line(), "}");
217   }
218   if (instr.operand_ids_size()) {
219     absl::StrAppend(out, "\n");
220   }
221   absl::StrAppend(out, absl::StrJoin(instr.operand_ids(), "\n",
222                                      [&](std::string* s, int64_t subop) {
223                                        ToStringHelper(s, ident + 2, subop);
224                                      }));
225 }
226 
XlaBuilder(const string & computation_name)227 XlaBuilder::XlaBuilder(const string& computation_name)
228     : name_(computation_name) {}
229 
~XlaBuilder()230 XlaBuilder::~XlaBuilder() {}
231 
ReportError(const Status & error)232 XlaOp XlaBuilder::ReportError(const Status& error) {
233   CHECK(!error.ok());
234   if (die_immediately_on_error_) {
235     LOG(FATAL) << "error building computation: " << error;
236   }
237 
238   if (first_error_.ok()) {
239     first_error_ = error;
240     first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
241   }
242   return XlaOp(this);
243 }
244 
ReportErrorOrReturn(const StatusOr<XlaOp> & op)245 XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr<XlaOp>& op) {
246   if (!first_error_.ok()) {
247     return XlaOp(this);
248   }
249   if (!op.ok()) {
250     return ReportError(op.status());
251   }
252   return op.ValueOrDie();
253 }
254 
ReportErrorOrReturn(const std::function<StatusOr<XlaOp> ()> & op_creator)255 XlaOp XlaBuilder::ReportErrorOrReturn(
256     const std::function<StatusOr<XlaOp>()>& op_creator) {
257   return ReportErrorOrReturn(op_creator());
258 }
259 
GetProgramShape(int64_t root_id) const260 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64_t root_id) const {
261   TF_RETURN_IF_ERROR(first_error_);
262   TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
263                       LookUpInstructionByHandle(root_id));
264 
265   ProgramShape program_shape;
266 
267   *program_shape.mutable_result() = Shape(root_proto->shape());
268 
269   // Check that the parameter numbers are continuous from 0, and add parameter
270   // shapes and names to the program shape.
271   const int64_t param_count = parameter_numbers_.size();
272   for (int64_t i = 0; i < param_count; i++) {
273     program_shape.add_parameters();
274     program_shape.add_parameter_names();
275   }
276   for (const HloInstructionProto& instr : instructions_) {
277     // Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So
278     // to verify continuity, we just need to verify that every parameter is in
279     // the right range.
280     if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
281       const int64_t index = instr.parameter_number();
282       TF_RET_CHECK(index >= 0 && index < param_count)
283           << "invalid parameter number: " << index;
284       *program_shape.mutable_parameters(index) = Shape(instr.shape());
285       *program_shape.mutable_parameter_names(index) = instr.name();
286     }
287   }
288   return program_shape;
289 }
290 
GetProgramShape() const291 StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const {
292   TF_RET_CHECK(!instructions_.empty());
293   return GetProgramShape(instructions_.back().id());
294 }
295 
GetProgramShape(XlaOp root) const296 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(XlaOp root) const {
297   if (root.builder_ != this) {
298     return InvalidArgument("Given root operation is not in this computation.");
299   }
300   return GetProgramShape(root.handle());
301 }
302 
IsConstantVisitor(const int64_t op_handle,int depth,absl::flat_hash_set<int64> * visited,bool * is_constant) const303 void XlaBuilder::IsConstantVisitor(const int64_t op_handle, int depth,
304                                    absl::flat_hash_set<int64>* visited,
305                                    bool* is_constant) const {
306   if (visited->contains(op_handle) || !*is_constant) {
307     return;
308   }
309 
310   const HloInstructionProto& instr =
311       *(LookUpInstructionByHandle(op_handle).ValueOrDie());
312   HloInstructionProto to_print(instr);
313   to_print.clear_shape();
314   const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
315   const string indent =
316       absl::StrJoin(std::vector<absl::string_view>(depth, "  "), "");
317   if (VLOG_IS_ON(2)) {
318     VLOG(2) << indent << "Visiting:";
319     for (const auto& l : absl::StrSplit(to_print.DebugString(), '\n')) {
320       VLOG(2) << indent << l;
321     }
322   }
323   switch (opcode) {
324     default:
325       for (const int64_t operand_id : instr.operand_ids()) {
326         IsConstantVisitor(operand_id, depth + 1, visited, is_constant);
327       }
328       // TODO(b/32495713): We aren't checking the called computations.
329       break;
330 
331     case HloOpcode::kGetDimensionSize:
332       // GetDimensionSize is always considered constant in XLA -- If a dynamic
333       // dimension is presented, -1 is returned.
334       break;
335     // Non functional ops.
336     case HloOpcode::kRng:
337     case HloOpcode::kAllReduce:
338     case HloOpcode::kReduceScatter:
339       // TODO(b/33009255): Implement constant folding for cross replica sum.
340     case HloOpcode::kInfeed:
341     case HloOpcode::kOutfeed:
342     case HloOpcode::kCall:
343       // TODO(b/32495713): We aren't checking the to_apply computation itself,
344       // so we conservatively say that computations containing the Call op
345       // cannot be constant.  We cannot set is_functional=false in other similar
346       // cases since we're already relying on IsConstant to return true.
347     case HloOpcode::kCustomCall:
348       if (instr.custom_call_target() == "SetBound") {
349         // Set bound is considered constant -- the bound is used as the value.
350         break;
351       }
352       TF_FALLTHROUGH_INTENDED;
353     case HloOpcode::kWhile:
354       // TODO(b/32495713): We aren't checking the condition and body
355       // computations themselves.
356     case HloOpcode::kScatter:
357       // TODO(b/32495713): We aren't checking the embedded computation in
358       // Scatter.
359     case HloOpcode::kSend:
360     case HloOpcode::kRecv:
361     case HloOpcode::kParameter:
362       *is_constant = false;
363       break;
364     case HloOpcode::kGetTupleElement: {
365       const HloInstructionProto& operand_instr =
366           *(LookUpInstructionByHandle(instr.operand_ids(0)).ValueOrDie());
367       if (HloOpcodeString(HloOpcode::kTuple) == operand_instr.opcode()) {
368         IsConstantVisitor(operand_instr.operand_ids(instr.tuple_index()),
369                           depth + 1, visited, is_constant);
370       } else {
371         for (const int64_t operand_id : instr.operand_ids()) {
372           IsConstantVisitor(operand_id, depth + 1, visited, is_constant);
373         }
374       }
375     }
376   }
377   if (VLOG_IS_ON(1) && !*is_constant) {
378     VLOG(1) << indent << "Non-constant: ";
379     for (const auto& l : absl::StrSplit(to_print.DebugString(), '\n')) {
380       VLOG(1) << indent << l;
381     }
382   }
383   visited->insert(op_handle);
384 }
385 
SetDynamicBinding(int64_t dynamic_size_param_num,ShapeIndex dynamic_size_param_index,int64_t target_param_num,ShapeIndex target_param_index,int64_t target_dim_num)386 Status XlaBuilder::SetDynamicBinding(int64_t dynamic_size_param_num,
387                                      ShapeIndex dynamic_size_param_index,
388                                      int64_t target_param_num,
389                                      ShapeIndex target_param_index,
390                                      int64_t target_dim_num) {
391   bool param_exists = false;
392   for (size_t index = 0; index < instructions_.size(); ++index) {
393     HloInstructionProto& instr = instructions_[index];
394     if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
395         instr.parameter_number() == target_param_num) {
396       param_exists = true;
397       Shape param_shape(instr.shape());
398       Shape* param_shape_ptr = &param_shape;
399       for (int64_t index : target_param_index) {
400         param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index);
401       }
402       param_shape_ptr->set_dynamic_dimension(target_dim_num,
403                                              /*is_dynamic=*/true);
404       *instr.mutable_shape() = param_shape.ToProto();
405       instruction_shapes_[index] =
406           absl::make_unique<Shape>(std::move(param_shape));
407     }
408   }
409   if (!param_exists) {
410     return InvalidArgument(
411         "Asked to mark parameter %lld as dynamic sized parameter, but the "
412         "doesn't exists",
413         target_param_num);
414   }
415 
416   TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind(
417       DynamicParameterBinding::DynamicParameter{dynamic_size_param_num,
418                                                 dynamic_size_param_index},
419       DynamicParameterBinding::DynamicDimension{
420           target_param_num, target_param_index, target_dim_num}));
421   return Status::OK();
422 }
423 
SetInstructionFrontendAttribute(const XlaOp op,std::string attribute,std::string value)424 Status XlaBuilder::SetInstructionFrontendAttribute(const XlaOp op,
425                                                    std::string attribute,
426                                                    std::string value) {
427   TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op));
428   auto* frontend_attributes = instr_proto->mutable_frontend_attributes();
429   (*frontend_attributes->mutable_map())[attribute] = std::move(value);
430   return Status::OK();
431 }
432 
BuildAndNoteError()433 XlaComputation XlaBuilder::BuildAndNoteError() {
434   DCHECK(parent_builder_ != nullptr);
435   auto build_status = Build();
436   if (!build_status.ok()) {
437     parent_builder_->ReportError(
438         AddStatus(build_status.status(), absl::StrCat("error from: ", name_)));
439     return {};
440   }
441   return build_status.ConsumeValueOrDie();
442 }
443 
GetCurrentStatus() const444 Status XlaBuilder::GetCurrentStatus() const {
445   if (!first_error_.ok()) {
446     string backtrace;
447     first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
448     return AppendStatus(first_error_, backtrace);
449   }
450   return Status::OK();
451 }
452 
Build(bool remove_dynamic_dimensions)453 StatusOr<XlaComputation> XlaBuilder::Build(bool remove_dynamic_dimensions) {
454   TF_RETURN_IF_ERROR(GetCurrentStatus());
455   return Build(instructions_.back().id(), remove_dynamic_dimensions);
456 }
457 
Build(XlaOp root,bool remove_dynamic_dimensions)458 StatusOr<XlaComputation> XlaBuilder::Build(XlaOp root,
459                                            bool remove_dynamic_dimensions) {
460   if (root.builder_ != this) {
461     return InvalidArgument("Given root operation is not in this computation.");
462   }
463   return Build(root.handle(), remove_dynamic_dimensions);
464 }
465 
Build(int64_t root_id,bool remove_dynamic_dimensions)466 StatusOr<XlaComputation> XlaBuilder::Build(int64_t root_id,
467                                            bool remove_dynamic_dimensions) {
468   TF_RETURN_IF_ERROR(GetCurrentStatus());
469 
470   // TODO(b/121223198): XLA backend cannot handle dynamic dimensions yet, remove
471   // all dynamic dimensions before building xla program until we have support in
472   // the backend.
473   if (remove_dynamic_dimensions) {
474     std::function<void(Shape*)> remove_dynamic_dimension = [&](Shape* shape) {
475       if (shape->tuple_shapes_size() != 0) {
476         for (int64_t i = 0; i < shape->tuple_shapes_size(); ++i) {
477           remove_dynamic_dimension(shape->mutable_tuple_shapes(i));
478         }
479       }
480       for (int64_t i = 0; i < shape->dimensions_size(); ++i) {
481         shape->set_dynamic_dimension(i, false);
482       }
483     };
484     for (size_t index = 0; index < instructions_.size(); ++index) {
485       remove_dynamic_dimension(instruction_shapes_[index].get());
486       *instructions_[index].mutable_shape() =
487           instruction_shapes_[index]->ToProto();
488     }
489   }
490 
491   HloComputationProto entry;
492   SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId());
493   TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id));
494   *entry.mutable_program_shape() = program_shape.ToProto();
495   entry.set_root_id(root_id);
496 
497   for (auto& instruction : instructions_) {
498     // Ensures that the instruction names are unique among the whole graph.
499     instruction.set_name(
500         GetFullName(instruction.name(), kNameSeparator, instruction.id()));
501     entry.add_instructions()->Swap(&instruction);
502   }
503 
504   XlaComputation computation(entry.id());
505   HloModuleProto* module = computation.mutable_proto();
506   module->set_name(entry.name());
507   module->set_id(entry.id());
508   module->set_entry_computation_name(entry.name());
509   module->set_entry_computation_id(entry.id());
510   *module->mutable_host_program_shape() = entry.program_shape();
511   for (auto& e : embedded_) {
512     module->add_computations()->Swap(&e.second);
513   }
514   module->add_computations()->Swap(&entry);
515   if (!input_output_aliases_.empty()) {
516     TF_RETURN_IF_ERROR(
517         PopulateInputOutputAlias(module, program_shape, input_output_aliases_));
518   }
519   *(module->mutable_dynamic_parameter_binding()) =
520       dynamic_parameter_binding_.ToProto();
521 
522   // Clear data held by this builder.
523   this->instructions_.clear();
524   this->instruction_shapes_.clear();
525   this->handle_to_index_.clear();
526   this->embedded_.clear();
527   this->parameter_numbers_.clear();
528 
529   return std::move(computation);
530 }
531 
PopulateInputOutputAlias(HloModuleProto * module,const ProgramShape & program_shape,const std::vector<InputOutputAlias> & input_output_aliases)532 /* static */ Status XlaBuilder::PopulateInputOutputAlias(
533     HloModuleProto* module, const ProgramShape& program_shape,
534     const std::vector<InputOutputAlias>& input_output_aliases) {
535   HloInputOutputAliasConfig config(program_shape.result());
536   for (auto& alias : input_output_aliases) {
537     // The HloInputOutputAliasConfig does not do parameter validation as it only
538     // carries the result shape. Maybe it should be constructed with a
539     // ProgramShape to allow full validation. We will still get an error when
540     // trying to compile the HLO module, but would be better to have validation
541     // at this stage.
542     if (alias.param_number >= program_shape.parameters_size()) {
543       return InvalidArgument("Invalid parameter number %ld (total %ld)",
544                              alias.param_number,
545                              program_shape.parameters_size());
546     }
547     const Shape& parameter_shape = program_shape.parameters(alias.param_number);
548     if (!ShapeUtil::IndexIsValid(parameter_shape, alias.param_index)) {
549       return InvalidArgument("Invalid parameter %ld index: %s",
550                              alias.param_number,
551                              alias.param_index.ToString().c_str());
552     }
553     TF_RETURN_IF_ERROR(config.SetUpAlias(alias.output_index, alias.param_number,
554                                          alias.param_index, alias.kind));
555   }
556   *module->mutable_input_output_alias() = config.ToProto();
557   return Status::OK();
558 }
559 
InDimBroadcast(const Shape & shape,XlaOp operand,absl::Span<const int64> broadcast_dimensions)560 StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
561     const Shape& shape, XlaOp operand,
562     absl::Span<const int64> broadcast_dimensions) {
563   TF_RETURN_IF_ERROR(first_error_);
564 
565   HloInstructionProto instr;
566   *instr.mutable_shape() = shape.ToProto();
567   for (int64_t dim : broadcast_dimensions) {
568     instr.add_dimensions(dim);
569   }
570 
571   return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand});
572 }
573 
AddBroadcastSequence(const Shape & output_shape,XlaOp operand)574 StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
575                                                  XlaOp operand) {
576   TF_RETURN_IF_ERROR(first_error_);
577 
578   TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
579 
580   CHECK(ShapeUtil::IsScalar(*operand_shape) ||
581         operand_shape->rank() == output_shape.rank());
582   Shape broadcast_shape =
583       ShapeUtil::ChangeElementType(output_shape, operand_shape->element_type());
584 
585   // Do explicit broadcast for scalar.
586   if (ShapeUtil::IsScalar(*operand_shape)) {
587     return InDimBroadcast(broadcast_shape, operand, {});
588   }
589 
590   // Do explicit broadcast for degenerate broadcast.
591   std::vector<int64> broadcast_dimensions;
592   std::vector<int64> reshaped_dimensions;
593   for (int i = 0; i < operand_shape->rank(); i++) {
594     if (operand_shape->dimensions(i) == output_shape.dimensions(i)) {
595       broadcast_dimensions.push_back(i);
596       reshaped_dimensions.push_back(operand_shape->dimensions(i));
597     } else {
598       TF_RET_CHECK(operand_shape->dimensions(i) == 1)
599           << "An explicit broadcast sequence requires the broadcasted "
600              "dimensions to be trivial; operand shape: "
601           << *operand_shape << "; output_shape: " << output_shape;
602     }
603   }
604 
605   Shape reshaped_shape =
606       ShapeUtil::MakeShape(operand_shape->element_type(), reshaped_dimensions);
607 
608   std::vector<std::pair<int64, int64>> unmodified_dims =
609       ShapeUtil::DimensionsUnmodifiedByReshape(*operand_shape, reshaped_shape);
610 
611   for (auto& unmodified : unmodified_dims) {
612     if (operand_shape->is_dynamic_dimension(unmodified.first)) {
613       reshaped_shape.set_dynamic_dimension(unmodified.second, true);
614     }
615   }
616 
617   // Eliminate the size one dimensions.
618   TF_ASSIGN_OR_RETURN(
619       XlaOp reshaped_operand,
620       ReshapeInternal(reshaped_shape, operand, /*inferred_dimension=*/-1));
621   // Broadcast 'reshape' up to the larger size.
622   return InDimBroadcast(broadcast_shape, reshaped_operand,
623                         broadcast_dimensions);
624 }
625 
UnaryOp(HloOpcode unop,XlaOp operand)626 XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
627   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
628     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
629     TF_ASSIGN_OR_RETURN(
630         Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape));
631     return AddOpWithShape(unop, shape, {operand});
632   });
633 }
634 
BinaryOp(HloOpcode binop,XlaOp lhs,XlaOp rhs,absl::Span<const int64> broadcast_dimensions,absl::optional<ComparisonDirection> direction,absl::optional<Comparison::Type> type)635 XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
636                            absl::Span<const int64> broadcast_dimensions,
637                            absl::optional<ComparisonDirection> direction,
638                            absl::optional<Comparison::Type> type) {
639   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
640     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
641     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
642     TF_ASSIGN_OR_RETURN(
643         Shape shape, ShapeInference::InferBinaryOpShape(
644                          binop, *lhs_shape, *rhs_shape, broadcast_dimensions));
645 
646     const int64_t lhs_rank = lhs_shape->rank();
647     const int64_t rhs_rank = rhs_shape->rank();
648 
649     XlaOp updated_lhs = lhs;
650     XlaOp updated_rhs = rhs;
651     if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) {
652       const bool should_broadcast_lhs = lhs_rank < rhs_rank;
653       XlaOp from = should_broadcast_lhs ? lhs : rhs;
654       const Shape& from_shape = should_broadcast_lhs ? *lhs_shape : *rhs_shape;
655 
656       std::vector<int64> to_size;
657       std::vector<bool> to_size_is_dynamic;
658       for (int i = 0; i < shape.rank(); i++) {
659         to_size.push_back(shape.dimensions(i));
660         to_size_is_dynamic.push_back(shape.is_dynamic_dimension(i));
661       }
662       for (int64_t from_dim = 0; from_dim < from_shape.rank(); from_dim++) {
663         int64_t to_dim = broadcast_dimensions[from_dim];
664         to_size[to_dim] = from_shape.dimensions(from_dim);
665         to_size_is_dynamic[to_dim] = from_shape.is_dynamic_dimension(from_dim);
666       }
667 
668       const Shape& broadcasted_shape = ShapeUtil::MakeShape(
669           from_shape.element_type(), to_size, to_size_is_dynamic);
670       TF_ASSIGN_OR_RETURN(
671           XlaOp broadcasted_operand,
672           InDimBroadcast(broadcasted_shape, from, broadcast_dimensions));
673 
674       updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs;
675       updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs;
676     }
677 
678     TF_ASSIGN_OR_RETURN(const Shape* updated_lhs_shape,
679                         GetShapePtr(updated_lhs));
680     if (!ShapeUtil::SameDimensions(shape, *updated_lhs_shape)) {
681       TF_ASSIGN_OR_RETURN(updated_lhs,
682                           AddBroadcastSequence(shape, updated_lhs));
683     }
684     TF_ASSIGN_OR_RETURN(const Shape* updated_rhs_shape,
685                         GetShapePtr(updated_rhs));
686     if (!ShapeUtil::SameDimensions(shape, *updated_rhs_shape)) {
687       TF_ASSIGN_OR_RETURN(updated_rhs,
688                           AddBroadcastSequence(shape, updated_rhs));
689     }
690 
691     if (binop == HloOpcode::kCompare) {
692       if (!direction.has_value()) {
693         return InvalidArgument(
694             "kCompare expects a ComparisonDirection, but none provided.");
695       }
696       if (type == absl::nullopt) {
697         return Compare(shape, updated_lhs, updated_rhs, *direction);
698       } else {
699         return Compare(shape, updated_lhs, updated_rhs, *direction, *type);
700       }
701     }
702 
703     if (direction.has_value()) {
704       return InvalidArgument(
705           "A comparison direction is provided for a non-compare opcode: %s.",
706           HloOpcodeString(binop));
707     }
708     return BinaryOpNoBroadcast(binop, shape, updated_lhs, updated_rhs);
709   });
710 }
711 
BinaryOpNoBroadcast(HloOpcode binop,const Shape & shape,XlaOp lhs,XlaOp rhs)712 XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
713                                       XlaOp lhs, XlaOp rhs) {
714   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
715     HloInstructionProto instr;
716     *instr.mutable_shape() = shape.ToProto();
717     return AddInstruction(std::move(instr), binop, {lhs, rhs});
718   });
719 }
720 
Compare(const Shape & shape,XlaOp lhs,XlaOp rhs,ComparisonDirection direction)721 StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
722                                     ComparisonDirection direction) {
723   TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(lhs));
724   return Compare(
725       shape, lhs, rhs, direction,
726       Comparison::DefaultComparisonType(operand_shape.element_type()));
727 }
728 
Compare(const Shape & shape,XlaOp lhs,XlaOp rhs,ComparisonDirection direction,Comparison::Type type)729 StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
730                                     ComparisonDirection direction,
731                                     Comparison::Type type) {
732   HloInstructionProto instr;
733   instr.set_comparison_direction(ComparisonDirectionToString(direction));
734   instr.set_comparison_type(ComparisonTypeToString(type));
735   *instr.mutable_shape() = shape.ToProto();
736   return AddInstruction(std::move(instr), HloOpcode::kCompare, {lhs, rhs});
737 }
738 
TernaryOp(HloOpcode triop,XlaOp lhs,XlaOp rhs,XlaOp ehs)739 XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) {
740   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
741     XlaOp updated_lhs = lhs;
742     XlaOp updated_rhs = rhs;
743     XlaOp updated_ehs = ehs;
744     // The client API supports implicit broadcast for kSelect and kClamp, but
745     // XLA does not support implicit broadcast. Make implicit broadcast explicit
746     // and update the operands.
747     if (triop == HloOpcode::kSelect || triop == HloOpcode::kClamp) {
748       TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
749       TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
750       TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(ehs));
751 
752       absl::optional<Shape> non_scalar_shape;
753       for (const Shape* shape : {lhs_shape, rhs_shape, ehs_shape}) {
754         if (shape->IsArray() && shape->rank() != 0) {
755           if (non_scalar_shape.has_value()) {
756             // TODO(jpienaar): The case where we need to compute the broadcasted
757             // shape by considering multiple of the shapes is not implemented.
758             // Consider reusing getBroadcastedType from mlir/Dialect/Traits.h.
759             TF_RET_CHECK(non_scalar_shape.value().dimensions() ==
760                          shape->dimensions())
761                 << "Unimplemented implicit broadcast.";
762           } else {
763             non_scalar_shape = *shape;
764           }
765         }
766       }
767       if (non_scalar_shape.has_value()) {
768         if (ShapeUtil::IsScalar(*lhs_shape)) {
769           TF_ASSIGN_OR_RETURN(updated_lhs,
770                               AddBroadcastSequence(*non_scalar_shape, lhs));
771         }
772         if (ShapeUtil::IsScalar(*rhs_shape)) {
773           TF_ASSIGN_OR_RETURN(updated_rhs,
774                               AddBroadcastSequence(*non_scalar_shape, rhs));
775         }
776         if (ShapeUtil::IsScalar(*ehs_shape)) {
777           TF_ASSIGN_OR_RETURN(updated_ehs,
778                               AddBroadcastSequence(*non_scalar_shape, ehs));
779         }
780       }
781     }
782 
783     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(updated_lhs));
784     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(updated_rhs));
785     TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(updated_ehs));
786     StatusOr<const Shape> status_or_shape = ShapeInference::InferTernaryOpShape(
787         triop, *lhs_shape, *rhs_shape, *ehs_shape);
788     if (!status_or_shape.status().ok()) {
789       return InvalidArgument(
790           "%s Input scalar shapes may have been changed to non-scalar shapes.",
791           status_or_shape.status().error_message());
792     }
793 
794     return AddOpWithShape(triop, status_or_shape.ValueOrDie(),
795                           {updated_lhs, updated_rhs, updated_ehs});
796   });
797 }
798 
ConstantLiteral(const LiteralSlice & literal)799 XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
800   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
801     if (literal.shape().IsArray() && literal.element_count() > 1 &&
802         literal.IsAllFirst()) {
803       Literal scalar = LiteralUtil::GetFirstScalarLiteral(literal);
804       HloInstructionProto instr;
805       *instr.mutable_shape() = scalar.shape().ToProto();
806       *instr.mutable_literal() = scalar.ToProto();
807       TF_ASSIGN_OR_RETURN(
808           XlaOp scalar_op,
809           AddInstruction(std::move(instr), HloOpcode::kConstant));
810       return Broadcast(scalar_op, literal.shape().dimensions());
811     } else {
812       HloInstructionProto instr;
813       *instr.mutable_shape() = literal.shape().ToProto();
814       *instr.mutable_literal() = literal.ToProto();
815       return AddInstruction(std::move(instr), HloOpcode::kConstant);
816     }
817   });
818 }
819 
Iota(const Shape & shape,int64_t iota_dimension)820 XlaOp XlaBuilder::Iota(const Shape& shape, int64_t iota_dimension) {
821   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
822     HloInstructionProto instr;
823     *instr.mutable_shape() = shape.ToProto();
824     instr.add_dimensions(iota_dimension);
825     return AddInstruction(std::move(instr), HloOpcode::kIota);
826   });
827 }
828 
Iota(PrimitiveType type,int64_t size)829 XlaOp XlaBuilder::Iota(PrimitiveType type, int64_t size) {
830   return Iota(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0);
831 }
832 
Call(const XlaComputation & computation,absl::Span<const XlaOp> operands)833 XlaOp XlaBuilder::Call(const XlaComputation& computation,
834                        absl::Span<const XlaOp> operands) {
835   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
836     HloInstructionProto instr;
837     std::vector<const Shape*> operand_shape_ptrs;
838     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
839     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
840                       [](const Shape& shape) { return &shape; });
841     TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
842                         computation.GetProgramShape());
843     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape(
844                                          operand_shape_ptrs,
845                                          /*to_apply=*/called_program_shape));
846     *instr.mutable_shape() = shape.ToProto();
847 
848     AddCalledComputation(computation, &instr);
849 
850     return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
851   });
852 }
853 
Parameter(int64_t parameter_number,const Shape & shape,const string & name,const std::vector<bool> & replicated_at_leaf_buffers)854 XlaOp XlaBuilder::Parameter(
855     int64_t parameter_number, const Shape& shape, const string& name,
856     const std::vector<bool>& replicated_at_leaf_buffers) {
857   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
858     HloInstructionProto instr;
859     if (!parameter_numbers_.insert(parameter_number).second) {
860       return InvalidArgument("parameter %d already registered",
861                              parameter_number);
862     }
863     instr.set_parameter_number(parameter_number);
864     instr.set_name(name);
865     *instr.mutable_shape() = shape.ToProto();
866     if (!replicated_at_leaf_buffers.empty()) {
867       auto replication = instr.mutable_parameter_replication();
868       for (bool replicated : replicated_at_leaf_buffers) {
869         replication->add_replicated_at_leaf_buffers(replicated);
870       }
871     }
872     return AddInstruction(std::move(instr), HloOpcode::kParameter);
873   });
874 }
875 
Broadcast(XlaOp operand,absl::Span<const int64> broadcast_sizes)876 XlaOp XlaBuilder::Broadcast(XlaOp operand,
877                             absl::Span<const int64> broadcast_sizes) {
878   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
879     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
880     TF_ASSIGN_OR_RETURN(
881         const Shape& shape,
882         ShapeInference::InferBroadcastShape(*operand_shape, broadcast_sizes));
883 
884     // The client-level broadcast op just appends dimensions on the left (adds
885     // lowest numbered dimensions). The HLO broadcast instruction is more
886     // flexible and can add new dimensions anywhere. The instruction's
887     // dimensions field maps operand dimensions to dimensions in the broadcast
888     // output, so to append dimensions on the left the instruction's dimensions
889     // should just be the n highest dimension numbers of the output shape where
890     // n is the number of input dimensions.
891     const int64_t operand_rank = operand_shape->rank();
892     std::vector<int64> dimensions(operand_rank);
893     for (int i = 0; i < operand_rank; ++i) {
894       dimensions[i] = i + shape.rank() - operand_rank;
895     }
896     return InDimBroadcast(shape, operand, dimensions);
897   });
898 }
899 
BroadcastInDim(XlaOp operand,const absl::Span<const int64> out_dim_size,const absl::Span<const int64> broadcast_dimensions)900 XlaOp XlaBuilder::BroadcastInDim(
901     XlaOp operand, const absl::Span<const int64> out_dim_size,
902     const absl::Span<const int64> broadcast_dimensions) {
903   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
904     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
905     // Output shape, in the case of degenerate broadcast, the out_dim_size is
906     // not necessarily the same as the dimension sizes of the output shape.
907     TF_ASSIGN_OR_RETURN(auto output_shape,
908                         ShapeUtil::MakeValidatedShape(
909                             operand_shape->element_type(), out_dim_size));
910     int64_t broadcast_rank = broadcast_dimensions.size();
911     if (operand_shape->rank() != broadcast_rank) {
912       return InvalidArgument(
913           "Size of broadcast_dimensions has to match operand's rank; operand "
914           "rank: %lld, size of broadcast_dimensions %u.",
915           operand_shape->rank(), broadcast_dimensions.size());
916     }
917     for (int i = 0; i < broadcast_rank; i++) {
918       const int64_t num_dims = out_dim_size.size();
919       if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] > num_dims) {
920         return InvalidArgument("Broadcast dimension %lld is out of bound",
921                                broadcast_dimensions[i]);
922       }
923       output_shape.set_dynamic_dimension(
924           broadcast_dimensions[i], operand_shape->is_dynamic_dimension(i));
925     }
926 
927     TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape(
928                            *operand_shape, output_shape, broadcast_dimensions)
929                            .status());
930     std::vector<int64> in_dim_size(out_dim_size.begin(), out_dim_size.end());
931     for (int i = 0; i < broadcast_rank; i++) {
932       in_dim_size[broadcast_dimensions[i]] = operand_shape->dimensions(i);
933     }
934     const auto& in_dim_shape =
935         ShapeUtil::MakeShape(operand_shape->element_type(), in_dim_size);
936     TF_ASSIGN_OR_RETURN(
937         XlaOp in_dim_broadcast,
938         InDimBroadcast(in_dim_shape, operand, broadcast_dimensions));
939 
940     // If broadcast is not degenerate, return broadcasted result.
941     if (ShapeUtil::Equal(in_dim_shape, output_shape)) {
942       return in_dim_broadcast;
943     }
944 
945     // Otherwise handle degenerate broadcast case.
946     return AddBroadcastSequence(output_shape, in_dim_broadcast);
947   });
948 }
949 
ReshapeInternal(const Shape & shape,XlaOp operand,int64_t inferred_dimension)950 StatusOr<XlaOp> XlaBuilder::ReshapeInternal(const Shape& shape, XlaOp operand,
951                                             int64_t inferred_dimension) {
952   TF_RETURN_IF_ERROR(first_error_);
953 
954   HloInstructionProto instr;
955   *instr.mutable_shape() = shape.ToProto();
956   if (inferred_dimension != -1) {
957     instr.add_dimensions(inferred_dimension);
958   }
959   return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand});
960 }
961 
Slice(XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)962 XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span<const int64> start_indices,
963                         absl::Span<const int64> limit_indices,
964                         absl::Span<const int64> strides) {
965   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
966     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
967     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSliceShape(
968                                          *operand_shape, start_indices,
969                                          limit_indices, strides));
970     return SliceInternal(shape, operand, start_indices, limit_indices, strides);
971   });
972 }
973 
SliceInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)974 StatusOr<XlaOp> XlaBuilder::SliceInternal(const Shape& shape, XlaOp operand,
975                                           absl::Span<const int64> start_indices,
976                                           absl::Span<const int64> limit_indices,
977                                           absl::Span<const int64> strides) {
978   HloInstructionProto instr;
979   *instr.mutable_shape() = shape.ToProto();
980   for (int i = 0, end = start_indices.size(); i < end; i++) {
981     auto* slice_config = instr.add_slice_dimensions();
982     slice_config->set_start(start_indices[i]);
983     slice_config->set_limit(limit_indices[i]);
984     slice_config->set_stride(strides[i]);
985   }
986   return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
987 }
988 
SliceInDim(XlaOp operand,int64_t start_index,int64_t limit_index,int64_t stride,int64_t dimno)989 XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64_t start_index,
990                              int64_t limit_index, int64_t stride,
991                              int64_t dimno) {
992   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
993     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
994     std::vector<int64> starts(shape->rank(), 0);
995     std::vector<int64> limits(shape->dimensions().begin(),
996                               shape->dimensions().end());
997     std::vector<int64> strides(shape->rank(), 1);
998     starts[dimno] = start_index;
999     limits[dimno] = limit_index;
1000     strides[dimno] = stride;
1001     return Slice(operand, starts, limits, strides);
1002   });
1003 }
1004 
DynamicSlice(XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)1005 XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
1006                                absl::Span<const XlaOp> start_indices,
1007                                absl::Span<const int64> slice_sizes) {
1008   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1009     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1010     std::vector<const Shape*> start_indices_shape_ptrs;
1011     TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
1012                         GetOperandShapes(start_indices));
1013     absl::c_transform(start_indices_shapes,
1014                       std::back_inserter(start_indices_shape_ptrs),
1015                       [](const Shape& shape) { return &shape; });
1016     TF_ASSIGN_OR_RETURN(Shape shape,
1017                         ShapeInference::InferDynamicSliceShape(
1018                             *operand_shape, start_indices_shapes, slice_sizes));
1019     return DynamicSliceInternal(shape, operand, start_indices, slice_sizes);
1020   });
1021 }
1022 
DynamicSliceInternal(const Shape & shape,XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)1023 StatusOr<XlaOp> XlaBuilder::DynamicSliceInternal(
1024     const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
1025     absl::Span<const int64> slice_sizes) {
1026   HloInstructionProto instr;
1027   *instr.mutable_shape() = shape.ToProto();
1028 
1029   for (int64_t size : slice_sizes) {
1030     instr.add_dynamic_slice_sizes(size);
1031   }
1032 
1033   std::vector<XlaOp> operands = {operand};
1034   operands.insert(operands.end(), start_indices.begin(), start_indices.end());
1035   return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands);
1036 }
1037 
DynamicUpdateSlice(XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)1038 XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
1039                                      absl::Span<const XlaOp> start_indices) {
1040   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1041     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1042     TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
1043     std::vector<const Shape*> start_indices_shape_ptrs;
1044     TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
1045                         GetOperandShapes(start_indices));
1046     absl::c_transform(start_indices_shapes,
1047                       std::back_inserter(start_indices_shape_ptrs),
1048                       [](const Shape& shape) { return &shape; });
1049     TF_ASSIGN_OR_RETURN(
1050         Shape shape, ShapeInference::InferDynamicUpdateSliceShape(
1051                          *operand_shape, *update_shape, start_indices_shapes));
1052     return DynamicUpdateSliceInternal(shape, operand, update, start_indices);
1053   });
1054 }
1055 
DynamicUpdateSliceInternal(const Shape & shape,XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)1056 StatusOr<XlaOp> XlaBuilder::DynamicUpdateSliceInternal(
1057     const Shape& shape, XlaOp operand, XlaOp update,
1058     absl::Span<const XlaOp> start_indices) {
1059   HloInstructionProto instr;
1060   *instr.mutable_shape() = shape.ToProto();
1061 
1062   std::vector<XlaOp> operands = {operand, update};
1063   operands.insert(operands.end(), start_indices.begin(), start_indices.end());
1064   return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
1065                         operands);
1066 }
1067 
ConcatInDim(absl::Span<const XlaOp> operands,int64_t dimension)1068 XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
1069                               int64_t dimension) {
1070   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1071     std::vector<const Shape*> operand_shape_ptrs;
1072     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
1073     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1074                       [](const Shape& shape) { return &shape; });
1075     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape(
1076                                          operand_shape_ptrs, dimension));
1077     return ConcatInDimInternal(shape, operands, dimension);
1078   });
1079 }
1080 
ConcatInDimInternal(const Shape & shape,absl::Span<const XlaOp> operands,int64_t dimension)1081 StatusOr<XlaOp> XlaBuilder::ConcatInDimInternal(
1082     const Shape& shape, absl::Span<const XlaOp> operands, int64_t dimension) {
1083   HloInstructionProto instr;
1084   *instr.mutable_shape() = shape.ToProto();
1085 
1086   instr.add_dimensions(dimension);
1087 
1088   return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
1089 }
1090 
Pad(XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)1091 XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value,
1092                       const PaddingConfig& padding_config) {
1093   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1094     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1095     TF_ASSIGN_OR_RETURN(const Shape* padding_value_shape,
1096                         GetShapePtr(padding_value));
1097     TF_ASSIGN_OR_RETURN(
1098         Shape shape, ShapeInference::InferPadShape(
1099                          *operand_shape, *padding_value_shape, padding_config));
1100     return PadInternal(shape, operand, padding_value, padding_config);
1101   });
1102 }
1103 
PadInDim(XlaOp operand,XlaOp padding_value,int64_t dimno,int64_t pad_lo,int64_t pad_hi)1104 XlaOp XlaBuilder::PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
1105                            int64_t pad_lo, int64_t pad_hi) {
1106   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1107     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
1108     PaddingConfig padding_config = MakeNoPaddingConfig(shape->rank());
1109     auto* dims = padding_config.mutable_dimensions(dimno);
1110     dims->set_edge_padding_low(pad_lo);
1111     dims->set_edge_padding_high(pad_hi);
1112     return Pad(operand, padding_value, padding_config);
1113   });
1114 }
1115 
PadInternal(const Shape & shape,XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)1116 StatusOr<XlaOp> XlaBuilder::PadInternal(const Shape& shape, XlaOp operand,
1117                                         XlaOp padding_value,
1118                                         const PaddingConfig& padding_config) {
1119   HloInstructionProto instr;
1120   *instr.mutable_shape() = shape.ToProto();
1121   *instr.mutable_padding_config() = padding_config;
1122   return AddInstruction(std::move(instr), HloOpcode::kPad,
1123                         {operand, padding_value});
1124 }
1125 
Reshape(XlaOp operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes,int64_t inferred_dimension)1126 XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> dimensions,
1127                           absl::Span<const int64> new_sizes,
1128                           int64_t inferred_dimension) {
1129   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1130     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1131     TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferReshapeShape(
1132                                                *operand_shape, dimensions,
1133                                                new_sizes, inferred_dimension));
1134     XlaOp transposed = IsIdentityPermutation(dimensions)
1135                            ? operand
1136                            : Transpose(operand, dimensions);
1137     return ReshapeInternal(shape, transposed, inferred_dimension);
1138   });
1139 }
1140 
Reshape(XlaOp operand,absl::Span<const int64> new_sizes,int64_t inferred_dimension)1141 XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
1142                           int64_t inferred_dimension) {
1143   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1144     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
1145     std::vector<int64> dimensions(shape->dimensions_size());
1146     std::iota(dimensions.begin(), dimensions.end(), 0);
1147     return Reshape(operand, dimensions, new_sizes, inferred_dimension);
1148   });
1149 }
1150 
Reshape(const Shape & shape,XlaOp operand,int64_t inferred_dimension)1151 XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand,
1152                           int64_t inferred_dimension) {
1153   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1154     return ReshapeInternal(shape, operand, inferred_dimension);
1155   });
1156 }
1157 
DynamicReshape(XlaOp operand,absl::Span<const XlaOp> dim_sizes,absl::Span<const int64> new_size_bounds,const std::vector<bool> & dims_are_dynamic)1158 XlaOp XlaBuilder::DynamicReshape(XlaOp operand,
1159                                  absl::Span<const XlaOp> dim_sizes,
1160                                  absl::Span<const int64> new_size_bounds,
1161                                  const std::vector<bool>& dims_are_dynamic) {
1162   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1163     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1164     std::vector<const Shape*> dim_size_shape_ptrs;
1165     TF_ASSIGN_OR_RETURN(const auto& dim_size_shapes,
1166                         GetOperandShapes(dim_sizes));
1167 
1168     absl::c_transform(dim_size_shapes, std::back_inserter(dim_size_shape_ptrs),
1169                       [](const Shape& shape) { return &shape; });
1170     TF_ASSIGN_OR_RETURN(const Shape shape,
1171                         ShapeInference::InferDynamicReshapeShape(
1172                             *operand_shape, dim_size_shape_ptrs,
1173                             new_size_bounds, dims_are_dynamic));
1174     TF_RETURN_IF_ERROR(first_error_);
1175     std::vector<XlaOp> operands;
1176     operands.reserve(1 + dim_sizes.size());
1177     operands.push_back(operand);
1178     for (const XlaOp& dim_size : dim_sizes) {
1179       operands.push_back(dim_size);
1180     }
1181     HloInstructionProto instr;
1182     *instr.mutable_shape() = shape.ToProto();
1183     return AddInstruction(std::move(instr), HloOpcode::kDynamicReshape,
1184                           operands);
1185   });
1186 }
1187 
Collapse(XlaOp operand,absl::Span<const int64> dimensions)1188 XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span<const int64> dimensions) {
1189   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1190     if (dimensions.size() <= 1) {
1191       // Not collapsing anything, trivially we can return the operand versus
1192       // enqueueing a trivial reshape.
1193       return operand;
1194     }
1195 
1196     // Out-of-order collapse is not supported.
1197     // Checks that the collapsed dimensions are in order and consecutive.
1198     for (absl::Span<const int64>::size_type i = 1; i < dimensions.size(); ++i) {
1199       if (dimensions[i] - 1 != dimensions[i - 1]) {
1200         return InvalidArgument(
1201             "Collapsed dimensions are not in consecutive order.");
1202       }
1203     }
1204 
1205     // Create a new sizes vector from the old shape, replacing the collapsed
1206     // dimensions by the product of their sizes.
1207     TF_ASSIGN_OR_RETURN(const Shape* original_shape, GetShapePtr(operand));
1208 
1209     VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape);
1210     VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ",");
1211 
1212     std::vector<int64> new_sizes;
1213     for (int i = 0; i < original_shape->rank(); ++i) {
1214       if (i <= dimensions.front() || i > dimensions.back()) {
1215         new_sizes.push_back(original_shape->dimensions(i));
1216       } else {
1217         new_sizes.back() *= original_shape->dimensions(i);
1218       }
1219     }
1220 
1221     VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]";
1222 
1223     return Reshape(operand, new_sizes);
1224   });
1225 }
1226 
Trace(const string & tag,XlaOp operand)1227 void XlaBuilder::Trace(const string& tag, XlaOp operand) {
1228   ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1229     HloInstructionProto instr;
1230     *instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
1231     *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto();
1232     return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
1233   });
1234 }
1235 
Select(XlaOp pred,XlaOp on_true,XlaOp on_false)1236 XlaOp XlaBuilder::Select(XlaOp pred, XlaOp on_true, XlaOp on_false) {
1237   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1238     TF_ASSIGN_OR_RETURN(const Shape* true_shape, GetShapePtr(on_true));
1239     TF_ASSIGN_OR_RETURN(const Shape* false_shape, GetShapePtr(on_false));
1240     TF_RET_CHECK(true_shape->IsTuple() == false_shape->IsTuple());
1241     HloOpcode opcode =
1242         true_shape->IsTuple() ? HloOpcode::kTupleSelect : HloOpcode::kSelect;
1243     return TernaryOp(opcode, pred, on_true, on_false);
1244   });
1245 }
1246 
Tuple(absl::Span<const XlaOp> elements)1247 XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
1248   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1249     std::vector<const Shape*> operand_shape_ptrs;
1250     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
1251     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1252                       [](const Shape& shape) { return &shape; });
1253     TF_ASSIGN_OR_RETURN(const Shape shape,
1254                         ShapeInference::InferVariadicOpShape(
1255                             HloOpcode::kTuple, operand_shape_ptrs));
1256     return TupleInternal(shape, elements);
1257   });
1258 }
1259 
TupleInternal(const Shape & shape,absl::Span<const XlaOp> elements)1260 StatusOr<XlaOp> XlaBuilder::TupleInternal(const Shape& shape,
1261                                           absl::Span<const XlaOp> elements) {
1262   HloInstructionProto instr;
1263   *instr.mutable_shape() = shape.ToProto();
1264   return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
1265 }
1266 
GetTupleElement(XlaOp tuple_data,int64_t index)1267 XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64_t index) {
1268   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1269     TF_ASSIGN_OR_RETURN(const Shape* tuple_shape, GetShapePtr(tuple_data));
1270     if (!tuple_shape->IsTuple()) {
1271       return InvalidArgument(
1272           "Operand to GetTupleElement() is not a tuple; got %s",
1273           ShapeUtil::HumanString(*tuple_shape));
1274     }
1275     if (index < 0 || index >= ShapeUtil::TupleElementCount(*tuple_shape)) {
1276       return InvalidArgument(
1277           "GetTupleElement() index (%d) out of range for tuple shape %s", index,
1278           ShapeUtil::HumanString(*tuple_shape));
1279     }
1280     return GetTupleElementInternal(
1281         ShapeUtil::GetTupleElementShape(*tuple_shape, index), tuple_data,
1282         index);
1283   });
1284 }
1285 
GetTupleElementInternal(const Shape & shape,XlaOp tuple_data,int64_t index)1286 StatusOr<XlaOp> XlaBuilder::GetTupleElementInternal(const Shape& shape,
1287                                                     XlaOp tuple_data,
1288                                                     int64_t index) {
1289   HloInstructionProto instr;
1290   *instr.mutable_shape() = shape.ToProto();
1291   instr.set_tuple_index(index);
1292   return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
1293                         {tuple_data});
1294 }
1295 
Dot(XlaOp lhs,XlaOp rhs,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1296 XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
1297                       const PrecisionConfig* precision_config,
1298                       absl::optional<PrimitiveType> preferred_element_type) {
1299   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1300     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1301 
1302     DotDimensionNumbers dimension_numbers;
1303     dimension_numbers.add_lhs_contracting_dimensions(
1304         lhs_shape->dimensions_size() == 1 ? 0 : 1);
1305     dimension_numbers.add_rhs_contracting_dimensions(0);
1306     return DotGeneral(lhs, rhs, dimension_numbers, precision_config,
1307                       preferred_element_type);
1308   });
1309 }
1310 
DotGeneral(XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1311 XlaOp XlaBuilder::DotGeneral(
1312     XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
1313     const PrecisionConfig* precision_config,
1314     absl::optional<PrimitiveType> preferred_element_type) {
1315   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1316     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1317     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1318     TF_ASSIGN_OR_RETURN(
1319         Shape shape,
1320         ShapeInference::InferDotOpShape(
1321             *lhs_shape, *rhs_shape, dimension_numbers, preferred_element_type));
1322     return DotGeneralInternal(shape, lhs, rhs, dimension_numbers,
1323                               precision_config);
1324   });
1325 }
1326 
DotGeneralInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config)1327 StatusOr<XlaOp> XlaBuilder::DotGeneralInternal(
1328     const Shape& shape, XlaOp lhs, XlaOp rhs,
1329     const DotDimensionNumbers& dimension_numbers,
1330     const PrecisionConfig* precision_config) {
1331   HloInstructionProto instr;
1332   *instr.mutable_shape() = shape.ToProto();
1333   *instr.mutable_dot_dimension_numbers() = dimension_numbers;
1334   if (precision_config != nullptr) {
1335     *instr.mutable_precision_config() = *precision_config;
1336   }
1337   return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
1338 }
1339 
VerifyConvolution(const Shape & lhs_shape,const Shape & rhs_shape,const ConvolutionDimensionNumbers & dimension_numbers) const1340 Status XlaBuilder::VerifyConvolution(
1341     const Shape& lhs_shape, const Shape& rhs_shape,
1342     const ConvolutionDimensionNumbers& dimension_numbers) const {
1343   if (lhs_shape.rank() != rhs_shape.rank()) {
1344     return InvalidArgument(
1345         "Convolution arguments must have same number of "
1346         "dimensions. Got: %s and %s",
1347         ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1348   }
1349   int num_dims = lhs_shape.rank();
1350   if (num_dims < 2) {
1351     return InvalidArgument(
1352         "Convolution expects argument arrays with >= 3 dimensions. "
1353         "Got: %s and %s",
1354         ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1355   }
1356   int num_spatial_dims = num_dims - 2;
1357 
1358   const auto check_spatial_dimensions =
1359       [&](const char* const field_name,
1360           const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
1361               numbers) {
1362         if (numbers.size() != num_spatial_dims) {
1363           return InvalidArgument("Expected %d elements for %s, but got %d.",
1364                                  num_spatial_dims, field_name, numbers.size());
1365         }
1366         for (int i = 0; i < numbers.size(); ++i) {
1367           if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
1368             return InvalidArgument("Convolution %s[%d] is out of bounds: %d",
1369                                    field_name, i, numbers.Get(i));
1370           }
1371         }
1372         return Status::OK();
1373       };
1374   TF_RETURN_IF_ERROR(
1375       check_spatial_dimensions("input_spatial_dimensions",
1376                                dimension_numbers.input_spatial_dimensions()));
1377   TF_RETURN_IF_ERROR(
1378       check_spatial_dimensions("kernel_spatial_dimensions",
1379                                dimension_numbers.kernel_spatial_dimensions()));
1380   return check_spatial_dimensions(
1381       "output_spatial_dimensions",
1382       dimension_numbers.output_spatial_dimensions());
1383 }
1384 
Conv(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1385 XlaOp XlaBuilder::Conv(XlaOp lhs, XlaOp rhs,
1386                        absl::Span<const int64> window_strides, Padding padding,
1387                        int64_t feature_group_count, int64_t batch_group_count,
1388                        const PrecisionConfig* precision_config,
1389                        absl::optional<PrimitiveType> preferred_element_type) {
1390   return ConvWithGeneralDimensions(
1391       lhs, rhs, window_strides, padding,
1392       CreateDefaultConvDimensionNumbers(window_strides.size()),
1393       feature_group_count, batch_group_count, precision_config,
1394       preferred_element_type);
1395 }
1396 
ConvWithGeneralPadding(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1397 XlaOp XlaBuilder::ConvWithGeneralPadding(
1398     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1399     absl::Span<const std::pair<int64, int64>> padding,
1400     int64_t feature_group_count, int64_t batch_group_count,
1401     const PrecisionConfig* precision_config,
1402     absl::optional<PrimitiveType> preferred_element_type) {
1403   return ConvGeneral(lhs, rhs, window_strides, padding,
1404                      CreateDefaultConvDimensionNumbers(window_strides.size()),
1405                      feature_group_count, batch_group_count, precision_config,
1406                      preferred_element_type);
1407 }
1408 
ConvWithGeneralDimensions(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1409 XlaOp XlaBuilder::ConvWithGeneralDimensions(
1410     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1411     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1412     int64_t feature_group_count, int64_t batch_group_count,
1413     const PrecisionConfig* precision_config,
1414     absl::optional<PrimitiveType> preferred_element_type) {
1415   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1416     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1417     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1418 
1419     TF_RETURN_IF_ERROR(
1420         VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers));
1421 
1422     std::vector<int64> base_area_dimensions(
1423         dimension_numbers.input_spatial_dimensions_size());
1424     for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size();
1425          ++i) {
1426       base_area_dimensions[i] =
1427           lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i));
1428     }
1429 
1430     std::vector<int64> window_dimensions(
1431         dimension_numbers.kernel_spatial_dimensions_size());
1432     for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
1433          ++i) {
1434       window_dimensions[i] =
1435           rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1436     }
1437 
1438     return ConvGeneral(lhs, rhs, window_strides,
1439                        MakePadding(base_area_dimensions, window_dimensions,
1440                                    window_strides, padding),
1441                        dimension_numbers, feature_group_count,
1442                        batch_group_count, precision_config,
1443                        preferred_element_type);
1444   });
1445 }
1446 
ConvGeneral(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1447 XlaOp XlaBuilder::ConvGeneral(
1448     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1449     absl::Span<const std::pair<int64, int64>> padding,
1450     const ConvolutionDimensionNumbers& dimension_numbers,
1451     int64_t feature_group_count, int64_t batch_group_count,
1452     const PrecisionConfig* precision_config,
1453     absl::optional<PrimitiveType> preferred_element_type) {
1454   return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
1455                             dimension_numbers, feature_group_count,
1456                             batch_group_count, precision_config,
1457                             preferred_element_type);
1458 }
1459 
ConvGeneralDilated(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1460 XlaOp XlaBuilder::ConvGeneralDilated(
1461     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1462     absl::Span<const std::pair<int64, int64>> padding,
1463     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1464     const ConvolutionDimensionNumbers& dimension_numbers,
1465     int64_t feature_group_count, int64_t batch_group_count,
1466     const PrecisionConfig* precision_config,
1467     absl::optional<PrimitiveType> preferred_element_type) {
1468   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1469     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1470     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1471     TF_RETURN_IF_ERROR(
1472         VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers));
1473 
1474     std::vector<int64> window_dimensions(
1475         dimension_numbers.kernel_spatial_dimensions_size());
1476     for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
1477          ++i) {
1478       window_dimensions[i] =
1479           rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1480     }
1481 
1482     TF_ASSIGN_OR_RETURN(Window window,
1483                         ShapeInference::InferWindowFromDimensions(
1484                             window_dimensions, window_strides, padding,
1485                             lhs_dilation, rhs_dilation));
1486     TF_ASSIGN_OR_RETURN(
1487         Shape shape,
1488         ShapeInference::InferConvolveShape(
1489             *lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
1490             window, dimension_numbers, preferred_element_type));
1491     return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides,
1492                                       padding, lhs_dilation, rhs_dilation,
1493                                       dimension_numbers, feature_group_count,
1494                                       batch_group_count, precision_config);
1495   });
1496 }
1497 
DynamicConvInstruction(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)1498 StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
1499     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1500     absl::Span<const std::pair<int64, int64>> padding,
1501     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1502     const ConvolutionDimensionNumbers& dimension_numbers,
1503     int64_t feature_group_count, int64_t batch_group_count,
1504     const PrecisionConfig* precision_config, PaddingType padding_type,
1505     absl::optional<PrimitiveType> preferred_element_type) {
1506   TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1507   TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1508   std::vector<int64> window_dimensions(
1509       dimension_numbers.kernel_spatial_dimensions_size());
1510   for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) {
1511     window_dimensions[i] =
1512         rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1513   }
1514 
1515   TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions(
1516                                          window_dimensions, window_strides,
1517                                          padding, lhs_dilation, rhs_dilation));
1518   TF_ASSIGN_OR_RETURN(
1519       Shape shape,
1520       ShapeInference::InferConvolveShape(
1521           *lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
1522           window, dimension_numbers, preferred_element_type));
1523 
1524   HloInstructionProto instr;
1525   *instr.mutable_shape() = shape.ToProto();
1526 
1527   *instr.mutable_window() = window;
1528   *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
1529   instr.set_feature_group_count(feature_group_count);
1530   instr.set_batch_group_count(batch_group_count);
1531   instr.set_padding_type(padding_type);
1532 
1533   if (precision_config != nullptr) {
1534     *instr.mutable_precision_config() = *precision_config;
1535   }
1536   return std::move(instr);
1537 }
1538 
DynamicConvInputGrad(XlaOp input_sizes,XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)1539 XlaOp XlaBuilder::DynamicConvInputGrad(
1540     XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
1541     absl::Span<const int64> window_strides,
1542     absl::Span<const std::pair<int64, int64>> padding,
1543     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1544     const ConvolutionDimensionNumbers& dimension_numbers,
1545     int64_t feature_group_count, int64_t batch_group_count,
1546     const PrecisionConfig* precision_config, PaddingType padding_type,
1547     absl::optional<PrimitiveType> preferred_element_type) {
1548   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1549     TF_ASSIGN_OR_RETURN(
1550         HloInstructionProto instr,
1551         DynamicConvInstruction(
1552             lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
1553             dimension_numbers, feature_group_count, batch_group_count,
1554             precision_config, padding_type, preferred_element_type));
1555 
1556     instr.set_custom_call_target("DynamicConvolutionInputGrad");
1557 
1558     return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
1559                           {input_sizes, lhs, rhs});
1560   });
1561 }
1562 
DynamicConvKernelGrad(XlaOp activations,XlaOp gradients,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)1563 XlaOp XlaBuilder::DynamicConvKernelGrad(
1564     XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
1565     absl::Span<const std::pair<int64, int64>> padding,
1566     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1567     const ConvolutionDimensionNumbers& dimension_numbers,
1568     int64_t feature_group_count, int64_t batch_group_count,
1569     const PrecisionConfig* precision_config, PaddingType padding_type,
1570     absl::optional<PrimitiveType> preferred_element_type) {
1571   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1572     TF_ASSIGN_OR_RETURN(
1573         HloInstructionProto instr,
1574         DynamicConvInstruction(activations, gradients, window_strides, padding,
1575                                lhs_dilation, rhs_dilation, dimension_numbers,
1576                                feature_group_count, batch_group_count,
1577                                precision_config, padding_type,
1578                                preferred_element_type));
1579 
1580     instr.set_custom_call_target("DynamicConvolutionKernelGrad");
1581     // The gradient of kernel has kernel shape and shouldn't have any dynamic
1582     // sizes.
1583     instr.mutable_shape()->clear_is_dynamic_dimension();
1584     return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
1585                           {activations, gradients});
1586   });
1587 }
1588 
DynamicConvForward(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)1589 XlaOp XlaBuilder::DynamicConvForward(
1590     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1591     absl::Span<const std::pair<int64, int64>> padding,
1592     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1593     const ConvolutionDimensionNumbers& dimension_numbers,
1594     int64_t feature_group_count, int64_t batch_group_count,
1595     const PrecisionConfig* precision_config, PaddingType padding_type,
1596     absl::optional<PrimitiveType> preferred_element_type) {
1597   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1598     TF_ASSIGN_OR_RETURN(
1599         HloInstructionProto instr,
1600         DynamicConvInstruction(
1601             lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
1602             dimension_numbers, feature_group_count, batch_group_count,
1603             precision_config, padding_type, preferred_element_type));
1604     instr.set_custom_call_target("DynamicConvolutionForward");
1605 
1606     return AddInstruction(std::move(instr), HloOpcode::kCustomCall, {lhs, rhs});
1607   });
1608 }
1609 
ConvGeneralDilatedInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const Window & window,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config)1610 StatusOr<XlaOp> XlaBuilder::ConvGeneralDilatedInternal(
1611     const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
1612     absl::Span<const int64> window_strides,
1613     absl::Span<const std::pair<int64, int64>> padding,
1614     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1615     const ConvolutionDimensionNumbers& dimension_numbers,
1616     int64_t feature_group_count, int64_t batch_group_count,
1617     const PrecisionConfig* precision_config) {
1618   HloInstructionProto instr;
1619   *instr.mutable_shape() = shape.ToProto();
1620 
1621   *instr.mutable_window() = window;
1622   *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
1623   instr.set_feature_group_count(feature_group_count);
1624   instr.set_batch_group_count(batch_group_count);
1625 
1626   if (precision_config != nullptr) {
1627     *instr.mutable_precision_config() = *precision_config;
1628   }
1629 
1630   return AddInstruction(std::move(instr), HloOpcode::kConvolution, {lhs, rhs});
1631 }
1632 
Fft(XlaOp operand,const FftType fft_type,const absl::Span<const int64> fft_length)1633 XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type,
1634                       const absl::Span<const int64> fft_length) {
1635   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1636     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1637     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape(
1638                                          *operand_shape, fft_type, fft_length));
1639     return FftInternal(shape, operand, fft_type, fft_length);
1640   });
1641 }
1642 
FftInternal(const Shape & shape,XlaOp operand,const FftType fft_type,const absl::Span<const int64> fft_length)1643 StatusOr<XlaOp> XlaBuilder::FftInternal(
1644     const Shape& shape, XlaOp operand, const FftType fft_type,
1645     const absl::Span<const int64> fft_length) {
1646   HloInstructionProto instr;
1647   *instr.mutable_shape() = shape.ToProto();
1648   instr.set_fft_type(fft_type);
1649   for (int64_t i : fft_length) {
1650     instr.add_fft_length(i);
1651   }
1652 
1653   return AddInstruction(std::move(instr), HloOpcode::kFft, {operand});
1654 }
1655 
TriangularSolveInternal(const Shape & shape,XlaOp a,XlaOp b,TriangularSolveOptions options)1656 StatusOr<XlaOp> XlaBuilder::TriangularSolveInternal(
1657     const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) {
1658   HloInstructionProto instr;
1659   *instr.mutable_triangular_solve_options() = std::move(options);
1660   *instr.mutable_shape() = shape.ToProto();
1661 
1662   return AddInstruction(std::move(instr), HloOpcode::kTriangularSolve, {a, b});
1663 }
1664 
CholeskyInternal(const Shape & shape,XlaOp a,bool lower)1665 StatusOr<XlaOp> XlaBuilder::CholeskyInternal(const Shape& shape, XlaOp a,
1666                                              bool lower) {
1667   HloInstructionProto instr;
1668   xla::CholeskyOptions& options = *instr.mutable_cholesky_options();
1669   options.set_lower(lower);
1670   *instr.mutable_shape() = shape.ToProto();
1671 
1672   return AddInstruction(std::move(instr), HloOpcode::kCholesky, {a});
1673 }
1674 
Infeed(const Shape & shape,const string & config)1675 XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
1676   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1677     HloInstructionProto instr;
1678     if (!LayoutUtil::HasLayout(shape)) {
1679       return InvalidArgument("Given shape to Infeed must have a layout");
1680     }
1681     const Shape infeed_instruction_shape =
1682         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1683     *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1684     instr.set_infeed_config(config);
1685 
1686     if (shape.IsArray() && sharding() &&
1687         sharding()->type() == OpSharding::OTHER) {
1688       // TODO(b/110793772): Support tiled array-shaped infeeds.
1689       return InvalidArgument(
1690           "Tiled sharding is not yet supported for array-shaped infeeds");
1691     }
1692 
1693     if (sharding() && sharding()->type() == OpSharding::REPLICATED) {
1694       return InvalidArgument(
1695           "Replicated sharding is not yet supported for infeeds");
1696     }
1697 
1698     // Infeed takes a single token operand. Generate the token to pass to the
1699     // infeed.
1700     XlaOp token;
1701     auto make_token = [&]() {
1702       HloInstructionProto token_instr;
1703       *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1704       return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {});
1705     };
1706     if (sharding()) {
1707       // Arbitrarily assign token to device 0.
1708       OpSharding sharding = sharding_builder::AssignDevice(0);
1709       XlaScopedShardingAssignment scoped_sharding(this, sharding);
1710       TF_ASSIGN_OR_RETURN(token, make_token());
1711     } else {
1712       TF_ASSIGN_OR_RETURN(token, make_token());
1713     }
1714 
1715     // The sharding is set by the client according to the data tuple shape.
1716     // However, the shape of the infeed instruction is a tuple containing the
1717     // data and a token. For tuple sharding type, the sharding must be changed
1718     // to accommodate the token.
1719     XlaOp infeed;
1720     if (sharding() && sharding()->type() == OpSharding::TUPLE) {
1721       // TODO(b/80000000): Remove this when clients have been updated to handle
1722       // tokens.
1723       OpSharding infeed_instruction_sharding = *sharding();
1724       // Arbitrarily assign the token to device 0.
1725       *infeed_instruction_sharding.add_tuple_shardings() =
1726           sharding_builder::AssignDevice(0);
1727       XlaScopedShardingAssignment scoped_sharding(this,
1728                                                   infeed_instruction_sharding);
1729       TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1730                                                  HloOpcode::kInfeed, {token}));
1731     } else {
1732       TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1733                                                  HloOpcode::kInfeed, {token}));
1734     }
1735 
1736     // The infeed instruction produces a tuple of the infed data and a token
1737     // type. Return XLA op containing the data.
1738     // TODO(b/80000000): Remove this when clients have been updated to handle
1739     // tokens.
1740     HloInstructionProto infeed_data;
1741     *infeed_data.mutable_shape() = shape.ToProto();
1742     infeed_data.set_tuple_index(0);
1743     return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement,
1744                           {infeed});
1745   });
1746 }
1747 
InfeedWithToken(XlaOp token,const Shape & shape,const string & config)1748 XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape,
1749                                   const string& config) {
1750   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1751     if (!LayoutUtil::HasLayout(shape)) {
1752       return InvalidArgument("Given shape to Infeed must have a layout");
1753     }
1754     const Shape infeed_instruction_shape =
1755         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1756 
1757     if (shape.IsArray() && sharding() &&
1758         sharding()->type() == OpSharding::OTHER) {
1759       // TODO(b/110793772): Support tiled array-shaped infeeds.
1760       return InvalidArgument(
1761           "Tiled sharding is not yet supported for array-shaped infeeds");
1762     }
1763 
1764     if (sharding() && sharding()->type() == OpSharding::REPLICATED) {
1765       return InvalidArgument(
1766           "Replicated sharding is not yet supported for infeeds");
1767     }
1768     return InfeedWithTokenInternal(infeed_instruction_shape, token, config);
1769   });
1770 }
1771 
InfeedWithTokenInternal(const Shape & infeed_instruction_shape,XlaOp token,const string & config)1772 StatusOr<XlaOp> XlaBuilder::InfeedWithTokenInternal(
1773     const Shape& infeed_instruction_shape, XlaOp token, const string& config) {
1774   HloInstructionProto instr;
1775   *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1776   instr.set_infeed_config(config);
1777   return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
1778 }
1779 
Outfeed(XlaOp operand,const Shape & shape_with_layout,const string & outfeed_config)1780 void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout,
1781                          const string& outfeed_config) {
1782   ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1783     HloInstructionProto instr;
1784 
1785     *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1786 
1787     // Check and set outfeed shape.
1788     if (!LayoutUtil::HasLayout(shape_with_layout)) {
1789       return InvalidArgument("Given shape to Outfeed must have a layout");
1790     }
1791     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1792     if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
1793       return InvalidArgument(
1794           "Outfeed shape %s must be compatible with operand shape %s",
1795           ShapeUtil::HumanStringWithLayout(shape_with_layout),
1796           ShapeUtil::HumanStringWithLayout(*operand_shape));
1797     }
1798     *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1799 
1800     instr.set_outfeed_config(outfeed_config);
1801 
1802     // Outfeed takes a token as its second operand. Generate the token to pass
1803     // to the outfeed.
1804     HloInstructionProto token_instr;
1805     *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1806     TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
1807                                                     HloOpcode::kAfterAll, {}));
1808 
1809     TF_RETURN_IF_ERROR(
1810         AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand, token})
1811             .status());
1812 
1813     // The outfeed instruction produces a token. However, existing users expect
1814     // a nil shape (empty tuple). This should only be relevant if the outfeed is
1815     // the root of a computation.
1816     // TODO(b/80000000): Remove this when clients have been updated to handle
1817     // tokens.
1818     HloInstructionProto tuple_instr;
1819     *tuple_instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
1820 
1821     // The dummy tuple should have no sharding.
1822     {
1823       XlaScopedShardingAssignment scoped_sharding(this, OpSharding());
1824       TF_ASSIGN_OR_RETURN(
1825           XlaOp empty_tuple,
1826           AddInstruction(std::move(tuple_instr), HloOpcode::kTuple, {}));
1827       return empty_tuple;
1828     }
1829   });
1830 }
1831 
OutfeedWithToken(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)1832 XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token,
1833                                    const Shape& shape_with_layout,
1834                                    const string& outfeed_config) {
1835   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1836     // Check and set outfeed shape.
1837     if (!LayoutUtil::HasLayout(shape_with_layout)) {
1838       return InvalidArgument("Given shape to Outfeed must have a layout");
1839     }
1840     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1841     if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
1842       return InvalidArgument(
1843           "Outfeed shape %s must be compatible with operand shape %s",
1844           ShapeUtil::HumanStringWithLayout(shape_with_layout),
1845           ShapeUtil::HumanStringWithLayout(*operand_shape));
1846     }
1847     return OutfeedWithTokenInternal(operand, token, shape_with_layout,
1848                                     outfeed_config);
1849   });
1850 }
1851 
OutfeedWithTokenInternal(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)1852 StatusOr<XlaOp> XlaBuilder::OutfeedWithTokenInternal(
1853     XlaOp operand, XlaOp token, const Shape& shape_with_layout,
1854     const string& outfeed_config) {
1855   HloInstructionProto instr;
1856   *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1857   *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1858   instr.set_outfeed_config(outfeed_config);
1859   return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
1860                         {operand, token});
1861 }
1862 
CreateToken()1863 XlaOp XlaBuilder::CreateToken() {
1864   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1865     HloInstructionProto instr;
1866     *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1867     return AddInstruction(std::move(instr), HloOpcode::kAfterAll);
1868   });
1869 }
1870 
AfterAll(absl::Span<const XlaOp> tokens)1871 XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
1872   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1873     if (tokens.empty()) {
1874       return InvalidArgument("AfterAll requires at least one operand");
1875     }
1876     for (int i = 0, end = tokens.size(); i < end; ++i) {
1877       XlaOp operand = tokens[i];
1878       TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1879       if (!operand_shape->IsToken()) {
1880         return InvalidArgument(
1881             "All operands to AfterAll must be tokens; operand %d has shape %s",
1882             i, ShapeUtil::HumanString(*operand_shape));
1883       }
1884     }
1885     HloInstructionProto instr;
1886     *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1887     return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens);
1888   });
1889 }
1890 
CustomCall(const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal,absl::optional<Window> window,absl::optional<ConvolutionDimensionNumbers> dnums,CustomCallSchedule schedule,CustomCallApiVersion api_version)1891 XlaOp XlaBuilder::CustomCall(
1892     const string& call_target_name, absl::Span<const XlaOp> operands,
1893     const Shape& shape, const string& opaque,
1894     absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
1895     bool has_side_effect,
1896     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1897         output_operand_aliasing,
1898     const Literal* literal, absl::optional<Window> window,
1899     absl::optional<ConvolutionDimensionNumbers> dnums,
1900     CustomCallSchedule schedule, CustomCallApiVersion api_version) {
1901   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1902     if (absl::StartsWith(call_target_name, "$")) {
1903       return InvalidArgument(
1904           "Invalid custom_call_target \"%s\": Call targets that start with '$' "
1905           "are reserved for internal use.",
1906           call_target_name);
1907     }
1908     if (operand_shapes_with_layout.has_value()) {
1909       if (!LayoutUtil::HasLayout(shape)) {
1910         return InvalidArgument(
1911             "Result shape must have layout for custom call with constrained "
1912             "layout.");
1913       }
1914       if (operands.size() != operand_shapes_with_layout->size()) {
1915         return InvalidArgument(
1916             "Must specify a shape with layout for each operand for custom call "
1917             "with constrained layout; given %d shapes, expected %d",
1918             operand_shapes_with_layout->size(), operands.size());
1919       }
1920       int64_t operand_num = 0;
1921       for (const Shape& operand_shape : *operand_shapes_with_layout) {
1922         if (!LayoutUtil::HasLayout(operand_shape)) {
1923           return InvalidArgument(
1924               "No layout specified for operand %d for custom call with "
1925               "constrained layout.",
1926               operand_num);
1927         }
1928         ++operand_num;
1929       }
1930     }
1931     return CustomCallInternal(call_target_name, operands, shape, opaque,
1932                               operand_shapes_with_layout, has_side_effect,
1933                               output_operand_aliasing, literal, window, dnums,
1934                               schedule, api_version);
1935   });
1936 }
1937 
CustomCallInternal(const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal,absl::optional<Window> window,absl::optional<ConvolutionDimensionNumbers> dnums,CustomCallSchedule schedule,CustomCallApiVersion api_version)1938 StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
1939     const string& call_target_name, absl::Span<const XlaOp> operands,
1940     const Shape& shape, const string& opaque,
1941     absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
1942     bool has_side_effect,
1943     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1944         output_operand_aliasing,
1945     const Literal* literal, absl::optional<Window> window,
1946     absl::optional<ConvolutionDimensionNumbers> dnums,
1947     CustomCallSchedule schedule, CustomCallApiVersion api_version) {
1948   HloInstructionProto instr;
1949   *instr.mutable_shape() = shape.ToProto();
1950   instr.set_custom_call_target(call_target_name);
1951   instr.set_backend_config(opaque);
1952   if (operand_shapes_with_layout.has_value()) {
1953     instr.set_constrain_layout(true);
1954     for (const Shape& operand_shape : *operand_shapes_with_layout) {
1955       *instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
1956     }
1957   }
1958   if (literal != nullptr) {
1959     *instr.mutable_literal() = literal->ToProto();
1960   }
1961   instr.set_custom_call_has_side_effect(has_side_effect);
1962   for (const auto& pair : output_operand_aliasing) {
1963     auto aliasing = instr.add_custom_call_output_operand_aliasing();
1964     aliasing->set_operand_index(pair.second.first);
1965     for (int64_t index : pair.second.second) {
1966       aliasing->add_operand_shape_index(index);
1967     }
1968     for (int64_t index : pair.first) {
1969       aliasing->add_output_shape_index(index);
1970     }
1971   }
1972   if (window.has_value()) {
1973     *instr.mutable_window() = *window;
1974   }
1975   if (dnums.has_value()) {
1976     *instr.mutable_convolution_dimension_numbers() = *dnums;
1977   }
1978   instr.set_custom_call_schedule(schedule);
1979   instr.set_custom_call_api_version(api_version);
1980   return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
1981 }
1982 
CustomCall(const string & call_target_name,absl::Span<const XlaOp> operands,const XlaComputation & computation,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal,CustomCallSchedule schedule,CustomCallApiVersion api_version)1983 XlaOp XlaBuilder::CustomCall(
1984     const string& call_target_name, absl::Span<const XlaOp> operands,
1985     const XlaComputation& computation, const Shape& shape, const string& opaque,
1986     absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
1987     bool has_side_effect,
1988     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1989         output_operand_aliasing,
1990     const Literal* literal, CustomCallSchedule schedule,
1991     CustomCallApiVersion api_version) {
1992   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1993     HloInstructionProto instr;
1994     if (absl::StartsWith(call_target_name, "$")) {
1995       return InvalidArgument(
1996           "Invalid custom_call_target \"%s\": Call targets that start with '$' "
1997           "are reserved for internal use.",
1998           call_target_name);
1999     }
2000     *instr.mutable_shape() = shape.ToProto();
2001     instr.set_custom_call_target(call_target_name);
2002     instr.set_backend_config(opaque);
2003     if (literal != nullptr) {
2004       *instr.mutable_literal() = literal->ToProto();
2005     }
2006     if (operand_shapes_with_layout.has_value()) {
2007       if (!LayoutUtil::HasLayout(shape)) {
2008         return InvalidArgument(
2009             "Result shape must have layout for custom call with constrained "
2010             "layout.");
2011       }
2012       if (operands.size() != operand_shapes_with_layout->size()) {
2013         return InvalidArgument(
2014             "Must specify a shape with layout for each operand for custom call "
2015             "with constrained layout; given %d shapes, expected %d",
2016             operand_shapes_with_layout->size(), operands.size());
2017       }
2018       instr.set_constrain_layout(true);
2019       int64_t operand_num = 0;
2020       for (const Shape& operand_shape : *operand_shapes_with_layout) {
2021         if (!LayoutUtil::HasLayout(operand_shape)) {
2022           return InvalidArgument(
2023               "No layout specified for operand %d for custom call with "
2024               "constrained layout.",
2025               operand_num);
2026         }
2027         *instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
2028         ++operand_num;
2029       }
2030     }
2031     AddCalledComputation(computation, &instr);
2032     for (const auto& pair : output_operand_aliasing) {
2033       auto aliasing = instr.add_custom_call_output_operand_aliasing();
2034       aliasing->set_operand_index(pair.second.first);
2035       for (int64_t index : pair.second.second) {
2036         aliasing->add_operand_shape_index(index);
2037       }
2038       for (int64_t index : pair.first) {
2039         aliasing->add_output_shape_index(index);
2040       }
2041     }
2042     instr.set_custom_call_schedule(schedule);
2043     instr.set_custom_call_api_version(api_version);
2044     return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
2045   });
2046 }
2047 
Transpose(XlaOp operand,absl::Span<const int64> permutation)2048 XlaOp XlaBuilder::Transpose(XlaOp operand,
2049                             absl::Span<const int64> permutation) {
2050   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2051     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2052     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape(
2053                                          *operand_shape, permutation));
2054     return TransposeInternal(shape, operand, permutation);
2055   });
2056 }
2057 
TransposeInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> permutation)2058 StatusOr<XlaOp> XlaBuilder::TransposeInternal(
2059     const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
2060   HloInstructionProto instr;
2061   *instr.mutable_shape() = shape.ToProto();
2062   for (int64_t dim : permutation) {
2063     instr.add_dimensions(dim);
2064   }
2065   return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
2066 }
2067 
Rev(XlaOp operand,absl::Span<const int64> dimensions)2068 XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span<const int64> dimensions) {
2069   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2070     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2071     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape(
2072                                          *operand_shape, dimensions));
2073     return RevInternal(shape, operand, dimensions);
2074   });
2075 }
2076 
RevInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> dimensions)2077 StatusOr<XlaOp> XlaBuilder::RevInternal(const Shape& shape, XlaOp operand,
2078                                         absl::Span<const int64> dimensions) {
2079   HloInstructionProto instr;
2080   *instr.mutable_shape() = shape.ToProto();
2081   for (int64_t dim : dimensions) {
2082     instr.add_dimensions(dim);
2083   }
2084   return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand});
2085 }
2086 
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64_t dimension,bool is_stable)2087 XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
2088                        const XlaComputation& comparator, int64_t dimension,
2089                        bool is_stable) {
2090   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2091     std::vector<const Shape*> operand_shape_ptrs;
2092     TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes,
2093                         GetOperandShapes(operands));
2094     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
2095                       [](const Shape& shape) { return &shape; });
2096     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape(
2097                                          HloOpcode::kSort, operand_shape_ptrs));
2098     return SortInternal(shape, operands, comparator, dimension, is_stable);
2099   });
2100 }
2101 
SortInternal(const Shape & shape,absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64_t dimension,bool is_stable)2102 StatusOr<XlaOp> XlaBuilder::SortInternal(const Shape& shape,
2103                                          absl::Span<const XlaOp> operands,
2104                                          const XlaComputation& comparator,
2105                                          int64_t dimension, bool is_stable) {
2106   HloInstructionProto instr;
2107   *instr.mutable_shape() = shape.ToProto();
2108   instr.set_is_stable(is_stable);
2109   if (dimension == -1) {
2110     TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0]));
2111     dimension = keys_shape->rank() - 1;
2112   }
2113   instr.add_dimensions(dimension);
2114   AddCalledComputation(comparator, &instr);
2115   return AddInstruction(std::move(instr), HloOpcode::kSort, operands);
2116 }
2117 
ConvertElementType(XlaOp operand,PrimitiveType new_element_type)2118 XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
2119                                      PrimitiveType new_element_type) {
2120   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2121     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2122     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
2123                                          *operand_shape, new_element_type));
2124     return AddOpWithShape(HloOpcode::kConvert, shape, {operand});
2125   });
2126 }
2127 
BitcastConvertType(XlaOp operand,PrimitiveType new_element_type)2128 XlaOp XlaBuilder::BitcastConvertType(XlaOp operand,
2129                                      PrimitiveType new_element_type) {
2130   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2131     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2132     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
2133                                          *operand_shape, new_element_type));
2134     return BitcastConvertTypeInternal(shape, operand);
2135   });
2136 }
2137 
BitcastConvertTypeInternal(const Shape & shape,XlaOp operand)2138 StatusOr<XlaOp> XlaBuilder::BitcastConvertTypeInternal(const Shape& shape,
2139                                                        XlaOp operand) {
2140   HloInstructionProto instr;
2141   *instr.mutable_shape() = shape.ToProto();
2142   return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
2143                         {operand});
2144 }
2145 
Clamp(XlaOp min,XlaOp operand,XlaOp max)2146 XlaOp XlaBuilder::Clamp(XlaOp min, XlaOp operand, XlaOp max) {
2147   return TernaryOp(HloOpcode::kClamp, min, operand, max);
2148 }
2149 
Map(absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64> dimensions,absl::Span<const XlaOp> static_operands)2150 XlaOp XlaBuilder::Map(absl::Span<const XlaOp> operands,
2151                       const XlaComputation& computation,
2152                       absl::Span<const int64> dimensions,
2153                       absl::Span<const XlaOp> static_operands) {
2154   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2155     if (!static_operands.empty()) {
2156       return Unimplemented("static_operands is not supported in Map");
2157     }
2158 
2159     HloInstructionProto instr;
2160     std::vector<const Shape*> operand_shape_ptrs;
2161     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
2162     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
2163                       [](const Shape& shape) { return &shape; });
2164     TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
2165                         computation.GetProgramShape());
2166     TF_ASSIGN_OR_RETURN(
2167         Shape shape, ShapeInference::InferMapShape(
2168                          operand_shape_ptrs, called_program_shape, dimensions));
2169     *instr.mutable_shape() = shape.ToProto();
2170 
2171     Shape output_shape(instr.shape());
2172     const int64_t output_rank = output_shape.rank();
2173     AddCalledComputation(computation, &instr);
2174     std::vector<XlaOp> new_operands(operands.begin(), operands.end());
2175     for (XlaOp& new_operand : new_operands) {
2176       TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(new_operand));
2177       const int64_t rank = shape->rank();
2178       if (rank != output_rank) {
2179         TF_ASSIGN_OR_RETURN(new_operand,
2180                             InDimBroadcast(output_shape, new_operand, {}));
2181         TF_ASSIGN_OR_RETURN(shape, GetShapePtr(new_operand));
2182       }
2183       if (!ShapeUtil::SameDimensions(output_shape, *shape)) {
2184         TF_ASSIGN_OR_RETURN(new_operand,
2185                             AddBroadcastSequence(output_shape, new_operand));
2186       }
2187     }
2188 
2189     return AddInstruction(std::move(instr), HloOpcode::kMap, new_operands);
2190   });
2191 }
2192 
RngOp(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)2193 XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
2194                         absl::Span<const XlaOp> parameters,
2195                         const Shape& shape) {
2196   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2197     // Check the number of parameters per RNG distribution.
2198     switch (distribution) {
2199       case RandomDistribution::RNG_NORMAL:
2200       case RandomDistribution::RNG_UNIFORM:
2201         if (parameters.size() != 2) {
2202           return InvalidArgument(
2203               "RNG distribution (%s) expects 2 parameters, but got %ld",
2204               RandomDistribution_Name(distribution), parameters.size());
2205         }
2206         break;
2207       default:
2208         LOG(FATAL) << "unhandled distribution " << distribution;
2209     }
2210 
2211     TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
2212     return RngOpInternal(distribution, parameters, shape);
2213   });
2214 }
2215 
RngOpInternal(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)2216 StatusOr<XlaOp> XlaBuilder::RngOpInternal(RandomDistribution distribution,
2217                                           absl::Span<const XlaOp> parameters,
2218                                           const Shape& shape) {
2219   HloInstructionProto instr;
2220   *instr.mutable_shape() = shape.ToProto();
2221   instr.set_distribution(distribution);
2222 
2223   return AddInstruction(std::move(instr), HloOpcode::kRng, parameters);
2224 }
2225 
RngNormal(XlaOp mu,XlaOp sigma,const Shape & shape)2226 XlaOp XlaBuilder::RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape) {
2227   return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
2228 }
2229 
RngUniform(XlaOp a,XlaOp b,const Shape & shape)2230 XlaOp XlaBuilder::RngUniform(XlaOp a, XlaOp b, const Shape& shape) {
2231   return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
2232 }
2233 
RngBitGenerator(RandomAlgorithm algorithm,XlaOp initial_state,const Shape & shape)2234 XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm,
2235                                   XlaOp initial_state, const Shape& shape) {
2236   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2237     TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
2238     TF_ASSIGN_OR_RETURN(Shape state_shape, GetShape(initial_state));
2239     Shape output_shape = shape;
2240     switch (output_shape.element_type()) {
2241       case PrimitiveType::F32:
2242       case PrimitiveType::S32:
2243       case PrimitiveType::U32:
2244         output_shape.set_element_type(PrimitiveType::U32);
2245         break;
2246       case PrimitiveType::F64:
2247       case PrimitiveType::S64:
2248       case PrimitiveType::U64:
2249         output_shape.set_element_type(PrimitiveType::U64);
2250         break;
2251       default:
2252         return InvalidArgument("Unsupported shape for RngBitGenerator: %s",
2253                                PrimitiveType_Name(output_shape.element_type()));
2254     }
2255     return RngBitGeneratorInternal(
2256         ShapeUtil::MakeTupleShape({state_shape, output_shape}), algorithm,
2257         initial_state);
2258   });
2259 }
2260 
RngBitGeneratorInternal(const Shape & full_result_shape,RandomAlgorithm algorithm,XlaOp initial_state)2261 StatusOr<XlaOp> XlaBuilder::RngBitGeneratorInternal(
2262     const Shape& full_result_shape, RandomAlgorithm algorithm,
2263     XlaOp initial_state) {
2264   HloInstructionProto instr;
2265   *instr.mutable_shape() = full_result_shape.ToProto();
2266   instr.set_rng_algorithm(algorithm);
2267   return AddInstruction(std::move(instr), HloOpcode::kRngBitGenerator,
2268                         {initial_state});
2269 }
2270 
While(const XlaComputation & condition,const XlaComputation & body,XlaOp init)2271 XlaOp XlaBuilder::While(const XlaComputation& condition,
2272                         const XlaComputation& body, XlaOp init) {
2273   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2274     // Infer shape.
2275     TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape());
2276     TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
2277                         condition.GetProgramShape());
2278     TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init));
2279     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape(
2280                                          condition_program_shape,
2281                                          body_program_shape, *init_shape));
2282     return WhileInternal(shape, condition, body, init);
2283   });
2284 }
2285 
WhileInternal(const Shape & shape,const XlaComputation & condition,const XlaComputation & body,XlaOp init)2286 StatusOr<XlaOp> XlaBuilder::WhileInternal(const Shape& shape,
2287                                           const XlaComputation& condition,
2288                                           const XlaComputation& body,
2289                                           XlaOp init) {
2290   HloInstructionProto instr;
2291   *instr.mutable_shape() = shape.ToProto();
2292   // Body comes before condition computation in the vector.
2293   AddCalledComputation(body, &instr);
2294   AddCalledComputation(condition, &instr);
2295   return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
2296 }
2297 
Gather(XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)2298 XlaOp XlaBuilder::Gather(XlaOp input, XlaOp start_indices,
2299                          const GatherDimensionNumbers& dimension_numbers,
2300                          absl::Span<const int64> slice_sizes,
2301                          bool indices_are_sorted) {
2302   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2303     TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
2304     TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape,
2305                         GetShapePtr(start_indices));
2306     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGatherShape(
2307                                          *input_shape, *start_indices_shape,
2308                                          dimension_numbers, slice_sizes));
2309     return GatherInternal(shape, input, start_indices, dimension_numbers,
2310                           slice_sizes, indices_are_sorted);
2311   });
2312 }
2313 
GatherInternal(const Shape & shape,XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)2314 StatusOr<XlaOp> XlaBuilder::GatherInternal(
2315     const Shape& shape, XlaOp input, XlaOp start_indices,
2316     const GatherDimensionNumbers& dimension_numbers,
2317     absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
2318   HloInstructionProto instr;
2319   instr.set_indices_are_sorted(indices_are_sorted);
2320   *instr.mutable_shape() = shape.ToProto();
2321   *instr.mutable_gather_dimension_numbers() = dimension_numbers;
2322   for (int64_t bound : slice_sizes) {
2323     instr.add_gather_slice_sizes(bound);
2324   }
2325 
2326   return AddInstruction(std::move(instr), HloOpcode::kGather,
2327                         {input, start_indices});
2328 }
2329 
Scatter(XlaOp input,XlaOp scatter_indices,XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)2330 XlaOp XlaBuilder::Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
2331                           const XlaComputation& update_computation,
2332                           const ScatterDimensionNumbers& dimension_numbers,
2333                           bool indices_are_sorted, bool unique_indices) {
2334   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2335     TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
2336     TF_ASSIGN_OR_RETURN(const Shape* scatter_indices_shape,
2337                         GetShapePtr(scatter_indices));
2338     TF_ASSIGN_OR_RETURN(const Shape* updates_shape, GetShapePtr(updates));
2339     TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2340                         update_computation.GetProgramShape());
2341     TF_ASSIGN_OR_RETURN(
2342         Shape shape, ShapeInference::InferScatterShape(
2343                          *input_shape, *scatter_indices_shape, *updates_shape,
2344                          to_apply_shape, dimension_numbers));
2345     return ScatterInternal(shape, input, scatter_indices, updates,
2346                            update_computation, dimension_numbers,
2347                            indices_are_sorted, unique_indices);
2348   });
2349 }
2350 
ScatterInternal(const Shape & shape,XlaOp input,XlaOp scatter_indices,XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)2351 StatusOr<XlaOp> XlaBuilder::ScatterInternal(
2352     const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
2353     const XlaComputation& update_computation,
2354     const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
2355     bool unique_indices) {
2356   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2357     HloInstructionProto instr;
2358     instr.set_indices_are_sorted(indices_are_sorted);
2359     instr.set_unique_indices(unique_indices);
2360     *instr.mutable_shape() = shape.ToProto();
2361     *instr.mutable_scatter_dimension_numbers() = dimension_numbers;
2362 
2363     AddCalledComputation(update_computation, &instr);
2364     return AddInstruction(std::move(instr), HloOpcode::kScatter,
2365                           {input, scatter_indices, updates});
2366   });
2367 }
2368 
Conditional(XlaOp predicate,XlaOp true_operand,const XlaComputation & true_computation,XlaOp false_operand,const XlaComputation & false_computation)2369 XlaOp XlaBuilder::Conditional(XlaOp predicate, XlaOp true_operand,
2370                               const XlaComputation& true_computation,
2371                               XlaOp false_operand,
2372                               const XlaComputation& false_computation) {
2373   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2374     TF_ASSIGN_OR_RETURN(const xla::Shape* shape, GetShapePtr(predicate));
2375 
2376     if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != PRED) {
2377       return InvalidArgument(
2378           "Argument to predicated-Conditional is not a scalar of PRED type "
2379           "(%s).",
2380           ShapeUtil::HumanString(*shape));
2381     }
2382     // The index of true_computation must be 0 and that of false computation
2383     // must be 1.
2384     return ConditionalImpl(predicate, {&true_computation, &false_computation},
2385                            {true_operand, false_operand});
2386   });
2387 }
2388 
Conditional(XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)2389 XlaOp XlaBuilder::Conditional(
2390     XlaOp branch_index,
2391     absl::Span<const XlaComputation* const> branch_computations,
2392     absl::Span<const XlaOp> branch_operands) {
2393   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2394     TF_ASSIGN_OR_RETURN(const xla::Shape* shape, GetShapePtr(branch_index));
2395 
2396     if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != S32) {
2397       return InvalidArgument(
2398           "Argument to indexed-Conditional is not a scalar of S32 type (%s).",
2399           ShapeUtil::HumanString(*shape));
2400     }
2401     return ConditionalImpl(branch_index, branch_computations, branch_operands);
2402   });
2403 }
2404 
ConditionalImpl(XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)2405 XlaOp XlaBuilder::ConditionalImpl(
2406     XlaOp branch_index,
2407     absl::Span<const XlaComputation* const> branch_computations,
2408     absl::Span<const XlaOp> branch_operands) {
2409   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2410     HloInstructionProto instr;
2411 
2412     TF_ASSIGN_OR_RETURN(const Shape* branch_index_shape,
2413                         GetShapePtr(branch_index));
2414     std::vector<Shape> branch_operand_shapes(branch_operands.size());
2415     std::vector<ProgramShape> branch_computation_shapes(
2416         branch_computations.size());
2417     for (int j = 0, end = branch_operands.size(); j < end; ++j) {
2418       TF_ASSIGN_OR_RETURN(branch_operand_shapes[j],
2419                           GetShape(branch_operands[j]));
2420       TF_ASSIGN_OR_RETURN(branch_computation_shapes[j],
2421                           branch_computations[j]->GetProgramShape());
2422     }
2423     TF_ASSIGN_OR_RETURN(const Shape shape,
2424                         ShapeInference::InferConditionalShape(
2425                             *branch_index_shape, branch_computation_shapes,
2426                             branch_operand_shapes));
2427     *instr.mutable_shape() = shape.ToProto();
2428 
2429     for (const XlaComputation* branch_computation : branch_computations) {
2430       AddCalledComputation(*branch_computation, &instr);
2431     }
2432 
2433     std::vector<XlaOp> operands(1, branch_index);
2434     for (const XlaOp branch_operand : branch_operands) {
2435       operands.emplace_back(branch_operand);
2436     }
2437     return AddInstruction(std::move(instr), HloOpcode::kConditional,
2438                           absl::MakeSpan(operands));
2439   });
2440 }
2441 
CheckOpBuilder(XlaOp op) const2442 Status XlaBuilder::CheckOpBuilder(XlaOp op) const {
2443   if (this != op.builder()) {
2444     return InvalidArgument(
2445         "XlaOp with handle %d is built by builder '%s', but is trying to use "
2446         "it in builder '%s'",
2447         op.handle(), op.builder()->name(), name());
2448   }
2449   return Status::OK();
2450 }
2451 
Reduce(XlaOp operand,XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)2452 XlaOp XlaBuilder::Reduce(XlaOp operand, XlaOp init_value,
2453                          const XlaComputation& computation,
2454                          absl::Span<const int64> dimensions_to_reduce) {
2455   return Reduce(absl::Span<const XlaOp>({operand}),
2456                 absl::Span<const XlaOp>({init_value}), computation,
2457                 dimensions_to_reduce);
2458 }
2459 
Reduce(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)2460 XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
2461                          absl::Span<const XlaOp> init_values,
2462                          const XlaComputation& computation,
2463                          absl::Span<const int64> dimensions_to_reduce) {
2464   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2465     TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
2466                         computation.GetProgramShape());
2467 
2468     std::vector<XlaOp> all_operands;
2469     all_operands.insert(all_operands.end(), operands.begin(), operands.end());
2470     all_operands.insert(all_operands.end(), init_values.begin(),
2471                         init_values.end());
2472 
2473     std::vector<const Shape*> operand_shape_ptrs;
2474     TF_ASSIGN_OR_RETURN(const auto& operand_shapes,
2475                         GetOperandShapes(all_operands));
2476     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
2477                       [](const Shape& shape) { return &shape; });
2478 
2479     TF_ASSIGN_OR_RETURN(
2480         Shape shape,
2481         ShapeInference::InferReduceShape(
2482             operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
2483     return ReduceInternal(shape, all_operands, computation,
2484                           dimensions_to_reduce);
2485   });
2486 }
2487 
ReduceInternal(const Shape & shape,absl::Span<const XlaOp> all_operands,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)2488 StatusOr<XlaOp> XlaBuilder::ReduceInternal(
2489     const Shape& shape, absl::Span<const XlaOp> all_operands,
2490     const XlaComputation& computation,
2491     absl::Span<const int64> dimensions_to_reduce) {
2492   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2493     HloInstructionProto instr;
2494     *instr.mutable_shape() = shape.ToProto();
2495 
2496     for (int64_t dim : dimensions_to_reduce) {
2497       instr.add_dimensions(dim);
2498     }
2499 
2500     AddCalledComputation(computation, &instr);
2501     return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands);
2502   });
2503 }
2504 
ReduceAll(XlaOp operand,XlaOp init_value,const XlaComputation & computation)2505 XlaOp XlaBuilder::ReduceAll(XlaOp operand, XlaOp init_value,
2506                             const XlaComputation& computation) {
2507   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2508     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2509     std::vector<int64> all_dimnos(operand_shape->rank());
2510     std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
2511     return Reduce(operand, init_value, computation, all_dimnos);
2512   });
2513 }
2514 
ReduceWindow(XlaOp operand,XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)2515 XlaOp XlaBuilder::ReduceWindow(XlaOp operand, XlaOp init_value,
2516                                const XlaComputation& computation,
2517                                absl::Span<const int64> window_dimensions,
2518                                absl::Span<const int64> window_strides,
2519                                Padding padding) {
2520   return ReduceWindow(absl::MakeSpan(&operand, 1),
2521                       absl::MakeSpan(&init_value, 1), computation,
2522                       window_dimensions, window_strides, padding);
2523 }
2524 
ReduceWindow(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)2525 XlaOp XlaBuilder::ReduceWindow(absl::Span<const XlaOp> operands,
2526                                absl::Span<const XlaOp> init_values,
2527                                const XlaComputation& computation,
2528                                absl::Span<const int64> window_dimensions,
2529                                absl::Span<const int64> window_strides,
2530                                Padding padding) {
2531   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2532     const Shape* operand_shape = nullptr;
2533     for (const auto& operand : operands) {
2534       TF_ASSIGN_OR_RETURN(operand_shape, GetShapePtr(operand));
2535       TF_RETURN_IF_ERROR(
2536           ValidatePaddingValues(AsInt64Slice(operand_shape->dimensions()),
2537                                 window_dimensions, window_strides));
2538     }
2539     CHECK(operand_shape != nullptr);
2540     std::vector<std::pair<int64, int64>> padding_values =
2541         MakePadding(AsInt64Slice(operand_shape->dimensions()),
2542                     window_dimensions, window_strides, padding);
2543     TF_ASSIGN_OR_RETURN(auto window,
2544                         ShapeInference::InferWindowFromDimensions(
2545                             window_dimensions, window_strides, padding_values,
2546                             /*lhs_dilation=*/{},
2547                             /*rhs_dilation=*/{}));
2548     PaddingType padding_type = PADDING_INVALID;
2549     for (int64_t i = 0; i < operand_shape->rank(); ++i) {
2550       if (operand_shape->is_dynamic_dimension(i) &&
2551           !window_util::IsTrivialWindowDimension(window.dimensions(i)) &&
2552           padding == Padding::kSame) {
2553         // SAME padding can create dynamic padding sizes. The padding size
2554         // need to be rewritten by dynamic padder using HloInstructions. We
2555         // create a CustomCall to handle this.
2556         padding_type = PADDING_SAME;
2557       }
2558     }
2559     if (padding_type == PADDING_SAME) {
2560       TF_ASSIGN_OR_RETURN(
2561           HloInstructionProto instr,
2562           ReduceWindowInternal(operands, init_values, computation,
2563                                window_dimensions, window_strides, {}, {},
2564                                padding_values));
2565       instr.set_custom_call_target("DynamicReduceWindowSamePadding");
2566       std::vector<XlaOp> args;
2567       args.insert(args.end(), operands.begin(), operands.end());
2568       args.insert(args.end(), init_values.begin(), init_values.end());
2569       return AddInstruction(std::move(instr), HloOpcode::kCustomCall, args);
2570     }
2571     return ReduceWindowWithGeneralPadding(
2572         operands, init_values, computation, window_dimensions, window_strides,
2573         /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values);
2574   });
2575 }
2576 
ReduceWindowWithGeneralPadding(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)2577 XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
2578     absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
2579     const XlaComputation& computation,
2580     absl::Span<const int64> window_dimensions,
2581     absl::Span<const int64> window_strides,
2582     absl::Span<const int64> base_dilations,
2583     absl::Span<const int64> window_dilations,
2584     absl::Span<const std::pair<int64, int64>> padding) {
2585   std::vector<const Shape*> operand_shapes, init_shapes;
2586   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2587     if (operands.size() == 1) {
2588       const auto& operand = operands[0];
2589       const auto& init_value = init_values[0];
2590       TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2591       operand_shapes.push_back(operand_shape);
2592       TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
2593       init_shapes.push_back(init_shape);
2594 
2595       TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2596                           computation.GetProgramShape());
2597       TF_ASSIGN_OR_RETURN(auto window,
2598                           ShapeInference::InferWindowFromDimensions(
2599                               window_dimensions, window_strides, padding,
2600                               /*lhs_dilation=*/base_dilations,
2601                               /*rhs_dilation=*/window_dilations));
2602       TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape(
2603                                            absl::MakeSpan(operand_shapes),
2604                                            absl::MakeSpan(init_shapes), window,
2605                                            to_apply_shape));
2606       return ReduceWindowInternal(shape, operands[0], init_values[0],
2607                                   computation, window);
2608     }
2609 
2610     TF_ASSIGN_OR_RETURN(
2611         HloInstructionProto instr,
2612         ReduceWindowInternal(operands, init_values, computation,
2613                              window_dimensions, window_strides, base_dilations,
2614                              window_dilations, padding));
2615     std::vector<XlaOp> args;
2616     args.insert(args.end(), operands.begin(), operands.end());
2617     args.insert(args.end(), init_values.begin(), init_values.end());
2618     return AddInstruction(std::move(instr), HloOpcode::kReduceWindow, args);
2619   });
2620 }
2621 
ReduceWindowInternal(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)2622 StatusOr<HloInstructionProto> XlaBuilder::ReduceWindowInternal(
2623     absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
2624     const XlaComputation& computation,
2625     absl::Span<const int64> window_dimensions,
2626     absl::Span<const int64> window_strides,
2627     absl::Span<const int64> base_dilations,
2628     absl::Span<const int64> window_dilations,
2629     absl::Span<const std::pair<int64, int64>> padding) {
2630   std::vector<const Shape*> operand_shapes, init_shapes;
2631   for (int i = 0; i < operands.size(); ++i) {
2632     const auto& operand = operands[i];
2633     const auto& init_value = init_values[i];
2634     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2635     operand_shapes.push_back(operand_shape);
2636     TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
2637     init_shapes.push_back(init_shape);
2638   }
2639   TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2640                       computation.GetProgramShape());
2641   TF_ASSIGN_OR_RETURN(auto window,
2642                       ShapeInference::InferWindowFromDimensions(
2643                           window_dimensions, window_strides, padding,
2644                           /*lhs_dilation=*/base_dilations,
2645                           /*rhs_dilation=*/window_dilations));
2646   TF_ASSIGN_OR_RETURN(Shape shape,
2647                       ShapeInference::InferReduceWindowShape(
2648                           absl::MakeSpan(operand_shapes),
2649                           absl::MakeSpan(init_shapes), window, to_apply_shape));
2650   HloInstructionProto instr;
2651   *instr.mutable_shape() = shape.ToProto();
2652   *instr.mutable_window() = std::move(window);
2653   AddCalledComputation(computation, &instr);
2654   return instr;
2655 }
2656 
ReduceWindowInternal(const Shape & shape,XlaOp operand,XlaOp init_value,const XlaComputation & computation,Window window)2657 StatusOr<XlaOp> XlaBuilder::ReduceWindowInternal(
2658     const Shape& shape, XlaOp operand, XlaOp init_value,
2659     const XlaComputation& computation, Window window) {
2660   HloInstructionProto instr;
2661   *instr.mutable_shape() = shape.ToProto();
2662   *instr.mutable_window() = std::move(window);
2663 
2664   AddCalledComputation(computation, &instr);
2665   return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
2666                         {operand, init_value});
2667 }
2668 
BatchNormTraining(XlaOp operand,XlaOp scale,XlaOp offset,float epsilon,int64_t feature_index)2669 XlaOp XlaBuilder::BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
2670                                     float epsilon, int64_t feature_index) {
2671   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2672     HloInstructionProto instr;
2673 
2674     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2675     TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2676     TF_ASSIGN_OR_RETURN(const Shape* offset_shape, GetShapePtr(offset));
2677     TF_ASSIGN_OR_RETURN(
2678         Shape shape,
2679         ShapeInference::InferBatchNormTrainingShape(
2680             *operand_shape, *scale_shape, *offset_shape, feature_index));
2681     *instr.mutable_shape() = shape.ToProto();
2682 
2683     instr.set_epsilon(epsilon);
2684     instr.set_feature_index(feature_index);
2685 
2686     return AddInstruction(std::move(instr), HloOpcode::kBatchNormTraining,
2687                           {operand, scale, offset});
2688   });
2689 }
2690 
BatchNormInference(XlaOp operand,XlaOp scale,XlaOp offset,XlaOp mean,XlaOp variance,float epsilon,int64_t feature_index)2691 XlaOp XlaBuilder::BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset,
2692                                      XlaOp mean, XlaOp variance, float epsilon,
2693                                      int64_t feature_index) {
2694   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2695     HloInstructionProto instr;
2696 
2697     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2698     TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2699     TF_ASSIGN_OR_RETURN(const Shape* offset_shape, GetShapePtr(offset));
2700     TF_ASSIGN_OR_RETURN(const Shape* mean_shape, GetShapePtr(mean));
2701     TF_ASSIGN_OR_RETURN(const Shape* variance_shape, GetShapePtr(variance));
2702     TF_ASSIGN_OR_RETURN(Shape shape,
2703                         ShapeInference::InferBatchNormInferenceShape(
2704                             *operand_shape, *scale_shape, *offset_shape,
2705                             *mean_shape, *variance_shape, feature_index));
2706     *instr.mutable_shape() = shape.ToProto();
2707 
2708     instr.set_epsilon(epsilon);
2709     instr.set_feature_index(feature_index);
2710 
2711     return AddInstruction(std::move(instr), HloOpcode::kBatchNormInference,
2712                           {operand, scale, offset, mean, variance});
2713   });
2714 }
2715 
BatchNormGrad(XlaOp operand,XlaOp scale,XlaOp batch_mean,XlaOp batch_var,XlaOp grad_output,float epsilon,int64_t feature_index)2716 XlaOp XlaBuilder::BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
2717                                 XlaOp batch_var, XlaOp grad_output,
2718                                 float epsilon, int64_t feature_index) {
2719   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2720     HloInstructionProto instr;
2721 
2722     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2723     TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2724     TF_ASSIGN_OR_RETURN(const Shape* batch_mean_shape, GetShapePtr(batch_mean));
2725     TF_ASSIGN_OR_RETURN(const Shape* batch_var_shape, GetShapePtr(batch_var));
2726     TF_ASSIGN_OR_RETURN(const Shape* grad_output_shape,
2727                         GetShapePtr(grad_output));
2728     TF_ASSIGN_OR_RETURN(
2729         Shape shape, ShapeInference::InferBatchNormGradShape(
2730                          *operand_shape, *scale_shape, *batch_mean_shape,
2731                          *batch_var_shape, *grad_output_shape, feature_index));
2732     *instr.mutable_shape() = shape.ToProto();
2733 
2734     instr.set_epsilon(epsilon);
2735     instr.set_feature_index(feature_index);
2736 
2737     return AddInstruction(std::move(instr), HloOpcode::kBatchNormGrad,
2738                           {operand, scale, batch_mean, batch_var, grad_output});
2739   });
2740 }
2741 
AllGather(XlaOp operand,int64_t all_gather_dimension,int64_t shard_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Layout> & layout,const absl::optional<bool> use_global_device_ids)2742 XlaOp XlaBuilder::AllGather(XlaOp operand, int64_t all_gather_dimension,
2743                             int64_t shard_count,
2744                             absl::Span<const ReplicaGroup> replica_groups,
2745                             const absl::optional<ChannelHandle>& channel_id,
2746                             const absl::optional<Layout>& layout,
2747                             const absl::optional<bool> use_global_device_ids) {
2748   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2749     HloInstructionProto instr;
2750     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2751 
2752     TF_ASSIGN_OR_RETURN(
2753         Shape inferred_shape,
2754         ShapeInference::InferAllGatherShape({operand_shape},
2755                                             all_gather_dimension, shard_count));
2756     if (layout) {
2757       *inferred_shape.mutable_layout() = *layout;
2758       instr.set_constrain_layout(true);
2759     }
2760     *instr.mutable_shape() = inferred_shape.ToProto();
2761 
2762     instr.add_dimensions(all_gather_dimension);
2763     for (const ReplicaGroup& group : replica_groups) {
2764       *instr.add_replica_groups() = group;
2765     }
2766     if (channel_id.has_value()) {
2767       instr.set_channel_id(channel_id->handle());
2768     }
2769     if (use_global_device_ids.has_value()) {
2770       instr.set_use_global_device_ids(use_global_device_ids.value());
2771     }
2772 
2773     TF_ASSIGN_OR_RETURN(
2774         auto all_gather,
2775         AddInstruction(std::move(instr), HloOpcode::kAllGather, {operand}));
2776     return all_gather;
2777   });
2778 }
2779 
CrossReplicaSum(XlaOp operand,absl::Span<const ReplicaGroup> replica_groups)2780 XlaOp XlaBuilder::CrossReplicaSum(
2781     XlaOp operand, absl::Span<const ReplicaGroup> replica_groups) {
2782   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2783     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
2784     const Shape* element_shape;
2785     if (shape->IsTuple()) {
2786       if (shape->tuple_shapes_size() == 0) {
2787         return Unimplemented(
2788             "0 element tuple CrossReplicaSum is not supported");
2789       }
2790       element_shape = &shape->tuple_shapes(0);
2791     } else {
2792       element_shape = shape;
2793     }
2794     const Shape scalar_shape =
2795         ShapeUtil::MakeShape(element_shape->element_type(), {});
2796     auto b = CreateSubBuilder("sum");
2797     auto x = b->Parameter(/*parameter_number=*/0, scalar_shape, "x");
2798     auto y = b->Parameter(/*parameter_number=*/1, scalar_shape, "y");
2799     if (scalar_shape.element_type() == PRED) {
2800       Or(x, y);
2801     } else {
2802       Add(x, y);
2803     }
2804     TF_ASSIGN_OR_RETURN(auto computation, b->Build());
2805     return AllReduce(operand, computation, replica_groups,
2806                      /*channel_id=*/absl::nullopt);
2807   });
2808 }
2809 
AllReduce(XlaOp operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Shape> & shape_with_layout)2810 XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
2811                             absl::Span<const ReplicaGroup> replica_groups,
2812                             const absl::optional<ChannelHandle>& channel_id,
2813                             const absl::optional<Shape>& shape_with_layout) {
2814   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2815     HloInstructionProto instr;
2816     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2817     std::vector<const Shape*> operand_shapes;
2818     std::vector<XlaOp> operands;
2819     if (operand_shape->IsTuple()) {
2820       if (operand_shape->tuple_shapes_size() == 0) {
2821         return Unimplemented("0 element tuple AllReduce is not supported");
2822       }
2823       for (int64_t i = 0; i < operand_shape->tuple_shapes_size(); ++i) {
2824         if (operand_shape->tuple_shapes(i).element_type() !=
2825             operand_shape->tuple_shapes(0).element_type()) {
2826           return Unimplemented(
2827               "All the shapes of a tuple input of AllReduce must have the same "
2828               "element type");
2829         }
2830         operand_shapes.push_back(&operand_shape->tuple_shapes(i));
2831         operands.push_back(GetTupleElement(operand, i));
2832       }
2833     } else {
2834       operand_shapes.push_back(operand_shape);
2835       operands.push_back(operand);
2836     }
2837 
2838     TF_ASSIGN_OR_RETURN(Shape inferred_shape,
2839                         ShapeInference::InferAllReduceShape(operand_shapes));
2840     if (shape_with_layout) {
2841       if (!LayoutUtil::HasLayout(*shape_with_layout)) {
2842         return InvalidArgument("shape_with_layout must have the layout set: %s",
2843                                shape_with_layout->ToString());
2844       }
2845       if (!ShapeUtil::Compatible(*shape_with_layout, *operand_shape)) {
2846         return InvalidArgument(
2847             "Provided shape_with_layout must be compatible with the "
2848             "operand shape: %s vs %s",
2849             shape_with_layout->ToString(), operand_shape->ToString());
2850       }
2851       instr.set_constrain_layout(true);
2852       if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
2853         // For a single-element tuple, take the tuple element shape.
2854         TF_RET_CHECK(shape_with_layout->tuple_shapes_size() == 1);
2855         *instr.mutable_shape() = shape_with_layout->tuple_shapes(0).ToProto();
2856       } else {
2857         *instr.mutable_shape() = shape_with_layout->ToProto();
2858       }
2859     } else {
2860       *instr.mutable_shape() = inferred_shape.ToProto();
2861     }
2862 
2863     for (const ReplicaGroup& group : replica_groups) {
2864       *instr.add_replica_groups() = group;
2865     }
2866 
2867     if (channel_id.has_value()) {
2868       instr.set_channel_id(channel_id->handle());
2869     }
2870 
2871     AddCalledComputation(computation, &instr);
2872 
2873     TF_ASSIGN_OR_RETURN(
2874         auto all_reduce,
2875         AddInstruction(std::move(instr), HloOpcode::kAllReduce, operands));
2876     if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
2877       // For a single-element tuple, wrap the result into a tuple.
2878       TF_RET_CHECK(operand_shapes.size() == 1);
2879       TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], inferred_shape));
2880       return Tuple({all_reduce});
2881     }
2882     return all_reduce;
2883   });
2884 }
2885 
ReduceScatter(XlaOp operand,const XlaComputation & computation,int64_t scatter_dimension,int64_t shard_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Layout> & layout,const absl::optional<bool> use_global_device_ids)2886 XlaOp XlaBuilder::ReduceScatter(
2887     XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension,
2888     int64_t shard_count, absl::Span<const ReplicaGroup> replica_groups,
2889     const absl::optional<ChannelHandle>& channel_id,
2890     const absl::optional<Layout>& layout,
2891     const absl::optional<bool> use_global_device_ids) {
2892   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2893     HloInstructionProto instr;
2894     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2895     std::vector<const Shape*> operand_shapes;
2896     std::vector<XlaOp> operands;
2897     if (operand_shape->IsTuple()) {
2898       if (operand_shape->tuple_shapes_size() == 0) {
2899         return Unimplemented("0 element tuple ReduceScatter is not supported");
2900       }
2901       for (int64_t i = 0; i < operand_shape->tuple_shapes_size(); ++i) {
2902         if (operand_shape->tuple_shapes(i).element_type() !=
2903             operand_shape->tuple_shapes(0).element_type()) {
2904           return Unimplemented(
2905               "All the shapes of a tuple input of ReduceScatter must have "
2906               "the same "
2907               "element type");
2908         }
2909         operand_shapes.push_back(&operand_shape->tuple_shapes(i));
2910         operands.push_back(GetTupleElement(operand, i));
2911       }
2912     } else {
2913       operand_shapes.push_back(operand_shape);
2914       operands.push_back(operand);
2915     }
2916 
2917     TF_ASSIGN_OR_RETURN(Shape inferred_shape,
2918                         ShapeInference::InferReduceScatterShape(
2919                             operand_shapes, scatter_dimension, shard_count));
2920     if (layout) {
2921       *inferred_shape.mutable_layout() = *layout;
2922       instr.set_constrain_layout(true);
2923     }
2924     *instr.mutable_shape() = inferred_shape.ToProto();
2925 
2926     AddCalledComputation(computation, &instr);
2927 
2928     instr.add_dimensions(scatter_dimension);
2929     for (const ReplicaGroup& group : replica_groups) {
2930       *instr.add_replica_groups() = group;
2931     }
2932     if (channel_id.has_value()) {
2933       instr.set_channel_id(channel_id->handle());
2934     }
2935     if (use_global_device_ids.has_value()) {
2936       instr.set_use_global_device_ids(use_global_device_ids.value());
2937     }
2938 
2939     TF_ASSIGN_OR_RETURN(
2940         auto reduce_scatter,
2941         AddInstruction(std::move(instr), HloOpcode::kReduceScatter, {operand}));
2942     return reduce_scatter;
2943   });
2944 }
2945 
AllToAll(XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<Layout> & layout)2946 XlaOp XlaBuilder::AllToAll(XlaOp operand, int64_t split_dimension,
2947                            int64_t concat_dimension, int64_t split_count,
2948                            absl::Span<const ReplicaGroup> replica_groups,
2949                            const absl::optional<Layout>& layout) {
2950   // Array all_to_all may need to violate layout constraint to be legal so use
2951   // the tuple version.
2952   if (layout.has_value()) {
2953     return AllToAllTuple(operand, split_dimension, concat_dimension,
2954                          split_count, replica_groups, layout);
2955   }
2956   return AllToAllArray(operand, split_dimension, concat_dimension, split_count,
2957                        replica_groups);
2958 }
2959 
AllToAllArray(XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups)2960 XlaOp XlaBuilder::AllToAllArray(XlaOp operand, int64_t split_dimension,
2961                                 int64_t concat_dimension, int64_t split_count,
2962                                 absl::Span<const ReplicaGroup> replica_groups) {
2963   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2964     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2965     TF_ASSIGN_OR_RETURN(
2966         const Shape all_to_all_shape,
2967         ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
2968                                            concat_dimension, split_count));
2969     HloInstructionProto instr;
2970     *instr.mutable_shape() = operand_shape->ToProto();
2971     if (replica_groups.empty()) {
2972       auto* group = instr.add_replica_groups();
2973       for (int64_t i = 0; i < split_count; ++i) {
2974         group->add_replica_ids(i);
2975       }
2976     } else {
2977       for (const ReplicaGroup& group : replica_groups) {
2978         *instr.add_replica_groups() = group;
2979       }
2980     }
2981     instr.add_dimensions(split_dimension);
2982     TF_ASSIGN_OR_RETURN(
2983         XlaOp all_to_all,
2984         AddInstruction(std::move(instr), HloOpcode::kAllToAll, {operand}));
2985     if (split_dimension == concat_dimension) {
2986       return all_to_all;
2987     }
2988     DimensionVector sizes;
2989     for (int64_t i = 0; i < operand_shape->rank(); ++i) {
2990       if (i != split_dimension) {
2991         sizes.push_back(operand_shape->dimensions(i));
2992         continue;
2993       }
2994       sizes.push_back(split_count);
2995       sizes.push_back(operand_shape->dimensions(i) / split_count);
2996     }
2997     all_to_all = Reshape(all_to_all, sizes);
2998 
2999     std::vector<int64> permutation;
3000     for (int64_t i = 0; i < operand_shape->rank(); ++i) {
3001       int64_t dim_after_reshape = i >= split_dimension ? i + 1 : i;
3002       if (i == concat_dimension) {
3003         permutation.push_back(split_dimension);
3004       }
3005       permutation.push_back(dim_after_reshape);
3006     }
3007     all_to_all = Transpose(all_to_all, permutation);
3008     return Reshape(all_to_all_shape, all_to_all);
3009   });
3010 }
3011 
AllToAllTuple(XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<Layout> & layout)3012 XlaOp XlaBuilder::AllToAllTuple(XlaOp operand, int64_t split_dimension,
3013                                 int64_t concat_dimension, int64_t split_count,
3014                                 absl::Span<const ReplicaGroup> replica_groups,
3015                                 const absl::optional<Layout>& layout) {
3016   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3017     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3018 
3019     // The HloInstruction for Alltoall currently only handles the data
3020     // communication: it accepts N already split parts and scatters them to N
3021     // cores, and each core gathers the N received parts into a tuple as the
3022     // output. So here we explicitly split the operand before the hlo alltoall,
3023     // and concat the tuple elements.
3024     //
3025     // First, run shape inference to make sure the shapes are valid.
3026     TF_RETURN_IF_ERROR(
3027         ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
3028                                            concat_dimension, split_count)
3029             .status());
3030 
3031     // Split into N parts.
3032     std::vector<XlaOp> slices;
3033     slices.reserve(split_count);
3034     const int64_t block_size =
3035         operand_shape->dimensions(split_dimension) / split_count;
3036     for (int i = 0; i < split_count; i++) {
3037       slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size,
3038                                   /*limit_index=*/(i + 1) * block_size,
3039                                   /*stride=*/1, /*dimno=*/split_dimension));
3040     }
3041 
3042     // Handle data communication.
3043     HloInstructionProto instr;
3044     TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices));
3045     std::vector<const Shape*> slice_shape_ptrs;
3046     absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs),
3047                       [](const Shape& shape) { return &shape; });
3048     TF_ASSIGN_OR_RETURN(
3049         Shape shape, ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
3050 
3051     if (layout) {
3052       TF_RET_CHECK(shape.IsTuple() && !ShapeUtil::IsNestedTuple(shape));
3053       for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
3054         const int64_t layout_minor_to_major_size =
3055             layout->minor_to_major().size();
3056         if (layout_minor_to_major_size != shape.tuple_shapes(i).rank()) {
3057           return InvalidArgument(
3058               "Provided layout must be compatible with the operand shape: %s "
3059               "vs %s",
3060               layout->ToString(), operand_shape->ToString());
3061         }
3062         *(shape.mutable_tuple_shapes(i)->mutable_layout()) = *layout;
3063       }
3064       instr.set_constrain_layout(true);
3065     }
3066     *instr.mutable_shape() = shape.ToProto();
3067 
3068     for (const ReplicaGroup& group : replica_groups) {
3069       *instr.add_replica_groups() = group;
3070     }
3071     TF_ASSIGN_OR_RETURN(
3072         XlaOp alltoall,
3073         AddInstruction(std::move(instr), HloOpcode::kAllToAll, slices));
3074 
3075     // Concat the N received parts.
3076     std::vector<XlaOp> received;
3077     received.reserve(split_count);
3078     for (int i = 0; i < split_count; i++) {
3079       received.push_back(this->GetTupleElement(alltoall, i));
3080     }
3081     return this->ConcatInDim(received, concat_dimension);
3082   });
3083 }
3084 
CollectivePermute(XlaOp operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)3085 XlaOp XlaBuilder::CollectivePermute(
3086     XlaOp operand,
3087     const std::vector<std::pair<int64, int64>>& source_target_pairs) {
3088   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3089     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3090     HloInstructionProto instr;
3091     TF_ASSIGN_OR_RETURN(
3092         Shape shape,
3093         ShapeInference::InferCollectivePermuteShape({operand_shape}));
3094     *instr.mutable_shape() = shape.ToProto();
3095 
3096     for (const auto& pair : source_target_pairs) {
3097       auto* proto_pair = instr.add_source_target_pairs();
3098       proto_pair->set_source(pair.first);
3099       proto_pair->set_target(pair.second);
3100     }
3101 
3102     return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute,
3103                           {operand});
3104   });
3105 }
3106 
ReplicaId()3107 XlaOp XlaBuilder::ReplicaId() {
3108   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3109     HloInstructionProto instr;
3110     *instr.mutable_shape() = ShapeUtil::MakeShape(U32, {}).ToProto();
3111     return AddInstruction(std::move(instr), HloOpcode::kReplicaId, {});
3112   });
3113 }
3114 
SelectAndScatter(XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)3115 XlaOp XlaBuilder::SelectAndScatter(XlaOp operand, const XlaComputation& select,
3116                                    absl::Span<const int64> window_dimensions,
3117                                    absl::Span<const int64> window_strides,
3118                                    Padding padding, XlaOp source,
3119                                    XlaOp init_value,
3120                                    const XlaComputation& scatter) {
3121   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3122     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3123 
3124     std::vector<std::pair<int64, int64>> padding_values =
3125         MakePadding(AsInt64Slice(operand_shape->dimensions()),
3126                     window_dimensions, window_strides, padding);
3127 
3128     TF_ASSIGN_OR_RETURN(auto window,
3129                         ShapeInference::InferWindowFromDimensions(
3130                             window_dimensions, window_strides, padding_values,
3131                             /*lhs_dilation=*/{},
3132                             /*rhs_dilation=*/{}));
3133     PaddingType padding_type = PADDING_INVALID;
3134     for (int64_t i = 0; i < operand_shape->rank(); ++i) {
3135       if (operand_shape->is_dynamic_dimension(i) &&
3136           !window_util::IsTrivialWindowDimension(window.dimensions(i)) &&
3137           padding == Padding::kSame) {
3138         // SAME padding can create dynamic padding sizes. The padding size
3139         // need to be rewritten by dynamic padder using HloInstructions. We
3140         // create a CustomCall to handle this.
3141         padding_type = PADDING_SAME;
3142       }
3143     }
3144     if (padding_type == PADDING_SAME) {
3145       TF_ASSIGN_OR_RETURN(
3146           HloInstructionProto instr,
3147           SelectAndScatterInternal(operand, select, window_dimensions,
3148                                    window_strides, padding_values, source,
3149                                    init_value, scatter));
3150       instr.set_custom_call_target("DynamicSelectAndScatterSamePadding");
3151       return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
3152                             {operand, source, init_value});
3153     }
3154     return SelectAndScatterWithGeneralPadding(
3155         operand, select, window_dimensions, window_strides, padding_values,
3156         source, init_value, scatter);
3157   });
3158 }
3159 
SelectAndScatterInternal(XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)3160 StatusOr<HloInstructionProto> XlaBuilder::SelectAndScatterInternal(
3161     XlaOp operand, const XlaComputation& select,
3162     absl::Span<const int64> window_dimensions,
3163     absl::Span<const int64> window_strides,
3164     absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
3165     XlaOp init_value, const XlaComputation& scatter) {
3166   HloInstructionProto instr;
3167 
3168   TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3169   TF_ASSIGN_OR_RETURN(const Shape* source_shape, GetShapePtr(source));
3170   TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
3171   TF_ASSIGN_OR_RETURN(const ProgramShape& select_shape,
3172                       select.GetProgramShape());
3173   TF_ASSIGN_OR_RETURN(const ProgramShape& scatter_shape,
3174                       scatter.GetProgramShape());
3175   TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
3176                       ShapeInference::InferWindowFromDimensions(
3177                           window_dimensions, window_strides, padding,
3178                           /*lhs_dilation=*/{}, /*rhs_dilation=*/{}));
3179   TF_ASSIGN_OR_RETURN(Shape shape,
3180                       ShapeInference::InferSelectAndScatterShape(
3181                           *operand_shape, select_shape, instr.window(),
3182                           *source_shape, *init_shape, scatter_shape));
3183   *instr.mutable_shape() = shape.ToProto();
3184 
3185   AddCalledComputation(select, &instr);
3186   AddCalledComputation(scatter, &instr);
3187   return instr;
3188 }
3189 
SelectAndScatterWithGeneralPadding(XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)3190 XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
3191     XlaOp operand, const XlaComputation& select,
3192     absl::Span<const int64> window_dimensions,
3193     absl::Span<const int64> window_strides,
3194     absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
3195     XlaOp init_value, const XlaComputation& scatter) {
3196   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3197     TF_ASSIGN_OR_RETURN(HloInstructionProto instr,
3198                         SelectAndScatterInternal(
3199                             operand, select, window_dimensions, window_strides,
3200                             padding, source, init_value, scatter));
3201 
3202     return AddInstruction(std::move(instr), HloOpcode::kSelectAndScatter,
3203                           {operand, source, init_value});
3204   });
3205 }
3206 
ReducePrecision(XlaOp operand,const int exponent_bits,const int mantissa_bits)3207 XlaOp XlaBuilder::ReducePrecision(XlaOp operand, const int exponent_bits,
3208                                   const int mantissa_bits) {
3209   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3210     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3211     TF_ASSIGN_OR_RETURN(Shape shape,
3212                         ShapeInference::InferReducePrecisionShape(
3213                             *operand_shape, exponent_bits, mantissa_bits));
3214     return ReducePrecisionInternal(shape, operand, exponent_bits,
3215                                    mantissa_bits);
3216   });
3217 }
3218 
ReducePrecisionInternal(const Shape & shape,XlaOp operand,const int exponent_bits,const int mantissa_bits)3219 StatusOr<XlaOp> XlaBuilder::ReducePrecisionInternal(const Shape& shape,
3220                                                     XlaOp operand,
3221                                                     const int exponent_bits,
3222                                                     const int mantissa_bits) {
3223   HloInstructionProto instr;
3224   *instr.mutable_shape() = shape.ToProto();
3225   instr.set_exponent_bits(exponent_bits);
3226   instr.set_mantissa_bits(mantissa_bits);
3227   return AddInstruction(std::move(instr), HloOpcode::kReducePrecision,
3228                         {operand});
3229 }
3230 
Send(XlaOp operand,const ChannelHandle & handle)3231 void XlaBuilder::Send(XlaOp operand, const ChannelHandle& handle) {
3232   ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3233     // Send HLO takes two operands: a data operand and a token. Generate the
3234     // token to pass into the send.
3235     // TODO(b/80000000): Remove this when clients have been updated to handle
3236     // tokens.
3237     HloInstructionProto token_instr;
3238     *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3239     TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
3240                                                     HloOpcode::kAfterAll, {}));
3241 
3242     return SendWithToken(operand, token, handle);
3243   });
3244 }
3245 
SendWithToken(XlaOp operand,XlaOp token,const ChannelHandle & handle)3246 XlaOp XlaBuilder::SendWithToken(XlaOp operand, XlaOp token,
3247                                 const ChannelHandle& handle) {
3248   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3249     if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
3250       return InvalidArgument("Send must use a device-to-device channel");
3251     }
3252 
3253     // Send instruction produces a tuple of {aliased operand, U32 context,
3254     // token}.
3255     HloInstructionProto send_instr;
3256     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
3257     *send_instr.mutable_shape() =
3258         ShapeUtil::MakeTupleShape({*shape, ShapeUtil::MakeShape(U32, {}),
3259                                    ShapeUtil::MakeTokenShape()})
3260             .ToProto();
3261     send_instr.set_channel_id(handle.handle());
3262     TF_ASSIGN_OR_RETURN(XlaOp send,
3263                         AddInstruction(std::move(send_instr), HloOpcode::kSend,
3264                                        {operand, token}));
3265 
3266     HloInstructionProto send_done_instr;
3267     *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3268     send_done_instr.set_channel_id(handle.handle());
3269     return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
3270                           {send});
3271   });
3272 }
3273 
Recv(const Shape & shape,const ChannelHandle & handle)3274 XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
3275   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3276     // Recv HLO takes a single token operand. Generate the token to pass into
3277     // the Recv and RecvDone instructions.
3278     // TODO(b/80000000): Remove this when clients have been updated to handle
3279     // tokens.
3280     HloInstructionProto token_instr;
3281     *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3282     TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
3283                                                     HloOpcode::kAfterAll, {}));
3284 
3285     XlaOp recv = RecvWithToken(token, shape, handle);
3286 
3287     // The RecvDone instruction produces a tuple of the data and a token
3288     // type. Return XLA op containing the data.
3289     // TODO(b/80000000): Remove this when clients have been updated to handle
3290     // tokens.
3291     HloInstructionProto recv_data;
3292     *recv_data.mutable_shape() = shape.ToProto();
3293     recv_data.set_tuple_index(0);
3294     return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement,
3295                           {recv});
3296   });
3297 }
3298 
RecvWithToken(XlaOp token,const Shape & shape,const ChannelHandle & handle)3299 XlaOp XlaBuilder::RecvWithToken(XlaOp token, const Shape& shape,
3300                                 const ChannelHandle& handle) {
3301   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3302     if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
3303       return InvalidArgument("Recv must use a device-to-device channel");
3304     }
3305 
3306     // Recv instruction produces a tuple of {receive buffer, U32 context,
3307     // token}.
3308     HloInstructionProto recv_instr;
3309     *recv_instr.mutable_shape() =
3310         ShapeUtil::MakeTupleShape(
3311             {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
3312             .ToProto();
3313     recv_instr.set_channel_id(handle.handle());
3314     TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
3315                                                    HloOpcode::kRecv, {token}));
3316 
3317     HloInstructionProto recv_done_instr;
3318     *recv_done_instr.mutable_shape() =
3319         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
3320             .ToProto();
3321     recv_done_instr.set_channel_id(handle.handle());
3322     return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
3323                           {recv});
3324   });
3325 }
3326 
SendToHost(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const ChannelHandle & handle)3327 XlaOp XlaBuilder::SendToHost(XlaOp operand, XlaOp token,
3328                              const Shape& shape_with_layout,
3329                              const ChannelHandle& handle) {
3330   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3331     if (!LayoutUtil::HasLayout(shape_with_layout)) {
3332       return InvalidArgument("Shape passed to SendToHost must have a layout");
3333     }
3334     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3335     if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
3336       return InvalidArgument(
3337           "SendToHost shape %s must be compatible with operand shape %s",
3338           ShapeUtil::HumanStringWithLayout(shape_with_layout),
3339           ShapeUtil::HumanStringWithLayout(*operand_shape));
3340     }
3341     // TODO(b/111544877): Support tuple shapes.
3342     if (!operand_shape->IsArray()) {
3343       return InvalidArgument("SendToHost only supports array shapes, shape: %s",
3344                              ShapeUtil::HumanString(*operand_shape));
3345     }
3346 
3347     if (handle.type() != ChannelHandle::DEVICE_TO_HOST) {
3348       return InvalidArgument("SendToHost must use a device-to-host channel");
3349     }
3350 
3351     // Send instruction produces a tuple of {aliased operand, U32 context,
3352     // token}.
3353     HloInstructionProto send_instr;
3354     *send_instr.mutable_shape() =
3355         ShapeUtil::MakeTupleShape({shape_with_layout,
3356                                    ShapeUtil::MakeShape(U32, {}),
3357                                    ShapeUtil::MakeTokenShape()})
3358             .ToProto();
3359     send_instr.set_channel_id(handle.handle());
3360     send_instr.set_is_host_transfer(true);
3361     TF_ASSIGN_OR_RETURN(XlaOp send,
3362                         AddInstruction(std::move(send_instr), HloOpcode::kSend,
3363                                        {operand, token}));
3364 
3365     HloInstructionProto send_done_instr;
3366     *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3367     send_done_instr.set_channel_id(handle.handle());
3368     send_done_instr.set_is_host_transfer(true);
3369     return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
3370                           {send});
3371   });
3372 }
3373 
RecvFromHost(XlaOp token,const Shape & shape,const ChannelHandle & handle)3374 XlaOp XlaBuilder::RecvFromHost(XlaOp token, const Shape& shape,
3375                                const ChannelHandle& handle) {
3376   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3377     if (!LayoutUtil::HasLayout(shape)) {
3378       return InvalidArgument("Shape passed to RecvFromHost must have a layout");
3379     }
3380 
3381     // TODO(b/111544877): Support tuple shapes.
3382     if (!shape.IsArray()) {
3383       return InvalidArgument(
3384           "RecvFromHost only supports array shapes, shape: %s",
3385           ShapeUtil::HumanString(shape));
3386     }
3387 
3388     if (handle.type() != ChannelHandle::HOST_TO_DEVICE) {
3389       return InvalidArgument("RecvFromHost must use a host-to-device channel");
3390     }
3391 
3392     // Recv instruction produces a tuple of {receive buffer, U32 context,
3393     // token}.
3394     HloInstructionProto recv_instr;
3395     *recv_instr.mutable_shape() =
3396         ShapeUtil::MakeTupleShape(
3397             {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
3398             .ToProto();
3399     recv_instr.set_channel_id(handle.handle());
3400     recv_instr.set_is_host_transfer(true);
3401     TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
3402                                                    HloOpcode::kRecv, {token}));
3403 
3404     HloInstructionProto recv_done_instr;
3405     *recv_done_instr.mutable_shape() =
3406         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
3407             .ToProto();
3408     recv_done_instr.set_channel_id(handle.handle());
3409     recv_done_instr.set_is_host_transfer(true);
3410     return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
3411                           {recv});
3412   });
3413 }
3414 
GetDimensionSize(XlaOp operand,int64_t dimension)3415 XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64_t dimension) {
3416   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3417     HloInstructionProto instr;
3418     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3419     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape(
3420                                          *operand_shape, dimension));
3421     // Calling GetDimensionSize on a static dimension returns a constant
3422     // instruction.
3423     if (!operand_shape->is_dynamic_dimension(dimension)) {
3424       return ConstantR0<int32>(this, operand_shape->dimensions(dimension));
3425     }
3426     *instr.mutable_shape() = shape.ToProto();
3427     instr.add_dimensions(dimension);
3428     return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize,
3429                           {operand});
3430   });
3431 }
3432 
RemoveDynamicDimension(XlaOp operand,int64_t dimension)3433 XlaOp XlaBuilder::RemoveDynamicDimension(XlaOp operand, int64_t dimension) {
3434   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3435     HloInstructionProto instr;
3436     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3437 
3438     Shape shape = *operand_shape;
3439     shape.set_dynamic_dimension(dimension, false);
3440     // Setting an op's dynamic dimension to its static size removes the dynamic
3441     // dimension.
3442     XlaOp static_size =
3443         ConstantR0<int32>(this, operand_shape->dimensions(dimension));
3444     return SetDimensionSizeInternal(shape, operand, static_size, dimension);
3445   });
3446 }
3447 
SetDimensionSize(XlaOp operand,XlaOp val,int64_t dimension)3448 XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val,
3449                                    int64_t dimension) {
3450   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3451     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3452     TF_ASSIGN_OR_RETURN(const Shape* val_shape, GetShapePtr(val));
3453 
3454     TF_ASSIGN_OR_RETURN(Shape shape,
3455                         ShapeInference::InferSetDimensionSizeShape(
3456                             *operand_shape, *val_shape, dimension));
3457     return SetDimensionSizeInternal(shape, operand, val, dimension);
3458   });
3459 }
3460 
SetDimensionSizeInternal(const Shape & shape,XlaOp operand,XlaOp val,int64_t dimension)3461 StatusOr<XlaOp> XlaBuilder::SetDimensionSizeInternal(const Shape& shape,
3462                                                      XlaOp operand, XlaOp val,
3463                                                      int64_t dimension) {
3464   TF_ASSIGN_OR_RETURN(const HloInstructionProto* val_proto,
3465                       LookUpInstruction(val));
3466   if (StringToHloOpcode(val_proto->opcode()).ValueOrDie() ==
3467           HloOpcode::kConstant &&
3468       shape.is_dynamic_dimension(dimension)) {
3469     TF_ASSIGN_OR_RETURN(auto constant_size,
3470                         Literal::CreateFromProto(val_proto->literal(), true));
3471     if (constant_size.Get<int32>({}) == shape.dimensions(dimension)) {
3472       return operand;
3473     }
3474   }
3475 
3476   HloInstructionProto instr;
3477   *instr.mutable_shape() = shape.ToProto();
3478   instr.add_dimensions(dimension);
3479   return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize,
3480                         {operand, val});
3481 }
3482 
IsConstant(XlaOp operand) const3483 StatusOr<bool> XlaBuilder::IsConstant(XlaOp operand) const {
3484   TF_RETURN_IF_ERROR(first_error_);
3485 
3486   // Verify that the handle is valid.
3487   TF_RETURN_IF_ERROR(LookUpInstruction(operand).status());
3488 
3489   bool is_constant = true;
3490   absl::flat_hash_set<int64> visited;
3491   IsConstantVisitor(operand.handle(), /*depth=*/0, &visited, &is_constant);
3492   return is_constant;
3493 }
3494 
BuildConstantSubGraph(XlaOp root_op,bool dynamic_dimension_is_minus_one)3495 StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
3496     XlaOp root_op, bool dynamic_dimension_is_minus_one) {
3497   TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
3498   if (!is_constant) {
3499     auto op_status = LookUpInstruction(root_op);
3500     string op_string =
3501         op_status.ok() ? op_status.ValueOrDie()->name() : "<unknown operation>";
3502     return InvalidArgument(
3503         "Operand to BuildConstantSubGraph depends on a parameter.\n\n"
3504         "  op requested for constant subgraph: %s\n\n"
3505         "This is an internal error that typically happens when the XLA user "
3506         "(e.g. TensorFlow) is attempting to determine a value that must be a "
3507         "compile-time constant (e.g. an array dimension) but it is not capable "
3508         "of being evaluated at XLA compile time.\n\n"
3509         "Please file a usability bug with the framework being used (e.g. "
3510         "TensorFlow).",
3511         op_string);
3512   }
3513 
3514   TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
3515                       LookUpInstruction(root_op));
3516   if (VLOG_IS_ON(4)) {
3517     VLOG(4) << "Build constant subgraph for:\n" << OpToString(root_op);
3518   }
3519 
3520   HloComputationProto entry;
3521   SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator,
3522                     GetNextId());
3523   ProgramShapeProto* program_shape = entry.mutable_program_shape();
3524   *program_shape->mutable_result() = root->shape();
3525 
3526   // We use std::set to keep the instruction ids in ascending order (which is
3527   // also a valid dependency order). The related ops will be added to the
3528   // subgraph in the same order.
3529   std::set<int64> related_ops;
3530   absl::flat_hash_map<int64, int64> substitutions;
3531   absl::flat_hash_set<int64> related_calls;  // Related computations.
3532   std::queue<int64> worklist;
3533   worklist.push(root->id());
3534   related_ops.insert(root->id());
3535 
3536   while (!worklist.empty()) {
3537     int64_t handle = worklist.front();
3538     worklist.pop();
3539     TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
3540                         LookUpInstructionByHandle(handle));
3541 
3542     auto default_behavior = [&related_ops, &worklist, &related_calls,
3543                              instr_proto]() {
3544       for (int64_t id : instr_proto->operand_ids()) {
3545         if (related_ops.insert(id).second) {
3546           worklist.push(id);
3547         }
3548       }
3549       for (int64_t called_id : instr_proto->called_computation_ids()) {
3550         related_calls.insert(called_id);
3551       }
3552     };
3553 
3554     if (instr_proto->opcode() ==
3555             HloOpcodeString(HloOpcode::kGetDimensionSize) ||
3556         InstrIsSetBound(instr_proto)) {
3557       int32_t constant_value = -1;
3558       HloInstructionProto const_instr;
3559 
3560       if (instr_proto->opcode() ==
3561           HloOpcodeString(HloOpcode::kGetDimensionSize)) {
3562         // At this point, BuildConstantSubGraph should never encounter a
3563         // GetDimensionSize with a dynamic dimension. IsConstant check would
3564         // have failed at the beginning of this function.
3565         //
3566         // Replace GetDimensionSize with a Constant representing the static
3567         // bound of the shape.
3568         int64_t dimension = instr_proto->dimensions(0);
3569         int64_t operand_handle = instr_proto->operand_ids(0);
3570         TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
3571                             LookUpInstructionByHandle(operand_handle));
3572 
3573         if (!(operand_proto->shape().is_dynamic_dimension(dimension) &&
3574               dynamic_dimension_is_minus_one)) {
3575           constant_value =
3576               static_cast<int32>(operand_proto->shape().dimensions(dimension));
3577         }
3578         Literal literal = LiteralUtil::CreateR0(constant_value);
3579         *const_instr.mutable_literal() = literal.ToProto();
3580         *const_instr.mutable_shape() = literal.shape().ToProto();
3581       } else {
3582         if (instr_proto->literal().shape().element_type() == TUPLE) {
3583           *const_instr.mutable_literal() =
3584               // First literal of SetBound contains bounds, second literal
3585               // contains dynamism indicators.
3586               instr_proto->literal().tuple_literals(0);
3587         } else {
3588           *const_instr.mutable_literal() = instr_proto->literal();
3589         }
3590 
3591         *const_instr.mutable_shape() = instr_proto->shape();
3592       }
3593       *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
3594       const_instr.set_id(handle);
3595       *const_instr.mutable_name() =
3596           GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id());
3597       *entry.add_instructions() =
3598           const_instr;  // Add to the result constant graph.
3599 
3600     } else if (instr_proto->opcode() ==
3601                HloOpcodeString(HloOpcode::kGetTupleElement)) {
3602       // Look through GTE(Tuple(..), i).
3603       TF_ASSIGN_OR_RETURN(
3604           const HloInstructionProto* maybe_tuple_instr,
3605           LookUpInstructionByHandle(instr_proto->operand_ids(0)));
3606 
3607       if (maybe_tuple_instr->opcode() == HloOpcodeString(HloOpcode::kTuple)) {
3608         int64_t id = maybe_tuple_instr->operand_ids(instr_proto->tuple_index());
3609         // Enqueue any dependencies of `id`.
3610         if (related_ops.insert(id).second) {
3611           worklist.push(id);
3612         }
3613         substitutions[handle] = id;
3614 
3615       } else {
3616         default_behavior();
3617       }
3618 
3619     } else {
3620       default_behavior();
3621     }
3622   }
3623 
3624   // Resolve any substitutions for the root id.
3625   int64_t root_id = root->id();
3626   auto it = substitutions.find(root_id);
3627   while (it != substitutions.end()) {
3628     root_id = it->second;
3629     it = substitutions.find(root_id);
3630   }
3631   entry.set_root_id(root_id);
3632 
3633   // Add related ops to the computation.
3634   for (int64_t id : related_ops) {
3635     if (substitutions.find(id) != substitutions.end()) {
3636       // Skip adding this instruction; we will replace references to it with the
3637       // substitution instruction's id.
3638       continue;
3639     }
3640     TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
3641                         LookUpInstructionByHandle(id));
3642 
3643     if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize) ||
3644         InstrIsSetBound(instr_src)) {
3645       continue;
3646     }
3647     HloInstructionProto* instr = entry.add_instructions();
3648     *instr = *instr_src;
3649     // Replace operands in case we have substitutions mapped.
3650     instr->clear_operand_ids();
3651     for (int64_t operand_id : instr_src->operand_ids()) {
3652       auto it = substitutions.find(operand_id);
3653       while (it != substitutions.end()) {
3654         operand_id = it->second;
3655         it = substitutions.find(operand_id);
3656       }
3657       instr->add_operand_ids(operand_id);
3658     }
3659     // Ensures that the instruction names are unique among the graph.
3660     const string& new_name =
3661         StrCat(instr->name(), ".", entry.id(), ".", instr->id());
3662     instr->set_name(new_name);
3663   }
3664 
3665   XlaComputation computation(entry.id());
3666   HloModuleProto* module = computation.mutable_proto();
3667   module->set_name(entry.name());
3668   module->set_id(entry.id());
3669   module->set_entry_computation_name(entry.name());
3670   module->set_entry_computation_id(entry.id());
3671   *module->mutable_host_program_shape() = *program_shape;
3672   for (auto& e : embedded_) {
3673     if (related_calls.find(e.second.id()) != related_calls.end()) {
3674       *module->add_computations() = e.second;
3675     }
3676   }
3677   *module->add_computations() = std::move(entry);
3678   if (VLOG_IS_ON(4)) {
3679     VLOG(4) << "Constant computation:\n" << module->DebugString();
3680   }
3681   return std::move(computation);
3682 }
3683 
CreateSubBuilder(const string & computation_name)3684 std::unique_ptr<XlaBuilder> XlaBuilder::CreateSubBuilder(
3685     const string& computation_name) {
3686   auto sub_builder = absl::make_unique<XlaBuilder>(computation_name);
3687   sub_builder->parent_builder_ = this;
3688   sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_;
3689   return sub_builder;
3690 }
3691 
3692 /* static */ ConvolutionDimensionNumbers
CreateDefaultConvDimensionNumbers(int num_spatial_dims)3693 XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
3694   ConvolutionDimensionNumbers dimension_numbers;
3695   dimension_numbers.set_input_batch_dimension(kConvBatchDimension);
3696   dimension_numbers.set_input_feature_dimension(kConvFeatureDimension);
3697   dimension_numbers.set_output_batch_dimension(kConvBatchDimension);
3698   dimension_numbers.set_output_feature_dimension(kConvFeatureDimension);
3699   dimension_numbers.set_kernel_output_feature_dimension(
3700       kConvKernelOutputDimension);
3701   dimension_numbers.set_kernel_input_feature_dimension(
3702       kConvKernelInputDimension);
3703   for (int i = 0; i < num_spatial_dims; ++i) {
3704     dimension_numbers.add_input_spatial_dimensions(i + 2);
3705     dimension_numbers.add_kernel_spatial_dimensions(i + 2);
3706     dimension_numbers.add_output_spatial_dimensions(i + 2);
3707   }
3708   return dimension_numbers;
3709 }
3710 
Validate(const ConvolutionDimensionNumbers & dnum)3711 /* static */ Status XlaBuilder::Validate(
3712     const ConvolutionDimensionNumbers& dnum) {
3713   if (dnum.input_spatial_dimensions_size() < 2) {
3714     return FailedPrecondition("input spacial dimension < 2: %d",
3715                               dnum.input_spatial_dimensions_size());
3716   }
3717   if (dnum.kernel_spatial_dimensions_size() < 2) {
3718     return FailedPrecondition("kernel spacial dimension < 2: %d",
3719                               dnum.kernel_spatial_dimensions_size());
3720   }
3721   if (dnum.output_spatial_dimensions_size() < 2) {
3722     return FailedPrecondition("output spacial dimension < 2: %d",
3723                               dnum.output_spatial_dimensions_size());
3724   }
3725 
3726   if (std::set<int64>(
3727           {dnum.input_batch_dimension(), dnum.input_feature_dimension(),
3728            dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)})
3729           .size() != 4) {
3730     return FailedPrecondition(
3731         "dimension numbers for the input are not unique: (%d, %d, %d, "
3732         "%d)",
3733         dnum.input_batch_dimension(), dnum.input_feature_dimension(),
3734         dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1));
3735   }
3736   if (std::set<int64>({dnum.kernel_output_feature_dimension(),
3737                        dnum.kernel_input_feature_dimension(),
3738                        dnum.kernel_spatial_dimensions(0),
3739                        dnum.kernel_spatial_dimensions(1)})
3740           .size() != 4) {
3741     return FailedPrecondition(
3742         "dimension numbers for the weight are not unique: (%d, %d, %d, "
3743         "%d)",
3744         dnum.kernel_output_feature_dimension(),
3745         dnum.kernel_input_feature_dimension(),
3746         dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1));
3747   }
3748   if (std::set<int64>({dnum.output_batch_dimension(),
3749                        dnum.output_feature_dimension(),
3750                        dnum.output_spatial_dimensions(0),
3751                        dnum.output_spatial_dimensions(1)})
3752           .size() != 4) {
3753     return FailedPrecondition(
3754         "dimension numbers for the output are not unique: (%d, %d, %d, "
3755         "%d)",
3756         dnum.output_batch_dimension(), dnum.output_feature_dimension(),
3757         dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1));
3758   }
3759   return Status::OK();
3760 }
3761 
AddInstruction(HloInstructionProto && instr,HloOpcode opcode,absl::Span<const XlaOp> operands)3762 StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
3763                                            HloOpcode opcode,
3764                                            absl::Span<const XlaOp> operands) {
3765   TF_RETURN_IF_ERROR(first_error_);
3766 
3767   const int64_t handle = GetNextId();
3768   instr.set_id(handle);
3769   instr.set_opcode(HloOpcodeString(opcode));
3770   if (instr.name().empty()) {
3771     instr.set_name(instr.opcode());
3772   }
3773   for (const auto& operand : operands) {
3774     if (operand.builder_ == nullptr) {
3775       return InvalidArgument("invalid XlaOp with handle %d", operand.handle());
3776     }
3777     if (operand.builder_ != this) {
3778       return InvalidArgument("Do not add XlaOp from builder %s to builder %s",
3779                              operand.builder_->name(), this->name());
3780     }
3781     instr.add_operand_ids(operand.handle());
3782   }
3783 
3784   if (one_shot_metadata_.has_value()) {
3785     *instr.mutable_metadata() = one_shot_metadata_.value();
3786     one_shot_metadata_.reset();
3787   } else {
3788     *instr.mutable_metadata() = metadata_;
3789   }
3790   if (sharding_) {
3791     *instr.mutable_sharding() = *sharding_;
3792   }
3793   *instr.mutable_frontend_attributes() = frontend_attributes_;
3794 
3795   handle_to_index_[handle] = instructions_.size();
3796   instructions_.push_back(std::move(instr));
3797   instruction_shapes_.push_back(
3798       absl::make_unique<Shape>(instructions_.back().shape()));
3799 
3800   XlaOp op(handle, this);
3801   return op;
3802 }
3803 
AddOpWithShape(HloOpcode opcode,const Shape & shape,absl::Span<const XlaOp> operands)3804 StatusOr<XlaOp> XlaBuilder::AddOpWithShape(HloOpcode opcode, const Shape& shape,
3805                                            absl::Span<const XlaOp> operands) {
3806   HloInstructionProto instr;
3807   *instr.mutable_shape() = shape.ToProto();
3808   return AddInstruction(std::move(instr), opcode, operands);
3809 }
3810 
AddCalledComputation(const XlaComputation & computation,HloInstructionProto * instr)3811 void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
3812                                       HloInstructionProto* instr) {
3813   absl::flat_hash_map<int64, int64> remapped_ids;
3814   std::vector<HloComputationProto> imported_computations;
3815   imported_computations.reserve(computation.proto().computations_size());
3816   // Before we import the computations by remapping IDs, and capturing the
3817   // old->new mappings in remapped_ids.
3818   for (const HloComputationProto& e : computation.proto().computations()) {
3819     HloComputationProto new_computation(e);
3820     int64_t computation_id = GetNextId();
3821     remapped_ids[new_computation.id()] = computation_id;
3822     SetProtoIdAndName(&new_computation,
3823                       GetBaseName(new_computation.name(), kNameSeparator),
3824                       kNameSeparator, computation_id);
3825     for (auto& instruction : *new_computation.mutable_instructions()) {
3826       int64_t instruction_id = GetNextId();
3827       remapped_ids[instruction.id()] = instruction_id;
3828       SetProtoIdAndName(&instruction,
3829                         GetBaseName(instruction.name(), kNameSeparator),
3830                         kNameSeparator, instruction_id);
3831     }
3832     new_computation.set_root_id(remapped_ids.at(new_computation.root_id()));
3833 
3834     imported_computations.push_back(std::move(new_computation));
3835   }
3836   // Once we have imported all the computations, and captured all the ID
3837   // mappings, we go back and fixup the IDs in the imported computations.
3838   instr->add_called_computation_ids(
3839       remapped_ids.at(computation.proto().entry_computation_id()));
3840   for (auto& imported_computation : imported_computations) {
3841     for (auto& instruction : *imported_computation.mutable_instructions()) {
3842       for (auto& operand_id : *instruction.mutable_operand_ids()) {
3843         operand_id = remapped_ids.at(operand_id);
3844       }
3845       for (auto& control_predecessor_id :
3846            *instruction.mutable_control_predecessor_ids()) {
3847         control_predecessor_id = remapped_ids.at(control_predecessor_id);
3848       }
3849       for (auto& called_computation_id :
3850            *instruction.mutable_called_computation_ids()) {
3851         called_computation_id = remapped_ids.at(called_computation_id);
3852       }
3853     }
3854 
3855     int64_t computation_id = imported_computation.id();
3856     for (int64_t i = 0; i < imported_computation.instructions_size(); ++i) {
3857       ImportedInstruction imported_instruction;
3858       imported_instruction.computation_id = computation_id;
3859       imported_instruction.instruction_index = i;
3860       handle_to_imported_index_.insert(
3861           {imported_computation.instructions(i).id(), imported_instruction});
3862     }
3863     embedded_.insert({computation_id, std::move(imported_computation)});
3864   }
3865 }
3866 
LookUpInstruction(const XlaOp op) const3867 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
3868     const XlaOp op) const {
3869   TF_RETURN_IF_ERROR(first_error_);
3870   return LookUpInstructionInternal<const HloInstructionProto*>(op);
3871 }
3872 
LookUpInstructionByHandle(int64_t handle) const3873 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
3874     int64_t handle) const {
3875   return LookUpInstructionByHandleInternal<const HloInstructionProto*>(handle);
3876 }
3877 
LookUpMutableInstruction(const XlaOp op)3878 StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstruction(
3879     const XlaOp op) {
3880   TF_RETURN_IF_ERROR(first_error_);
3881   return LookUpInstructionInternal<HloInstructionProto*>(op);
3882 }
3883 
LookUpMutableInstructionByHandle(int64_t handle)3884 StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstructionByHandle(
3885     int64_t handle) {
3886   return LookUpInstructionByHandleInternal<HloInstructionProto*>(handle);
3887 }
3888 
3889 // Enqueues a "retrieve parameter value" instruction for a parameter that was
3890 // passed to the computation.
Parameter(XlaBuilder * builder,int64_t parameter_number,const Shape & shape,const string & name)3891 XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
3892                 const Shape& shape, const string& name) {
3893   std::vector<bool> empty_bools;
3894   return Parameter(builder, parameter_number, shape, name, empty_bools);
3895 }
3896 
Parameter(XlaBuilder * builder,int64_t parameter_number,const Shape & shape,const string & name,const std::vector<bool> & replicated_at_leaf_buffers)3897 XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
3898                 const Shape& shape, const string& name,
3899                 const std::vector<bool>& replicated_at_leaf_buffers) {
3900   return builder->Parameter(parameter_number, shape, name,
3901                             replicated_at_leaf_buffers);
3902 }
3903 
3904 // Enqueues a constant with the value of the given literal onto the
3905 // computation.
ConstantLiteral(XlaBuilder * builder,const LiteralSlice & literal)3906 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) {
3907   return builder->ConstantLiteral(literal);
3908 }
3909 
Broadcast(const XlaOp operand,absl::Span<const int64> broadcast_sizes)3910 XlaOp Broadcast(const XlaOp operand, absl::Span<const int64> broadcast_sizes) {
3911   return operand.builder()->Broadcast(operand, broadcast_sizes);
3912 }
3913 
BroadcastInDim(const XlaOp operand,const absl::Span<const int64> out_dim_size,const absl::Span<const int64> broadcast_dimensions)3914 XlaOp BroadcastInDim(const XlaOp operand,
3915                      const absl::Span<const int64> out_dim_size,
3916                      const absl::Span<const int64> broadcast_dimensions) {
3917   return operand.builder()->BroadcastInDim(operand, out_dim_size,
3918                                            broadcast_dimensions);
3919 }
3920 
Copy(const XlaOp operand)3921 XlaOp Copy(const XlaOp operand) {
3922   return operand.builder()->UnaryOp(HloOpcode::kCopy, operand);
3923 }
3924 
Pad(const XlaOp operand,const XlaOp padding_value,const PaddingConfig & padding_config)3925 XlaOp Pad(const XlaOp operand, const XlaOp padding_value,
3926           const PaddingConfig& padding_config) {
3927   return operand.builder()->Pad(operand, padding_value, padding_config);
3928 }
3929 
PadInDim(XlaOp operand,XlaOp padding_value,int64_t dimno,int64_t pad_lo,int64_t pad_hi)3930 XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
3931                int64_t pad_lo, int64_t pad_hi) {
3932   return operand.builder()->PadInDim(operand, padding_value, dimno, pad_lo,
3933                                      pad_hi);
3934 }
3935 
Reshape(const XlaOp operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes)3936 XlaOp Reshape(const XlaOp operand, absl::Span<const int64> dimensions,
3937               absl::Span<const int64> new_sizes) {
3938   return operand.builder()->Reshape(operand, dimensions, new_sizes);
3939 }
3940 
Reshape(const XlaOp operand,absl::Span<const int64> new_sizes)3941 XlaOp Reshape(const XlaOp operand, absl::Span<const int64> new_sizes) {
3942   return operand.builder()->Reshape(operand, new_sizes);
3943 }
3944 
Reshape(const Shape & shape,XlaOp operand)3945 XlaOp Reshape(const Shape& shape, XlaOp operand) {
3946   return operand.builder()->Reshape(shape, operand);
3947 }
3948 
DynamicReshape(XlaOp operand,absl::Span<const XlaOp> dim_sizes,absl::Span<const int64> new_size_bounds,const std::vector<bool> & dims_are_dynamic)3949 XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
3950                      absl::Span<const int64> new_size_bounds,
3951                      const std::vector<bool>& dims_are_dynamic) {
3952   return operand.builder()->DynamicReshape(operand, dim_sizes, new_size_bounds,
3953                                            dims_are_dynamic);
3954 }
3955 
ReshapeWithInferredDimension(XlaOp operand,absl::Span<const int64> new_sizes,int64_t inferred_dimension)3956 XlaOp ReshapeWithInferredDimension(XlaOp operand,
3957                                    absl::Span<const int64> new_sizes,
3958                                    int64_t inferred_dimension) {
3959   return operand.builder()->Reshape(operand, new_sizes, inferred_dimension);
3960 }
3961 
Collapse(const XlaOp operand,absl::Span<const int64> dimensions)3962 XlaOp Collapse(const XlaOp operand, absl::Span<const int64> dimensions) {
3963   return operand.builder()->Collapse(operand, dimensions);
3964 }
3965 
Slice(const XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)3966 XlaOp Slice(const XlaOp operand, absl::Span<const int64> start_indices,
3967             absl::Span<const int64> limit_indices,
3968             absl::Span<const int64> strides) {
3969   return operand.builder()->Slice(operand, start_indices, limit_indices,
3970                                   strides);
3971 }
3972 
SliceInDim(const XlaOp operand,int64_t start_index,int64_t limit_index,int64_t stride,int64_t dimno)3973 XlaOp SliceInDim(const XlaOp operand, int64_t start_index, int64_t limit_index,
3974                  int64_t stride, int64_t dimno) {
3975   return operand.builder()->SliceInDim(operand, start_index, limit_index,
3976                                        stride, dimno);
3977 }
3978 
DynamicSlice(const XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)3979 XlaOp DynamicSlice(const XlaOp operand, absl::Span<const XlaOp> start_indices,
3980                    absl::Span<const int64> slice_sizes) {
3981   return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
3982 }
3983 
DynamicUpdateSlice(const XlaOp operand,const XlaOp update,absl::Span<const XlaOp> start_indices)3984 XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update,
3985                          absl::Span<const XlaOp> start_indices) {
3986   return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
3987 }
3988 
ConcatInDim(XlaBuilder * builder,absl::Span<const XlaOp> operands,int64_t dimension)3989 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
3990                   int64_t dimension) {
3991   return builder->ConcatInDim(operands, dimension);
3992 }
3993 
Trace(const string & tag,const XlaOp operand)3994 void Trace(const string& tag, const XlaOp operand) {
3995   return operand.builder()->Trace(tag, operand);
3996 }
3997 
Select(const XlaOp pred,const XlaOp on_true,const XlaOp on_false)3998 XlaOp Select(const XlaOp pred, const XlaOp on_true, const XlaOp on_false) {
3999   return pred.builder()->Select(pred, on_true, on_false);
4000 }
4001 
Tuple(XlaBuilder * builder,absl::Span<const XlaOp> elements)4002 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements) {
4003   return builder->Tuple(elements);
4004 }
4005 
GetTupleElement(const XlaOp tuple_data,int64_t index)4006 XlaOp GetTupleElement(const XlaOp tuple_data, int64_t index) {
4007   return tuple_data.builder()->GetTupleElement(tuple_data, index);
4008 }
4009 
Eq(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4010 XlaOp Eq(const XlaOp lhs, const XlaOp rhs,
4011          absl::Span<const int64> broadcast_dimensions) {
4012   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
4013 }
4014 
CompareTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions,ComparisonDirection comparison_direction)4015 static XlaOp CompareTotalOrder(const XlaOp lhs, const XlaOp rhs,
4016                                absl::Span<const int64> broadcast_dimensions,
4017                                ComparisonDirection comparison_direction) {
4018   auto b = lhs.builder();
4019   return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
4020     TF_ASSIGN_OR_RETURN(auto operand_shape, b->GetShape(lhs));
4021     auto operand_element_type = operand_shape.element_type();
4022     auto compare_type =
4023         primitive_util::IsFloatingPointType(operand_element_type)
4024             ? Comparison::Type::kFloatTotalOrder
4025             : Comparison::DefaultComparisonType(operand_element_type);
4026     return Compare(lhs, rhs, broadcast_dimensions, comparison_direction,
4027                    compare_type);
4028   });
4029 }
4030 
EqTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4031 XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs,
4032                    absl::Span<const int64> broadcast_dimensions) {
4033   return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4034                            ComparisonDirection::kEq);
4035 }
4036 
Ne(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4037 XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
4038          absl::Span<const int64> broadcast_dimensions) {
4039   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe);
4040 }
4041 
NeTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4042 XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs,
4043                    absl::Span<const int64> broadcast_dimensions) {
4044   return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4045                            ComparisonDirection::kNe);
4046 }
4047 
Ge(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4048 XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
4049          absl::Span<const int64> broadcast_dimensions) {
4050   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe);
4051 }
4052 
GeTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4053 XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs,
4054                    absl::Span<const int64> broadcast_dimensions) {
4055   return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4056                            ComparisonDirection::kGe);
4057 }
4058 
Gt(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4059 XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
4060          absl::Span<const int64> broadcast_dimensions) {
4061   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt);
4062 }
4063 
GtTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4064 XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs,
4065                    absl::Span<const int64> broadcast_dimensions) {
4066   return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4067                            ComparisonDirection::kGt);
4068 }
4069 
Le(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4070 XlaOp Le(const XlaOp lhs, const XlaOp rhs,
4071          absl::Span<const int64> broadcast_dimensions) {
4072   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe);
4073 }
4074 
LeTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4075 XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs,
4076                    absl::Span<const int64> broadcast_dimensions) {
4077   return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4078                            ComparisonDirection::kLe);
4079 }
4080 
Lt(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4081 XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
4082          absl::Span<const int64> broadcast_dimensions) {
4083   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
4084 }
4085 
LtTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4086 XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs,
4087                    absl::Span<const int64> broadcast_dimensions) {
4088   return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4089                            ComparisonDirection::kLt);
4090 }
4091 
Compare(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions,ComparisonDirection direction)4092 XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
4093               absl::Span<const int64> broadcast_dimensions,
4094               ComparisonDirection direction) {
4095   return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
4096                                  broadcast_dimensions, direction);
4097 }
4098 
Compare(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions,ComparisonDirection direction,Comparison::Type compare_type)4099 XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
4100               absl::Span<const int64> broadcast_dimensions,
4101               ComparisonDirection direction, Comparison::Type compare_type) {
4102   return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
4103                                  broadcast_dimensions, direction, compare_type);
4104 }
4105 
Compare(const XlaOp lhs,const XlaOp rhs,ComparisonDirection direction)4106 XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) {
4107   return Compare(lhs, rhs, {}, direction);
4108 }
4109 
Dot(const XlaOp lhs,const XlaOp rhs,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4110 XlaOp Dot(const XlaOp lhs, const XlaOp rhs,
4111           const PrecisionConfig* precision_config,
4112           absl::optional<PrimitiveType> preferred_element_type) {
4113   return lhs.builder()->Dot(lhs, rhs, precision_config, preferred_element_type);
4114 }
4115 
DotGeneral(const XlaOp lhs,const XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4116 XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs,
4117                  const DotDimensionNumbers& dimension_numbers,
4118                  const PrecisionConfig* precision_config,
4119                  absl::optional<PrimitiveType> preferred_element_type) {
4120   return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
4121                                    precision_config, preferred_element_type);
4122 }
4123 
Conv(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4124 XlaOp Conv(const XlaOp lhs, const XlaOp rhs,
4125            absl::Span<const int64> window_strides, Padding padding,
4126            int64_t feature_group_count, int64_t batch_group_count,
4127            const PrecisionConfig* precision_config,
4128            absl::optional<PrimitiveType> preferred_element_type) {
4129   return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
4130                              feature_group_count, batch_group_count,
4131                              precision_config, preferred_element_type);
4132 }
4133 
ConvWithGeneralPadding(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4134 XlaOp ConvWithGeneralPadding(
4135     const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
4136     absl::Span<const std::pair<int64, int64>> padding,
4137     int64_t feature_group_count, int64_t batch_group_count,
4138     const PrecisionConfig* precision_config,
4139     absl::optional<PrimitiveType> preferred_element_type) {
4140   return lhs.builder()->ConvWithGeneralPadding(
4141       lhs, rhs, window_strides, padding, feature_group_count, batch_group_count,
4142       precision_config, preferred_element_type);
4143 }
4144 
ConvWithGeneralDimensions(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4145 XlaOp ConvWithGeneralDimensions(
4146     const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
4147     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
4148     int64_t feature_group_count, int64_t batch_group_count,
4149     const PrecisionConfig* precision_config,
4150     absl::optional<PrimitiveType> preferred_element_type) {
4151   return lhs.builder()->ConvWithGeneralDimensions(
4152       lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
4153       batch_group_count, precision_config, preferred_element_type);
4154 }
4155 
ConvGeneral(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4156 XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
4157                   absl::Span<const int64> window_strides,
4158                   absl::Span<const std::pair<int64, int64>> padding,
4159                   const ConvolutionDimensionNumbers& dimension_numbers,
4160                   int64_t feature_group_count, int64_t batch_group_count,
4161                   const PrecisionConfig* precision_config,
4162                   absl::optional<PrimitiveType> preferred_element_type) {
4163   return lhs.builder()->ConvGeneral(
4164       lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
4165       batch_group_count, precision_config, preferred_element_type);
4166 }
4167 
ConvGeneralDilated(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4168 XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
4169                          absl::Span<const int64> window_strides,
4170                          absl::Span<const std::pair<int64, int64>> padding,
4171                          absl::Span<const int64> lhs_dilation,
4172                          absl::Span<const int64> rhs_dilation,
4173                          const ConvolutionDimensionNumbers& dimension_numbers,
4174                          int64_t feature_group_count, int64_t batch_group_count,
4175                          const PrecisionConfig* precision_config,
4176                          absl::optional<PrimitiveType> preferred_element_type) {
4177   return lhs.builder()->ConvGeneralDilated(
4178       lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
4179       dimension_numbers, feature_group_count, batch_group_count,
4180       precision_config, preferred_element_type);
4181 }
4182 
DynamicConvInputGrad(XlaOp input_sizes,const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)4183 XlaOp DynamicConvInputGrad(
4184     XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs,
4185     absl::Span<const int64> window_strides,
4186     absl::Span<const std::pair<int64, int64>> padding,
4187     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
4188     const ConvolutionDimensionNumbers& dimension_numbers,
4189     int64_t feature_group_count, int64_t batch_group_count,
4190     const PrecisionConfig* precision_config, PaddingType padding_type,
4191     absl::optional<PrimitiveType> preferred_element_type) {
4192   return lhs.builder()->DynamicConvInputGrad(
4193       input_sizes, lhs, rhs, window_strides, padding, lhs_dilation,
4194       rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
4195       precision_config, padding_type, preferred_element_type);
4196 }
4197 
DynamicConvKernelGrad(XlaOp activations,XlaOp gradients,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)4198 XlaOp DynamicConvKernelGrad(
4199     XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
4200     absl::Span<const std::pair<int64, int64>> padding,
4201     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
4202     const ConvolutionDimensionNumbers& dimension_numbers,
4203     int64_t feature_group_count, int64_t batch_group_count,
4204     const PrecisionConfig* precision_config, PaddingType padding_type,
4205     absl::optional<PrimitiveType> preferred_element_type) {
4206   return activations.builder()->DynamicConvKernelGrad(
4207       activations, gradients, window_strides, padding, lhs_dilation,
4208       rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
4209       precision_config, padding_type, preferred_element_type);
4210 }
4211 
DynamicConvForward(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)4212 XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
4213                          absl::Span<const int64> window_strides,
4214                          absl::Span<const std::pair<int64, int64>> padding,
4215                          absl::Span<const int64> lhs_dilation,
4216                          absl::Span<const int64> rhs_dilation,
4217                          const ConvolutionDimensionNumbers& dimension_numbers,
4218                          int64_t feature_group_count, int64_t batch_group_count,
4219                          const PrecisionConfig* precision_config,
4220                          PaddingType padding_type,
4221                          absl::optional<PrimitiveType> preferred_element_type) {
4222   return lhs.builder()->DynamicConvForward(
4223       lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
4224       dimension_numbers, feature_group_count, batch_group_count,
4225       precision_config, padding_type, preferred_element_type);
4226 }
4227 
Fft(const XlaOp operand,FftType fft_type,absl::Span<const int64> fft_length)4228 XlaOp Fft(const XlaOp operand, FftType fft_type,
4229           absl::Span<const int64> fft_length) {
4230   return operand.builder()->Fft(operand, fft_type, fft_length);
4231 }
4232 
TriangularSolve(XlaOp a,XlaOp b,bool left_side,bool lower,bool unit_diagonal,TriangularSolveOptions::Transpose transpose_a)4233 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
4234                       bool unit_diagonal,
4235                       TriangularSolveOptions::Transpose transpose_a) {
4236   XlaBuilder* builder = a.builder();
4237   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
4238     TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a));
4239     TF_ASSIGN_OR_RETURN(const Shape* b_shape, builder->GetShapePtr(b));
4240     xla::TriangularSolveOptions options;
4241     options.set_left_side(left_side);
4242     options.set_lower(lower);
4243     options.set_unit_diagonal(unit_diagonal);
4244     options.set_transpose_a(transpose_a);
4245     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTriangularSolveShape(
4246                                          *a_shape, *b_shape, options));
4247     return builder->TriangularSolveInternal(shape, a, b, std::move(options));
4248   });
4249 }
4250 
Cholesky(XlaOp a,bool lower)4251 XlaOp Cholesky(XlaOp a, bool lower) {
4252   XlaBuilder* builder = a.builder();
4253   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
4254     TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a));
4255     TF_ASSIGN_OR_RETURN(Shape shape,
4256                         ShapeInference::InferCholeskyShape(*a_shape));
4257     return builder->CholeskyInternal(shape, a, lower);
4258   });
4259 }
4260 
Infeed(XlaBuilder * builder,const Shape & shape,const string & config)4261 XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) {
4262   return builder->Infeed(shape, config);
4263 }
4264 
Outfeed(const XlaOp operand,const Shape & shape_with_layout,const string & outfeed_config)4265 void Outfeed(const XlaOp operand, const Shape& shape_with_layout,
4266              const string& outfeed_config) {
4267   return operand.builder()->Outfeed(operand, shape_with_layout, outfeed_config);
4268 }
4269 
Call(XlaBuilder * builder,const XlaComputation & computation,absl::Span<const XlaOp> operands)4270 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
4271            absl::Span<const XlaOp> operands) {
4272   return builder->Call(computation, operands);
4273 }
4274 
CustomCall(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal,CustomCallSchedule schedule,CustomCallApiVersion api_version)4275 XlaOp CustomCall(
4276     XlaBuilder* builder, const string& call_target_name,
4277     absl::Span<const XlaOp> operands, const Shape& shape, const string& opaque,
4278     bool has_side_effect,
4279     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
4280         output_operand_aliasing,
4281     const Literal* literal, CustomCallSchedule schedule,
4282     CustomCallApiVersion api_version) {
4283   return builder->CustomCall(call_target_name, operands, shape, opaque,
4284                              /*operand_shapes_with_layout=*/absl::nullopt,
4285                              has_side_effect, output_operand_aliasing, literal,
4286                              /*window=*/absl::nullopt, /*dnums=*/absl::nullopt,
4287                              schedule, api_version);
4288 }
4289 
CustomCallWithComputation(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const XlaComputation & computation,const Shape & shape,const string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal,CustomCallSchedule schedule,CustomCallApiVersion api_version)4290 XlaOp CustomCallWithComputation(
4291     XlaBuilder* builder, const string& call_target_name,
4292     absl::Span<const XlaOp> operands, const XlaComputation& computation,
4293     const Shape& shape, const string& opaque, bool has_side_effect,
4294     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
4295         output_operand_aliasing,
4296     const Literal* literal, CustomCallSchedule schedule,
4297     CustomCallApiVersion api_version) {
4298   return builder->CustomCall(
4299       call_target_name, operands, computation, shape, opaque,
4300       /*operand_shapes_with_layout=*/absl::nullopt, has_side_effect,
4301       output_operand_aliasing, literal, schedule, api_version);
4302 }
4303 
CustomCallWithLayout(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,absl::Span<const Shape> operand_shapes_with_layout,const string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal,CustomCallSchedule schedule,CustomCallApiVersion api_version)4304 XlaOp CustomCallWithLayout(
4305     XlaBuilder* builder, const string& call_target_name,
4306     absl::Span<const XlaOp> operands, const Shape& shape,
4307     absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
4308     bool has_side_effect,
4309     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
4310         output_operand_aliasing,
4311     const Literal* literal, CustomCallSchedule schedule,
4312     CustomCallApiVersion api_version) {
4313   return builder->CustomCall(
4314       call_target_name, operands, shape, opaque, operand_shapes_with_layout,
4315       has_side_effect, output_operand_aliasing, literal,
4316       /*window=*/absl::nullopt, /*dnums=*/absl::nullopt, schedule, api_version);
4317 }
4318 
CustomCallWithConvDnums(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,absl::Span<const Shape> operand_shapes_with_layout,const string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal,Window window,ConvolutionDimensionNumbers dnums,CustomCallSchedule schedule,CustomCallApiVersion api_version)4319 XlaOp CustomCallWithConvDnums(
4320     XlaBuilder* builder, const string& call_target_name,
4321     absl::Span<const XlaOp> operands, const Shape& shape,
4322     absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
4323     bool has_side_effect,
4324     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
4325         output_operand_aliasing,
4326     const Literal* literal, Window window, ConvolutionDimensionNumbers dnums,
4327     CustomCallSchedule schedule, CustomCallApiVersion api_version) {
4328   absl::optional<absl::Span<const Shape>> maybe_operand_shapes;
4329   if (!operand_shapes_with_layout.empty()) {
4330     maybe_operand_shapes = operand_shapes_with_layout;
4331   }
4332   return builder->CustomCall(call_target_name, operands, shape, opaque,
4333                              maybe_operand_shapes, has_side_effect,
4334                              output_operand_aliasing, literal, window, dnums,
4335                              schedule, api_version);
4336 }
4337 
Complex(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4338 XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
4339               absl::Span<const int64> broadcast_dimensions) {
4340   return lhs.builder()->BinaryOp(HloOpcode::kComplex, lhs, rhs,
4341                                  broadcast_dimensions);
4342 }
4343 
Conj(const XlaOp operand)4344 XlaOp Conj(const XlaOp operand) {
4345   return Complex(Real(operand), Neg(Imag(operand)));
4346 }
4347 
Add(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4348 XlaOp Add(const XlaOp lhs, const XlaOp rhs,
4349           absl::Span<const int64> broadcast_dimensions) {
4350   return lhs.builder()->BinaryOp(HloOpcode::kAdd, lhs, rhs,
4351                                  broadcast_dimensions);
4352 }
4353 
Sub(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4354 XlaOp Sub(const XlaOp lhs, const XlaOp rhs,
4355           absl::Span<const int64> broadcast_dimensions) {
4356   return lhs.builder()->BinaryOp(HloOpcode::kSubtract, lhs, rhs,
4357                                  broadcast_dimensions);
4358 }
4359 
Mul(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4360 XlaOp Mul(const XlaOp lhs, const XlaOp rhs,
4361           absl::Span<const int64> broadcast_dimensions) {
4362   return lhs.builder()->BinaryOp(HloOpcode::kMultiply, lhs, rhs,
4363                                  broadcast_dimensions);
4364 }
4365 
Div(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4366 XlaOp Div(const XlaOp lhs, const XlaOp rhs,
4367           absl::Span<const int64> broadcast_dimensions) {
4368   return lhs.builder()->BinaryOp(HloOpcode::kDivide, lhs, rhs,
4369                                  broadcast_dimensions);
4370 }
4371 
Rem(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4372 XlaOp Rem(const XlaOp lhs, const XlaOp rhs,
4373           absl::Span<const int64> broadcast_dimensions) {
4374   return lhs.builder()->BinaryOp(HloOpcode::kRemainder, lhs, rhs,
4375                                  broadcast_dimensions);
4376 }
4377 
Max(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4378 XlaOp Max(const XlaOp lhs, const XlaOp rhs,
4379           absl::Span<const int64> broadcast_dimensions) {
4380   return lhs.builder()->BinaryOp(HloOpcode::kMaximum, lhs, rhs,
4381                                  broadcast_dimensions);
4382 }
4383 
Min(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4384 XlaOp Min(const XlaOp lhs, const XlaOp rhs,
4385           absl::Span<const int64> broadcast_dimensions) {
4386   return lhs.builder()->BinaryOp(HloOpcode::kMinimum, lhs, rhs,
4387                                  broadcast_dimensions);
4388 }
4389 
And(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4390 XlaOp And(const XlaOp lhs, const XlaOp rhs,
4391           absl::Span<const int64> broadcast_dimensions) {
4392   return lhs.builder()->BinaryOp(HloOpcode::kAnd, lhs, rhs,
4393                                  broadcast_dimensions);
4394 }
4395 
Or(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4396 XlaOp Or(const XlaOp lhs, const XlaOp rhs,
4397          absl::Span<const int64> broadcast_dimensions) {
4398   return lhs.builder()->BinaryOp(HloOpcode::kOr, lhs, rhs,
4399                                  broadcast_dimensions);
4400 }
4401 
Xor(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4402 XlaOp Xor(const XlaOp lhs, const XlaOp rhs,
4403           absl::Span<const int64> broadcast_dimensions) {
4404   return lhs.builder()->BinaryOp(HloOpcode::kXor, lhs, rhs,
4405                                  broadcast_dimensions);
4406 }
4407 
Not(const XlaOp operand)4408 XlaOp Not(const XlaOp operand) {
4409   return operand.builder()->UnaryOp(HloOpcode::kNot, operand);
4410 }
4411 
PopulationCount(const XlaOp operand)4412 XlaOp PopulationCount(const XlaOp operand) {
4413   return operand.builder()->UnaryOp(HloOpcode::kPopulationCount, operand);
4414 }
4415 
ShiftLeft(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4416 XlaOp ShiftLeft(const XlaOp lhs, const XlaOp rhs,
4417                 absl::Span<const int64> broadcast_dimensions) {
4418   return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs,
4419                                  broadcast_dimensions);
4420 }
4421 
ShiftRightArithmetic(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4422 XlaOp ShiftRightArithmetic(const XlaOp lhs, const XlaOp rhs,
4423                            absl::Span<const int64> broadcast_dimensions) {
4424   return lhs.builder()->BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs,
4425                                  broadcast_dimensions);
4426 }
4427 
ShiftRightLogical(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4428 XlaOp ShiftRightLogical(const XlaOp lhs, const XlaOp rhs,
4429                         absl::Span<const int64> broadcast_dimensions) {
4430   return lhs.builder()->BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs,
4431                                  broadcast_dimensions);
4432 }
4433 
Reduce(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)4434 XlaOp Reduce(const XlaOp operand, const XlaOp init_value,
4435              const XlaComputation& computation,
4436              absl::Span<const int64> dimensions_to_reduce) {
4437   return operand.builder()->Reduce(operand, init_value, computation,
4438                                    dimensions_to_reduce);
4439 }
4440 
4441 // Reduces several arrays simultaneously among the provided dimensions, given
4442 // "computation" as a reduction operator.
Reduce(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)4443 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
4444              absl::Span<const XlaOp> init_values,
4445              const XlaComputation& computation,
4446              absl::Span<const int64> dimensions_to_reduce) {
4447   return builder->Reduce(operands, init_values, computation,
4448                          dimensions_to_reduce);
4449 }
4450 
ReduceAll(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation)4451 XlaOp ReduceAll(const XlaOp operand, const XlaOp init_value,
4452                 const XlaComputation& computation) {
4453   return operand.builder()->ReduceAll(operand, init_value, computation);
4454 }
4455 
ReduceWindow(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)4456 XlaOp ReduceWindow(const XlaOp operand, const XlaOp init_value,
4457                    const XlaComputation& computation,
4458                    absl::Span<const int64> window_dimensions,
4459                    absl::Span<const int64> window_strides, Padding padding) {
4460   return operand.builder()->ReduceWindow(operand, init_value, computation,
4461                                          window_dimensions, window_strides,
4462                                          padding);
4463 }
4464 
ReduceWindow(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)4465 XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
4466                    absl::Span<const XlaOp> init_values,
4467                    const XlaComputation& computation,
4468                    absl::Span<const int64> window_dimensions,
4469                    absl::Span<const int64> window_strides, Padding padding) {
4470   CHECK(!operands.empty());
4471   return operands[0].builder()->ReduceWindow(operands, init_values, computation,
4472                                              window_dimensions, window_strides,
4473                                              padding);
4474 }
4475 
ReduceWindowWithGeneralPadding(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)4476 XlaOp ReduceWindowWithGeneralPadding(
4477     const XlaOp operand, const XlaOp init_value,
4478     const XlaComputation& computation,
4479     absl::Span<const int64> window_dimensions,
4480     absl::Span<const int64> window_strides,
4481     absl::Span<const int64> base_dilations,
4482     absl::Span<const int64> window_dilations,
4483     absl::Span<const std::pair<int64, int64>> padding) {
4484   return operand.builder()->ReduceWindowWithGeneralPadding(
4485       absl::MakeSpan(&operand, 1), absl::MakeSpan(&init_value, 1), computation,
4486       window_dimensions, window_strides, base_dilations, window_dilations,
4487       padding);
4488 }
4489 
ReduceWindowWithGeneralPadding(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)4490 XlaOp ReduceWindowWithGeneralPadding(
4491     absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
4492     const XlaComputation& computation,
4493     absl::Span<const int64> window_dimensions,
4494     absl::Span<const int64> window_strides,
4495     absl::Span<const int64> base_dilations,
4496     absl::Span<const int64> window_dilations,
4497     absl::Span<const std::pair<int64, int64>> padding) {
4498   CHECK(!operands.empty());
4499   return operands[0].builder()->ReduceWindowWithGeneralPadding(
4500       operands, init_values, computation, window_dimensions, window_strides,
4501       base_dilations, window_dilations, padding);
4502 }
4503 
AllGather(const XlaOp operand,int64_t all_gather_dimension,int64_t shard_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Layout> & layout,const absl::optional<bool> use_global_device_ids)4504 XlaOp AllGather(const XlaOp operand, int64_t all_gather_dimension,
4505                 int64_t shard_count,
4506                 absl::Span<const ReplicaGroup> replica_groups,
4507                 const absl::optional<ChannelHandle>& channel_id,
4508                 const absl::optional<Layout>& layout,
4509                 const absl::optional<bool> use_global_device_ids) {
4510   return operand.builder()->AllGather(operand, all_gather_dimension,
4511                                       shard_count, replica_groups, channel_id,
4512                                       layout, use_global_device_ids);
4513 }
4514 
CrossReplicaSum(const XlaOp operand,absl::Span<const ReplicaGroup> replica_groups)4515 XlaOp CrossReplicaSum(const XlaOp operand,
4516                       absl::Span<const ReplicaGroup> replica_groups) {
4517   return operand.builder()->CrossReplicaSum(operand, replica_groups);
4518 }
4519 
AllReduce(const XlaOp operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Shape> & shape_with_layout)4520 XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
4521                 absl::Span<const ReplicaGroup> replica_groups,
4522                 const absl::optional<ChannelHandle>& channel_id,
4523                 const absl::optional<Shape>& shape_with_layout) {
4524   return operand.builder()->AllReduce(operand, computation, replica_groups,
4525                                       channel_id, shape_with_layout);
4526 }
4527 
ReduceScatter(const XlaOp operand,const XlaComputation & computation,int64_t scatter_dimension,int64_t shard_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Layout> & layout,const absl::optional<bool> use_global_device_ids)4528 XlaOp ReduceScatter(const XlaOp operand, const XlaComputation& computation,
4529                     int64_t scatter_dimension, int64_t shard_count,
4530                     absl::Span<const ReplicaGroup> replica_groups,
4531                     const absl::optional<ChannelHandle>& channel_id,
4532                     const absl::optional<Layout>& layout,
4533                     const absl::optional<bool> use_global_device_ids) {
4534   return operand.builder()->ReduceScatter(
4535       operand, computation, scatter_dimension, shard_count, replica_groups,
4536       channel_id, layout, use_global_device_ids);
4537 }
4538 
AllToAll(const XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<Layout> & layout)4539 XlaOp AllToAll(const XlaOp operand, int64_t split_dimension,
4540                int64_t concat_dimension, int64_t split_count,
4541                absl::Span<const ReplicaGroup> replica_groups,
4542                const absl::optional<Layout>& layout) {
4543   return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
4544                                      split_count, replica_groups, layout);
4545 }
4546 
AllToAllTuple(const XlaOp operand,int64_t split_dimension,int64_t concat_dimension,int64_t split_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<Layout> & layout)4547 XlaOp AllToAllTuple(const XlaOp operand, int64_t split_dimension,
4548                     int64_t concat_dimension, int64_t split_count,
4549                     absl::Span<const ReplicaGroup> replica_groups,
4550                     const absl::optional<Layout>& layout) {
4551   return operand.builder()->AllToAllTuple(operand, split_dimension,
4552                                           concat_dimension, split_count,
4553                                           replica_groups, layout);
4554 }
4555 
CollectivePermute(const XlaOp operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)4556 XlaOp CollectivePermute(
4557     const XlaOp operand,
4558     const std::vector<std::pair<int64, int64>>& source_target_pairs) {
4559   return operand.builder()->CollectivePermute(operand, source_target_pairs);
4560 }
4561 
ReplicaId(XlaBuilder * builder)4562 XlaOp ReplicaId(XlaBuilder* builder) { return builder->ReplicaId(); }
4563 
SelectAndScatter(const XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding,const XlaOp source,const XlaOp init_value,const XlaComputation & scatter)4564 XlaOp SelectAndScatter(const XlaOp operand, const XlaComputation& select,
4565                        absl::Span<const int64> window_dimensions,
4566                        absl::Span<const int64> window_strides, Padding padding,
4567                        const XlaOp source, const XlaOp init_value,
4568                        const XlaComputation& scatter) {
4569   return operand.builder()->SelectAndScatter(operand, select, window_dimensions,
4570                                              window_strides, padding, source,
4571                                              init_value, scatter);
4572 }
4573 
SelectAndScatterWithGeneralPadding(const XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const XlaOp source,const XlaOp init_value,const XlaComputation & scatter)4574 XlaOp SelectAndScatterWithGeneralPadding(
4575     const XlaOp operand, const XlaComputation& select,
4576     absl::Span<const int64> window_dimensions,
4577     absl::Span<const int64> window_strides,
4578     absl::Span<const std::pair<int64, int64>> padding, const XlaOp source,
4579     const XlaOp init_value, const XlaComputation& scatter) {
4580   return operand.builder()->SelectAndScatterWithGeneralPadding(
4581       operand, select, window_dimensions, window_strides, padding, source,
4582       init_value, scatter);
4583 }
4584 
Abs(const XlaOp operand)4585 XlaOp Abs(const XlaOp operand) {
4586   return operand.builder()->UnaryOp(HloOpcode::kAbs, operand);
4587 }
4588 
Atan2(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4589 XlaOp Atan2(const XlaOp lhs, const XlaOp rhs,
4590             absl::Span<const int64> broadcast_dimensions) {
4591   return lhs.builder()->BinaryOp(HloOpcode::kAtan2, lhs, rhs,
4592                                  broadcast_dimensions);
4593 }
4594 
Exp(const XlaOp operand)4595 XlaOp Exp(const XlaOp operand) {
4596   return operand.builder()->UnaryOp(HloOpcode::kExp, operand);
4597 }
Expm1(const XlaOp operand)4598 XlaOp Expm1(const XlaOp operand) {
4599   return operand.builder()->UnaryOp(HloOpcode::kExpm1, operand);
4600 }
Floor(const XlaOp operand)4601 XlaOp Floor(const XlaOp operand) {
4602   return operand.builder()->UnaryOp(HloOpcode::kFloor, operand);
4603 }
Ceil(const XlaOp operand)4604 XlaOp Ceil(const XlaOp operand) {
4605   return operand.builder()->UnaryOp(HloOpcode::kCeil, operand);
4606 }
Round(const XlaOp operand)4607 XlaOp Round(const XlaOp operand) {
4608   return operand.builder()->UnaryOp(HloOpcode::kRoundNearestAfz, operand);
4609 }
Log(const XlaOp operand)4610 XlaOp Log(const XlaOp operand) {
4611   return operand.builder()->UnaryOp(HloOpcode::kLog, operand);
4612 }
Log1p(const XlaOp operand)4613 XlaOp Log1p(const XlaOp operand) {
4614   return operand.builder()->UnaryOp(HloOpcode::kLog1p, operand);
4615 }
Logistic(const XlaOp operand)4616 XlaOp Logistic(const XlaOp operand) {
4617   return operand.builder()->UnaryOp(HloOpcode::kLogistic, operand);
4618 }
Sign(const XlaOp operand)4619 XlaOp Sign(const XlaOp operand) {
4620   return operand.builder()->UnaryOp(HloOpcode::kSign, operand);
4621 }
Clz(const XlaOp operand)4622 XlaOp Clz(const XlaOp operand) {
4623   return operand.builder()->UnaryOp(HloOpcode::kClz, operand);
4624 }
Cos(const XlaOp operand)4625 XlaOp Cos(const XlaOp operand) {
4626   return operand.builder()->UnaryOp(HloOpcode::kCos, operand);
4627 }
Sin(const XlaOp operand)4628 XlaOp Sin(const XlaOp operand) {
4629   return operand.builder()->UnaryOp(HloOpcode::kSin, operand);
4630 }
Tanh(const XlaOp operand)4631 XlaOp Tanh(const XlaOp operand) {
4632   return operand.builder()->UnaryOp(HloOpcode::kTanh, operand);
4633 }
Real(const XlaOp operand)4634 XlaOp Real(const XlaOp operand) {
4635   return operand.builder()->UnaryOp(HloOpcode::kReal, operand);
4636 }
Imag(const XlaOp operand)4637 XlaOp Imag(const XlaOp operand) {
4638   return operand.builder()->UnaryOp(HloOpcode::kImag, operand);
4639 }
Sqrt(const XlaOp operand)4640 XlaOp Sqrt(const XlaOp operand) {
4641   return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand);
4642 }
Cbrt(const XlaOp operand)4643 XlaOp Cbrt(const XlaOp operand) {
4644   return operand.builder()->UnaryOp(HloOpcode::kCbrt, operand);
4645 }
Rsqrt(const XlaOp operand)4646 XlaOp Rsqrt(const XlaOp operand) {
4647   return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand);
4648 }
4649 
Pow(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4650 XlaOp Pow(const XlaOp lhs, const XlaOp rhs,
4651           absl::Span<const int64> broadcast_dimensions) {
4652   return lhs.builder()->BinaryOp(HloOpcode::kPower, lhs, rhs,
4653                                  broadcast_dimensions);
4654 }
4655 
IsFinite(const XlaOp operand)4656 XlaOp IsFinite(const XlaOp operand) {
4657   return operand.builder()->UnaryOp(HloOpcode::kIsFinite, operand);
4658 }
4659 
ConvertElementType(const XlaOp operand,PrimitiveType new_element_type)4660 XlaOp ConvertElementType(const XlaOp operand, PrimitiveType new_element_type) {
4661   return operand.builder()->ConvertElementType(operand, new_element_type);
4662 }
4663 
BitcastConvertType(const XlaOp operand,PrimitiveType new_element_type)4664 XlaOp BitcastConvertType(const XlaOp operand, PrimitiveType new_element_type) {
4665   return operand.builder()->BitcastConvertType(operand, new_element_type);
4666 }
4667 
Neg(const XlaOp operand)4668 XlaOp Neg(const XlaOp operand) {
4669   return operand.builder()->UnaryOp(HloOpcode::kNegate, operand);
4670 }
4671 
Transpose(const XlaOp operand,absl::Span<const int64> permutation)4672 XlaOp Transpose(const XlaOp operand, absl::Span<const int64> permutation) {
4673   return operand.builder()->Transpose(operand, permutation);
4674 }
4675 
Rev(const XlaOp operand,absl::Span<const int64> dimensions)4676 XlaOp Rev(const XlaOp operand, absl::Span<const int64> dimensions) {
4677   return operand.builder()->Rev(operand, dimensions);
4678 }
4679 
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64_t dimension,bool is_stable)4680 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
4681            int64_t dimension, bool is_stable) {
4682   return operands[0].builder()->Sort(operands, comparator, dimension,
4683                                      is_stable);
4684 }
4685 
Clamp(const XlaOp min,const XlaOp operand,const XlaOp max)4686 XlaOp Clamp(const XlaOp min, const XlaOp operand, const XlaOp max) {
4687   return min.builder()->Clamp(min, operand, max);
4688 }
4689 
Map(XlaBuilder * builder,absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64> dimensions,absl::Span<const XlaOp> static_operands)4690 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
4691           const XlaComputation& computation, absl::Span<const int64> dimensions,
4692           absl::Span<const XlaOp> static_operands) {
4693   return builder->Map(operands, computation, dimensions, static_operands);
4694 }
4695 
RngNormal(const XlaOp mu,const XlaOp sigma,const Shape & shape)4696 XlaOp RngNormal(const XlaOp mu, const XlaOp sigma, const Shape& shape) {
4697   return mu.builder()->RngNormal(mu, sigma, shape);
4698 }
4699 
RngUniform(const XlaOp a,const XlaOp b,const Shape & shape)4700 XlaOp RngUniform(const XlaOp a, const XlaOp b, const Shape& shape) {
4701   return a.builder()->RngUniform(a, b, shape);
4702 }
4703 
RngBitGenerator(RandomAlgorithm algorithm,const XlaOp initial_state,const Shape & shape)4704 XlaOp RngBitGenerator(RandomAlgorithm algorithm, const XlaOp initial_state,
4705                       const Shape& shape) {
4706   return initial_state.builder()->RngBitGenerator(algorithm, initial_state,
4707                                                   shape);
4708 }
4709 
While(const XlaComputation & condition,const XlaComputation & body,const XlaOp init)4710 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
4711             const XlaOp init) {
4712   return init.builder()->While(condition, body, init);
4713 }
4714 
Conditional(const XlaOp predicate,const XlaOp true_operand,const XlaComputation & true_computation,const XlaOp false_operand,const XlaComputation & false_computation)4715 XlaOp Conditional(const XlaOp predicate, const XlaOp true_operand,
4716                   const XlaComputation& true_computation,
4717                   const XlaOp false_operand,
4718                   const XlaComputation& false_computation) {
4719   return predicate.builder()->Conditional(predicate, true_operand,
4720                                           true_computation, false_operand,
4721                                           false_computation);
4722 }
4723 
Conditional(const XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)4724 XlaOp Conditional(const XlaOp branch_index,
4725                   absl::Span<const XlaComputation* const> branch_computations,
4726                   absl::Span<const XlaOp> branch_operands) {
4727   return branch_index.builder()->Conditional(branch_index, branch_computations,
4728                                              branch_operands);
4729 }
4730 
ReducePrecision(const XlaOp operand,const int exponent_bits,const int mantissa_bits)4731 XlaOp ReducePrecision(const XlaOp operand, const int exponent_bits,
4732                       const int mantissa_bits) {
4733   return operand.builder()->ReducePrecision(operand, exponent_bits,
4734                                             mantissa_bits);
4735 }
4736 
Gather(const XlaOp input,const XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)4737 XlaOp Gather(const XlaOp input, const XlaOp start_indices,
4738              const GatherDimensionNumbers& dimension_numbers,
4739              absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
4740   return input.builder()->Gather(input, start_indices, dimension_numbers,
4741                                  slice_sizes, indices_are_sorted);
4742 }
4743 
Scatter(const XlaOp input,const XlaOp scatter_indices,const XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)4744 XlaOp Scatter(const XlaOp input, const XlaOp scatter_indices,
4745               const XlaOp updates, const XlaComputation& update_computation,
4746               const ScatterDimensionNumbers& dimension_numbers,
4747               bool indices_are_sorted, bool unique_indices) {
4748   return input.builder()->Scatter(input, scatter_indices, updates,
4749                                   update_computation, dimension_numbers,
4750                                   indices_are_sorted, unique_indices);
4751 }
4752 
Send(const XlaOp operand,const ChannelHandle & handle)4753 void Send(const XlaOp operand, const ChannelHandle& handle) {
4754   return operand.builder()->Send(operand, handle);
4755 }
4756 
Recv(XlaBuilder * builder,const Shape & shape,const ChannelHandle & handle)4757 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
4758            const ChannelHandle& handle) {
4759   return builder->Recv(shape, handle);
4760 }
4761 
SendWithToken(const XlaOp operand,const XlaOp token,const ChannelHandle & handle)4762 XlaOp SendWithToken(const XlaOp operand, const XlaOp token,
4763                     const ChannelHandle& handle) {
4764   return operand.builder()->SendWithToken(operand, token, handle);
4765 }
4766 
RecvWithToken(const XlaOp token,const Shape & shape,const ChannelHandle & handle)4767 XlaOp RecvWithToken(const XlaOp token, const Shape& shape,
4768                     const ChannelHandle& handle) {
4769   return token.builder()->RecvWithToken(token, shape, handle);
4770 }
4771 
SendToHost(const XlaOp operand,const XlaOp token,const Shape & shape_with_layout,const ChannelHandle & handle)4772 XlaOp SendToHost(const XlaOp operand, const XlaOp token,
4773                  const Shape& shape_with_layout, const ChannelHandle& handle) {
4774   return operand.builder()->SendToHost(operand, token, shape_with_layout,
4775                                        handle);
4776 }
4777 
RecvFromHost(const XlaOp token,const Shape & shape,const ChannelHandle & handle)4778 XlaOp RecvFromHost(const XlaOp token, const Shape& shape,
4779                    const ChannelHandle& handle) {
4780   return token.builder()->RecvFromHost(token, shape, handle);
4781 }
4782 
InfeedWithToken(const XlaOp token,const Shape & shape,const string & config)4783 XlaOp InfeedWithToken(const XlaOp token, const Shape& shape,
4784                       const string& config) {
4785   return token.builder()->InfeedWithToken(token, shape, config);
4786 }
4787 
OutfeedWithToken(const XlaOp operand,const XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)4788 XlaOp OutfeedWithToken(const XlaOp operand, const XlaOp token,
4789                        const Shape& shape_with_layout,
4790                        const string& outfeed_config) {
4791   return operand.builder()->OutfeedWithToken(operand, token, shape_with_layout,
4792                                              outfeed_config);
4793 }
4794 
CreateToken(XlaBuilder * builder)4795 XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); }
4796 
AfterAll(XlaBuilder * builder,absl::Span<const XlaOp> tokens)4797 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens) {
4798   return builder->AfterAll(tokens);
4799 }
4800 
BatchNormTraining(const XlaOp operand,const XlaOp scale,const XlaOp offset,float epsilon,int64_t feature_index)4801 XlaOp BatchNormTraining(const XlaOp operand, const XlaOp scale,
4802                         const XlaOp offset, float epsilon,
4803                         int64_t feature_index) {
4804   return operand.builder()->BatchNormTraining(operand, scale, offset, epsilon,
4805                                               feature_index);
4806 }
4807 
BatchNormInference(const XlaOp operand,const XlaOp scale,const XlaOp offset,const XlaOp mean,const XlaOp variance,float epsilon,int64_t feature_index)4808 XlaOp BatchNormInference(const XlaOp operand, const XlaOp scale,
4809                          const XlaOp offset, const XlaOp mean,
4810                          const XlaOp variance, float epsilon,
4811                          int64_t feature_index) {
4812   return operand.builder()->BatchNormInference(
4813       operand, scale, offset, mean, variance, epsilon, feature_index);
4814 }
4815 
BatchNormGrad(const XlaOp operand,const XlaOp scale,const XlaOp batch_mean,const XlaOp batch_var,const XlaOp grad_output,float epsilon,int64_t feature_index)4816 XlaOp BatchNormGrad(const XlaOp operand, const XlaOp scale,
4817                     const XlaOp batch_mean, const XlaOp batch_var,
4818                     const XlaOp grad_output, float epsilon,
4819                     int64_t feature_index) {
4820   return operand.builder()->BatchNormGrad(operand, scale, batch_mean, batch_var,
4821                                           grad_output, epsilon, feature_index);
4822 }
4823 
Iota(XlaBuilder * builder,PrimitiveType type,int64_t size)4824 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size) {
4825   return builder->Iota(type, size);
4826 }
4827 
Iota(XlaBuilder * builder,const Shape & shape,int64_t iota_dimension)4828 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64_t iota_dimension) {
4829   return builder->Iota(shape, iota_dimension);
4830 }
4831 
GetDimensionSize(const XlaOp operand,int64_t dimension)4832 XlaOp GetDimensionSize(const XlaOp operand, int64_t dimension) {
4833   return operand.builder()->GetDimensionSize(operand, dimension);
4834 }
4835 
SetDimensionSize(const XlaOp operand,const XlaOp val,int64_t dimension)4836 XlaOp SetDimensionSize(const XlaOp operand, const XlaOp val,
4837                        int64_t dimension) {
4838   return operand.builder()->SetDimensionSize(operand, val, dimension);
4839 }
4840 
RemoveDynamicDimension(const XlaOp operand,int64_t dimension)4841 XlaOp RemoveDynamicDimension(const XlaOp operand, int64_t dimension) {
4842   return operand.builder()->RemoveDynamicDimension(operand, dimension);
4843 }
4844 
4845 }  // namespace xla
4846