• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/xla_builder.h"
17 
18 #include <functional>
19 #include <numeric>
20 #include <queue>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/memory/memory.h"
28 #include "absl/strings/match.h"
29 #include "absl/strings/numbers.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/client/sharding_builder.h"
34 #include "tensorflow/compiler/xla/client/xla_computation.h"
35 #include "tensorflow/compiler/xla/comparison_util.h"
36 #include "tensorflow/compiler/xla/execution_options_util.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/permutation_util.h"
39 #include "tensorflow/compiler/xla/primitive_util.h"
40 #include "tensorflow/compiler/xla/service/hlo.pb.h"
41 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
42 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
43 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
44 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
45 #include "tensorflow/compiler/xla/service/shape_inference.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/compiler/xla/window_util.h"
49 #include "tensorflow/compiler/xla/xla_data.pb.h"
50 #include "tensorflow/core/platform/errors.h"
51 #include "tensorflow/core/platform/macros.h"
52 #include "tensorflow/stream_executor/lib/statusor.h"
53 
54 namespace xla {
55 
56 using absl::StrCat;
57 
58 namespace {
59 
60 static const char kNameSeparator = '.';
61 
62 // Retrieves the base name of an instruction or computation fully qualified
63 // name, using separator as boundary between the initial base name part, and
64 // the numeric identification.
GetBaseName(const string & name,char separator)65 string GetBaseName(const string& name, char separator) {
66   auto pos = name.rfind(separator);
67   CHECK_NE(pos, string::npos) << name;
68   return name.substr(0, pos);
69 }
70 
71 // Generates a fully qualified computation/instruction name.
GetFullName(const string & base_name,char separator,int64 id)72 string GetFullName(const string& base_name, char separator, int64 id) {
73   const char separator_str[] = {separator, '\0'};
74   return StrCat(base_name, separator_str, id);
75 }
76 
77 // Common function to standardize setting name and IDs on computation and
78 // instruction proto entities.
79 template <typename T>
SetProtoIdAndName(T * entry,const string & base_name,char separator,int64 id)80 void SetProtoIdAndName(T* entry, const string& base_name, char separator,
81                        int64 id) {
82   entry->set_id(id);
83   entry->set_name(GetFullName(base_name, separator, id));
84 }
85 
ConvertShapeProtoToPred(const ShapeProto & shape_proto)86 ShapeProto ConvertShapeProtoToPred(const ShapeProto& shape_proto) {
87   return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto();
88 }
89 
SetInstructionAsConstant(HloInstructionProto * instr,int64 id,const Shape & shape,bool pred)90 void SetInstructionAsConstant(HloInstructionProto* instr, int64 id,
91                               const Shape& shape, bool pred) {
92   Literal literal = LiteralUtil::CreateR0(pred);
93   Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie();
94   *instr->mutable_shape() = shape.ToProto();
95   *instr->mutable_literal() = literal_broadcast.ToProto();
96   *instr->mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
97 }
98 
99 // Copy `original_reducer` into a new computation proto with `reducer_id` as new
100 // id. If `rewrite_into_pred` is true, the instructions in the reducer are
101 // rewritten into predicate form.
CopyReducer(int64 reducer_id,HloComputationProto * original_reducer,bool rewrite_into_pred,int64 * global_id)102 HloComputationProto CopyReducer(int64 reducer_id,
103                                 HloComputationProto* original_reducer,
104                                 bool rewrite_into_pred, int64* global_id) {
105   HloComputationProto reducer;
106   SetProtoIdAndName(&reducer, StrCat("reduce_or"), kNameSeparator, reducer_id);
107   std::vector<int64> operands_id;
108   for (auto& inst : original_reducer->instructions()) {
109     // Copy params.
110     if (StringToHloOpcode(inst.opcode()).ValueOrDie() ==
111         HloOpcode::kParameter) {
112       HloInstructionProto* new_param = reducer.add_instructions();
113       *new_param = inst;
114       new_param->set_id((*global_id)++);
115       *new_param->mutable_name() =
116           GetFullName(inst.name(), '.', new_param->id());
117       if (rewrite_into_pred) {
118         *new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
119       }
120       operands_id.push_back(new_param->id());
121     }
122     if (inst.id() == original_reducer->root_id()) {
123       HloInstructionProto* new_root = reducer.add_instructions();
124       *new_root = inst;
125       new_root->set_id((*global_id)++);
126       *new_root->mutable_name() = GetFullName(inst.name(), '.', new_root->id());
127       if (rewrite_into_pred) {
128         *new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
129         *new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
130       }
131       new_root->clear_operand_ids();
132       for (int64 operand_id : operands_id) {
133         new_root->add_operand_ids(operand_id);
134       }
135       reducer.set_root_id(new_root->id());
136     }
137   }
138   return reducer;
139 }
140 
InstrIsSetBound(const HloInstructionProto * instr_proto)141 bool InstrIsSetBound(const HloInstructionProto* instr_proto) {
142   HloOpcode opcode = StringToHloOpcode(instr_proto->opcode()).ValueOrDie();
143   if (opcode == HloOpcode::kCustomCall &&
144       instr_proto->custom_call_target() == "SetBound") {
145     return true;
146   }
147   return false;
148 }
149 
150 }  // namespace
151 
152 namespace internal {
153 
BuildFusion(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::string_view fusion_kind,const XlaComputation & fused_computation)154 XlaOp XlaBuilderFriend::BuildFusion(XlaBuilder* builder,
155                                     absl::Span<const XlaOp> operands,
156                                     absl::string_view fusion_kind,
157                                     const XlaComputation& fused_computation) {
158   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
159     HloInstructionProto instr;
160     instr.set_fusion_kind(std::string(fusion_kind));
161     std::vector<const Shape*> operand_shape_ptrs;
162     TF_ASSIGN_OR_RETURN(auto program_shape,
163                         fused_computation.GetProgramShape());
164     *instr.mutable_shape() = program_shape.result().ToProto();
165     builder->AddCalledComputation(fused_computation, &instr);
166     return builder->AddInstruction(std::move(instr), HloOpcode::kFusion,
167                                    operands);
168   });
169 }
170 
BuildBitcast(XlaBuilder * builder,XlaOp operand,const Shape & shape)171 XlaOp XlaBuilderFriend::BuildBitcast(XlaBuilder* builder, XlaOp operand,
172                                      const Shape& shape) {
173   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
174     HloInstructionProto instr;
175     *instr.mutable_shape() = shape.ToProto();
176     return builder->AddInstruction(std::move(instr), HloOpcode::kBitcast,
177                                    {operand});
178   });
179 }
180 
GetInstruction(XlaOp op)181 HloInstructionProto* XlaBuilderFriend::GetInstruction(XlaOp op) {
182   return &op.builder()
183               ->instructions_[op.builder()->handle_to_index_[op.handle_]];
184 }
185 
186 }  // namespace internal
187 
operator -(XlaOp x)188 XlaOp operator-(XlaOp x) { return Neg(x); }
operator +(XlaOp x,XlaOp y)189 XlaOp operator+(XlaOp x, XlaOp y) { return Add(x, y); }
operator -(XlaOp x,XlaOp y)190 XlaOp operator-(XlaOp x, XlaOp y) { return Sub(x, y); }
operator *(XlaOp x,XlaOp y)191 XlaOp operator*(XlaOp x, XlaOp y) { return Mul(x, y); }
operator /(XlaOp x,XlaOp y)192 XlaOp operator/(XlaOp x, XlaOp y) { return Div(x, y); }
operator %(XlaOp x,XlaOp y)193 XlaOp operator%(XlaOp x, XlaOp y) { return Rem(x, y); }
194 
operator ~(XlaOp x)195 XlaOp operator~(XlaOp x) { return Not(x); }
operator &(XlaOp x,XlaOp y)196 XlaOp operator&(XlaOp x, XlaOp y) { return And(x, y); }
operator |(XlaOp x,XlaOp y)197 XlaOp operator|(XlaOp x, XlaOp y) { return Or(x, y); }
operator ^(XlaOp x,XlaOp y)198 XlaOp operator^(XlaOp x, XlaOp y) { return Xor(x, y); }
operator <<(XlaOp x,XlaOp y)199 XlaOp operator<<(XlaOp x, XlaOp y) { return ShiftLeft(x, y); }
200 
operator >>(XlaOp x,XlaOp y)201 XlaOp operator>>(XlaOp x, XlaOp y) {
202   XlaBuilder* builder = x.builder();
203   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
204     TF_ASSIGN_OR_RETURN(const xla::Shape* shape, builder->GetShapePtr(x));
205     if (!ShapeUtil::ElementIsIntegral(*shape)) {
206       return InvalidArgument(
207           "Argument to >> operator does not have an integral type (%s).",
208           ShapeUtil::HumanString(*shape));
209     }
210     if (ShapeUtil::ElementIsSigned(*shape)) {
211       return ShiftRightArithmetic(x, y);
212     } else {
213       return ShiftRightLogical(x, y);
214     }
215   });
216 }
217 
GetShapePtr(XlaOp op) const218 StatusOr<const Shape*> XlaBuilder::GetShapePtr(XlaOp op) const {
219   TF_RETURN_IF_ERROR(first_error_);
220   TF_RETURN_IF_ERROR(CheckOpBuilder(op));
221   auto it = handle_to_index_.find(op.handle());
222   if (it == handle_to_index_.end()) {
223     return InvalidArgument("No XlaOp with handle %d", op.handle());
224   }
225   return instruction_shapes_.at(it->second).get();
226 }
227 
GetShape(XlaOp op) const228 StatusOr<Shape> XlaBuilder::GetShape(XlaOp op) const {
229   TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(op));
230   return *shape;
231 }
232 
GetOperandShapes(absl::Span<const XlaOp> operands) const233 StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
234     absl::Span<const XlaOp> operands) const {
235   std::vector<Shape> operand_shapes;
236   for (XlaOp operand : operands) {
237     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
238     operand_shapes.push_back(*shape);
239   }
240   return operand_shapes;
241 }
242 
XlaBuilder(const string & computation_name)243 XlaBuilder::XlaBuilder(const string& computation_name)
244     : name_(computation_name) {}
245 
~XlaBuilder()246 XlaBuilder::~XlaBuilder() {}
247 
ReportError(const Status & error)248 XlaOp XlaBuilder::ReportError(const Status& error) {
249   CHECK(!error.ok());
250   if (die_immediately_on_error_) {
251     LOG(FATAL) << "error building computation: " << error;
252   }
253 
254   if (first_error_.ok()) {
255     first_error_ = error;
256     first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
257   }
258   return XlaOp(this);
259 }
260 
ReportErrorOrReturn(const StatusOr<XlaOp> & op)261 XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr<XlaOp>& op) {
262   if (!first_error_.ok()) {
263     return XlaOp(this);
264   }
265   if (!op.ok()) {
266     return ReportError(op.status());
267   }
268   return op.ValueOrDie();
269 }
270 
ReportErrorOrReturn(const std::function<StatusOr<XlaOp> ()> & op_creator)271 XlaOp XlaBuilder::ReportErrorOrReturn(
272     const std::function<StatusOr<XlaOp>()>& op_creator) {
273   return ReportErrorOrReturn(op_creator());
274 }
275 
GetProgramShape(int64 root_id) const276 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
277   TF_RETURN_IF_ERROR(first_error_);
278   TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
279                       LookUpInstructionByHandle(root_id));
280 
281   ProgramShape program_shape;
282 
283   *program_shape.mutable_result() = Shape(root_proto->shape());
284 
285   // Check that the parameter numbers are continuous from 0, and add parameter
286   // shapes and names to the program shape.
287   const int64 param_count = parameter_numbers_.size();
288   for (int64 i = 0; i < param_count; i++) {
289     program_shape.add_parameters();
290     program_shape.add_parameter_names();
291   }
292   for (const HloInstructionProto& instr : instructions_) {
293     // Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So
294     // to verify continuity, we just need to verify that every parameter is in
295     // the right range.
296     if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
297       const int64 index = instr.parameter_number();
298       TF_RET_CHECK(index >= 0 && index < param_count)
299           << "invalid parameter number: " << index;
300       *program_shape.mutable_parameters(index) = Shape(instr.shape());
301       *program_shape.mutable_parameter_names(index) = instr.name();
302     }
303   }
304   return program_shape;
305 }
306 
GetProgramShape() const307 StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const {
308   TF_RET_CHECK(!instructions_.empty());
309   return GetProgramShape(instructions_.back().id());
310 }
311 
GetProgramShape(XlaOp root) const312 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(XlaOp root) const {
313   if (root.builder_ != this) {
314     return InvalidArgument("Given root operation is not in this computation.");
315   }
316   return GetProgramShape(root.handle());
317 }
318 
IsConstantVisitor(const int64 op_handle,absl::flat_hash_set<int64> * visited,bool * is_constant) const319 void XlaBuilder::IsConstantVisitor(const int64 op_handle,
320                                    absl::flat_hash_set<int64>* visited,
321                                    bool* is_constant) const {
322   if (visited->contains(op_handle) || !*is_constant) {
323     return;
324   }
325 
326   const HloInstructionProto& instr =
327       *(LookUpInstructionByHandle(op_handle).ValueOrDie());
328   const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
329   switch (opcode) {
330     default:
331       for (const int64 operand_id : instr.operand_ids()) {
332         IsConstantVisitor(operand_id, visited, is_constant);
333       }
334       // TODO(b/32495713): We aren't checking the called computations.
335       break;
336 
337     case HloOpcode::kGetDimensionSize:
338       // GetDimensionSize is always considered constant in XLA -- If a dynamic
339       // dimension is presented, -1 is returned.
340       break;
341     // Non functional ops.
342     case HloOpcode::kRng:
343     case HloOpcode::kAllReduce:
344       // TODO(b/33009255): Implement constant folding for cross replica sum.
345     case HloOpcode::kInfeed:
346     case HloOpcode::kOutfeed:
347     case HloOpcode::kCall:
348       // TODO(b/32495713): We aren't checking the to_apply computation itself,
349       // so we conservatively say that computations containing the Call op
350       // cannot be constant.  We cannot set is_functional=false in other similar
351       // cases since we're already relying on IsConstant to return true.
352     case HloOpcode::kCustomCall:
353       if (instr.custom_call_target() == "SetBound") {
354         // Set bound is considered constant -- the bound is used as the value.
355         break;
356       }
357       TF_FALLTHROUGH_INTENDED;
358     case HloOpcode::kWhile:
359       // TODO(b/32495713): We aren't checking the condition and body
360       // computations themselves.
361     case HloOpcode::kScatter:
362       // TODO(b/32495713): We aren't checking the embedded computation in
363       // Scatter.
364     case HloOpcode::kSend:
365     case HloOpcode::kRecv:
366     case HloOpcode::kParameter:
367       *is_constant = false;
368       break;
369   }
370   if (!*is_constant) {
371     VLOG(1) << "Non-constant: " << instr.name();
372   }
373   visited->insert(op_handle);
374 }
375 
SetDynamicBinding(int64 dynamic_size_param_num,ShapeIndex dynamic_size_param_index,int64 target_param_num,ShapeIndex target_param_index,int64 target_dim_num)376 Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num,
377                                      ShapeIndex dynamic_size_param_index,
378                                      int64 target_param_num,
379                                      ShapeIndex target_param_index,
380                                      int64 target_dim_num) {
381   bool param_exists = false;
382   for (size_t index = 0; index < instructions_.size(); ++index) {
383     HloInstructionProto& instr = instructions_[index];
384     if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
385         instr.parameter_number() == target_param_num) {
386       param_exists = true;
387       Shape param_shape(instr.shape());
388       Shape* param_shape_ptr = &param_shape;
389       for (int64 index : target_param_index) {
390         param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index);
391       }
392       param_shape_ptr->set_dynamic_dimension(target_dim_num,
393                                              /*is_dynamic=*/true);
394       *instr.mutable_shape() = param_shape.ToProto();
395       instruction_shapes_[index] =
396           absl::make_unique<Shape>(std::move(param_shape));
397     }
398   }
399   if (!param_exists) {
400     return InvalidArgument(
401         "Asked to mark parameter %lld as dynamic sized parameter, but the "
402         "doesn't exists",
403         target_param_num);
404   }
405 
406   TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind(
407       DynamicParameterBinding::DynamicParameter{dynamic_size_param_num,
408                                                 dynamic_size_param_index},
409       DynamicParameterBinding::DynamicDimension{
410           target_param_num, target_param_index, target_dim_num}));
411   return Status::OK();
412 }
413 
SetInstructionFrontendAttribute(const XlaOp op,std::string attribute,std::string value)414 Status XlaBuilder::SetInstructionFrontendAttribute(const XlaOp op,
415                                                    std::string attribute,
416                                                    std::string value) {
417   TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op));
418   auto* frontend_attributes = instr_proto->mutable_frontend_attributes();
419   (*frontend_attributes->mutable_map())[attribute] = std::move(value);
420   return Status::OK();
421 }
422 
BuildAndNoteError()423 XlaComputation XlaBuilder::BuildAndNoteError() {
424   DCHECK(parent_builder_ != nullptr);
425   auto build_status = Build();
426   if (!build_status.ok()) {
427     parent_builder_->ReportError(
428         AddStatus(build_status.status(), absl::StrCat("error from: ", name_)));
429     return {};
430   }
431   return build_status.ConsumeValueOrDie();
432 }
433 
GetCurrentStatus() const434 Status XlaBuilder::GetCurrentStatus() const {
435   if (!first_error_.ok()) {
436     string backtrace;
437     first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
438     return AppendStatus(first_error_, backtrace);
439   }
440   return Status::OK();
441 }
442 
Build(bool remove_dynamic_dimensions)443 StatusOr<XlaComputation> XlaBuilder::Build(bool remove_dynamic_dimensions) {
444   TF_RETURN_IF_ERROR(GetCurrentStatus());
445   return Build(instructions_.back().id(), remove_dynamic_dimensions);
446 }
447 
Build(XlaOp root,bool remove_dynamic_dimensions)448 StatusOr<XlaComputation> XlaBuilder::Build(XlaOp root,
449                                            bool remove_dynamic_dimensions) {
450   if (root.builder_ != this) {
451     return InvalidArgument("Given root operation is not in this computation.");
452   }
453   return Build(root.handle(), remove_dynamic_dimensions);
454 }
455 
Build(int64 root_id,bool remove_dynamic_dimensions)456 StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id,
457                                            bool remove_dynamic_dimensions) {
458   TF_RETURN_IF_ERROR(GetCurrentStatus());
459 
460   // TODO(b/121223198): XLA backend cannot handle dynamic dimensions yet, remove
461   // all dynamic dimensions before building xla program until we have support in
462   // the backend.
463   if (remove_dynamic_dimensions) {
464     std::function<void(Shape*)> remove_dynamic_dimension = [&](Shape* shape) {
465       if (shape->tuple_shapes_size() != 0) {
466         for (int64 i = 0; i < shape->tuple_shapes_size(); ++i) {
467           remove_dynamic_dimension(shape->mutable_tuple_shapes(i));
468         }
469       }
470       for (int64 i = 0; i < shape->dimensions_size(); ++i) {
471         shape->set_dynamic_dimension(i, false);
472       }
473     };
474     for (size_t index = 0; index < instructions_.size(); ++index) {
475       remove_dynamic_dimension(instruction_shapes_[index].get());
476       *instructions_[index].mutable_shape() =
477           instruction_shapes_[index]->ToProto();
478     }
479   }
480 
481   HloComputationProto entry;
482   SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId());
483   TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id));
484   *entry.mutable_program_shape() = program_shape.ToProto();
485   entry.set_root_id(root_id);
486 
487   for (auto& instruction : instructions_) {
488     // Ensures that the instruction names are unique among the whole graph.
489     instruction.set_name(
490         GetFullName(instruction.name(), kNameSeparator, instruction.id()));
491     entry.add_instructions()->Swap(&instruction);
492   }
493 
494   XlaComputation computation(entry.id());
495   HloModuleProto* module = computation.mutable_proto();
496   module->set_name(entry.name());
497   module->set_id(entry.id());
498   module->set_entry_computation_name(entry.name());
499   module->set_entry_computation_id(entry.id());
500   *module->mutable_host_program_shape() = entry.program_shape();
501   for (auto& e : embedded_) {
502     module->add_computations()->Swap(&e.second);
503   }
504   module->add_computations()->Swap(&entry);
505   if (!input_output_aliases_.empty()) {
506     TF_RETURN_IF_ERROR(
507         PopulateInputOutputAlias(module, program_shape, input_output_aliases_));
508   }
509   *(module->mutable_dynamic_parameter_binding()) =
510       dynamic_parameter_binding_.ToProto();
511 
512   // Clear data held by this builder.
513   this->instructions_.clear();
514   this->instruction_shapes_.clear();
515   this->handle_to_index_.clear();
516   this->embedded_.clear();
517   this->parameter_numbers_.clear();
518 
519   return std::move(computation);
520 }
521 
PopulateInputOutputAlias(HloModuleProto * module,const ProgramShape & program_shape,const std::vector<InputOutputAlias> & input_output_aliases)522 /* static */ Status XlaBuilder::PopulateInputOutputAlias(
523     HloModuleProto* module, const ProgramShape& program_shape,
524     const std::vector<InputOutputAlias>& input_output_aliases) {
525   HloInputOutputAliasConfig config(program_shape.result());
526   for (auto& alias : input_output_aliases) {
527     // The HloInputOutputAliasConfig does not do parameter validation as it only
528     // carries the result shape. Maybe it should be constructed with a
529     // ProgramShape to allow full validation. We will still get an error when
530     // trying to compile the HLO module, but would be better to have validation
531     // at this stage.
532     if (alias.param_number >= program_shape.parameters_size()) {
533       return InvalidArgument("Invalid parameter number %ld (total %ld)",
534                              alias.param_number,
535                              program_shape.parameters_size());
536     }
537     const Shape& parameter_shape = program_shape.parameters(alias.param_number);
538     if (!ShapeUtil::IndexIsValid(parameter_shape, alias.param_index)) {
539       return InvalidArgument("Invalid parameter %ld index: %s",
540                              alias.param_number,
541                              alias.param_index.ToString().c_str());
542     }
543     TF_RETURN_IF_ERROR(config.SetUpAlias(alias.output_index, alias.param_number,
544                                          alias.param_index, alias.kind));
545   }
546   *module->mutable_input_output_alias() = config.ToProto();
547   return Status::OK();
548 }
549 
InDimBroadcast(const Shape & shape,XlaOp operand,absl::Span<const int64> broadcast_dimensions)550 StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
551     const Shape& shape, XlaOp operand,
552     absl::Span<const int64> broadcast_dimensions) {
553   TF_RETURN_IF_ERROR(first_error_);
554 
555   HloInstructionProto instr;
556   *instr.mutable_shape() = shape.ToProto();
557   for (int64 dim : broadcast_dimensions) {
558     instr.add_dimensions(dim);
559   }
560 
561   return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand});
562 }
563 
AddBroadcastSequence(const Shape & output_shape,XlaOp operand)564 StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
565                                                  XlaOp operand) {
566   TF_RETURN_IF_ERROR(first_error_);
567 
568   TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
569 
570   CHECK(ShapeUtil::IsScalar(*operand_shape) ||
571         operand_shape->rank() == output_shape.rank());
572   Shape broadcast_shape =
573       ShapeUtil::ChangeElementType(output_shape, operand_shape->element_type());
574 
575   // Do explicit broadcast for scalar.
576   if (ShapeUtil::IsScalar(*operand_shape)) {
577     return InDimBroadcast(broadcast_shape, operand, {});
578   }
579 
580   // Do explicit broadcast for degenerate broadcast.
581   std::vector<int64> broadcast_dimensions;
582   std::vector<int64> reshaped_dimensions;
583   for (int i = 0; i < operand_shape->rank(); i++) {
584     if (operand_shape->dimensions(i) == output_shape.dimensions(i)) {
585       broadcast_dimensions.push_back(i);
586       reshaped_dimensions.push_back(operand_shape->dimensions(i));
587     } else {
588       TF_RET_CHECK(operand_shape->dimensions(i) == 1)
589           << "An explicit broadcast sequence requires the broadcasted "
590              "dimensions to be trivial; operand shape: "
591           << *operand_shape << "; output_shape: " << output_shape;
592     }
593   }
594 
595   Shape reshaped_shape =
596       ShapeUtil::MakeShape(operand_shape->element_type(), reshaped_dimensions);
597 
598   std::vector<std::pair<int64, int64>> unmodified_dims =
599       ShapeUtil::DimensionsUnmodifiedByReshape(*operand_shape, reshaped_shape);
600 
601   for (auto& unmodified : unmodified_dims) {
602     if (operand_shape->is_dynamic_dimension(unmodified.first)) {
603       reshaped_shape.set_dynamic_dimension(unmodified.second, true);
604     }
605   }
606 
607   // Eliminate the size one dimensions.
608   TF_ASSIGN_OR_RETURN(
609       XlaOp reshaped_operand,
610       ReshapeInternal(reshaped_shape, operand, /*inferred_dimension=*/-1));
611   // Broadcast 'reshape' up to the larger size.
612   return InDimBroadcast(broadcast_shape, reshaped_operand,
613                         broadcast_dimensions);
614 }
615 
UnaryOp(HloOpcode unop,XlaOp operand)616 XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
617   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
618     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
619     TF_ASSIGN_OR_RETURN(
620         Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape));
621     return AddOpWithShape(unop, shape, {operand});
622   });
623 }
624 
BinaryOp(HloOpcode binop,XlaOp lhs,XlaOp rhs,absl::Span<const int64> broadcast_dimensions,absl::optional<ComparisonDirection> direction,absl::optional<Comparison::Type> type)625 XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
626                            absl::Span<const int64> broadcast_dimensions,
627                            absl::optional<ComparisonDirection> direction,
628                            absl::optional<Comparison::Type> type) {
629   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
630     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
631     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
632     TF_ASSIGN_OR_RETURN(
633         Shape shape, ShapeInference::InferBinaryOpShape(
634                          binop, *lhs_shape, *rhs_shape, broadcast_dimensions));
635 
636     const int64 lhs_rank = lhs_shape->rank();
637     const int64 rhs_rank = rhs_shape->rank();
638 
639     XlaOp updated_lhs = lhs;
640     XlaOp updated_rhs = rhs;
641     if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) {
642       const bool should_broadcast_lhs = lhs_rank < rhs_rank;
643       XlaOp from = should_broadcast_lhs ? lhs : rhs;
644       const Shape& from_shape = should_broadcast_lhs ? *lhs_shape : *rhs_shape;
645 
646       std::vector<int64> to_size;
647       std::vector<bool> to_size_is_dynamic;
648       for (int i = 0; i < shape.rank(); i++) {
649         to_size.push_back(shape.dimensions(i));
650         to_size_is_dynamic.push_back(shape.is_dynamic_dimension(i));
651       }
652       for (int64 from_dim = 0; from_dim < from_shape.rank(); from_dim++) {
653         int64 to_dim = broadcast_dimensions[from_dim];
654         to_size[to_dim] = from_shape.dimensions(from_dim);
655         to_size_is_dynamic[to_dim] = from_shape.is_dynamic_dimension(from_dim);
656       }
657 
658       const Shape& broadcasted_shape = ShapeUtil::MakeShape(
659           from_shape.element_type(), to_size, to_size_is_dynamic);
660       TF_ASSIGN_OR_RETURN(
661           XlaOp broadcasted_operand,
662           InDimBroadcast(broadcasted_shape, from, broadcast_dimensions));
663 
664       updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs;
665       updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs;
666     }
667 
668     TF_ASSIGN_OR_RETURN(const Shape* updated_lhs_shape,
669                         GetShapePtr(updated_lhs));
670     if (!ShapeUtil::SameDimensions(shape, *updated_lhs_shape)) {
671       TF_ASSIGN_OR_RETURN(updated_lhs,
672                           AddBroadcastSequence(shape, updated_lhs));
673     }
674     TF_ASSIGN_OR_RETURN(const Shape* updated_rhs_shape,
675                         GetShapePtr(updated_rhs));
676     if (!ShapeUtil::SameDimensions(shape, *updated_rhs_shape)) {
677       TF_ASSIGN_OR_RETURN(updated_rhs,
678                           AddBroadcastSequence(shape, updated_rhs));
679     }
680 
681     if (binop == HloOpcode::kCompare) {
682       if (!direction.has_value()) {
683         return InvalidArgument(
684             "kCompare expects a ComparisonDirection, but none provided.");
685       }
686       if (type == absl::nullopt) {
687         return Compare(shape, updated_lhs, updated_rhs, *direction);
688       } else {
689         return Compare(shape, updated_lhs, updated_rhs, *direction, *type);
690       }
691     }
692 
693     if (direction.has_value()) {
694       return InvalidArgument(
695           "A comparison direction is provided for a non-compare opcode: %s.",
696           HloOpcodeString(binop));
697     }
698     return BinaryOpNoBroadcast(binop, shape, updated_lhs, updated_rhs);
699   });
700 }
701 
BinaryOpNoBroadcast(HloOpcode binop,const Shape & shape,XlaOp lhs,XlaOp rhs)702 XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
703                                       XlaOp lhs, XlaOp rhs) {
704   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
705     HloInstructionProto instr;
706     *instr.mutable_shape() = shape.ToProto();
707     return AddInstruction(std::move(instr), binop, {lhs, rhs});
708   });
709 }
710 
Compare(const Shape & shape,XlaOp lhs,XlaOp rhs,ComparisonDirection direction)711 StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
712                                     ComparisonDirection direction) {
713   TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(lhs));
714   return Compare(
715       shape, lhs, rhs, direction,
716       Comparison::DefaultComparisonType(operand_shape.element_type()));
717 }
718 
Compare(const Shape & shape,XlaOp lhs,XlaOp rhs,ComparisonDirection direction,Comparison::Type type)719 StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
720                                     ComparisonDirection direction,
721                                     Comparison::Type type) {
722   HloInstructionProto instr;
723   instr.set_comparison_direction(ComparisonDirectionToString(direction));
724   instr.set_comparison_type(ComparisonTypeToString(type));
725   *instr.mutable_shape() = shape.ToProto();
726   return AddInstruction(std::move(instr), HloOpcode::kCompare, {lhs, rhs});
727 }
728 
TernaryOp(HloOpcode triop,XlaOp lhs,XlaOp rhs,XlaOp ehs)729 XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) {
730   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
731     XlaOp updated_lhs = lhs;
732     XlaOp updated_rhs = rhs;
733     XlaOp updated_ehs = ehs;
734     // The client API supports implicit broadcast for kSelect and kClamp, but
735     // XLA does not support implicit broadcast. Make implicit broadcast explicit
736     // and update the operands.
737     if (triop == HloOpcode::kSelect || triop == HloOpcode::kClamp) {
738       TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
739       TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
740       TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(ehs));
741 
742       absl::optional<Shape> non_scalar_shape;
743       for (const Shape* shape : {lhs_shape, rhs_shape, ehs_shape}) {
744         if (shape->IsArray() && shape->rank() != 0) {
745           if (non_scalar_shape.has_value()) {
746             // TODO(jpienaar): The case where we need to compute the broadcasted
747             // shape by considering multiple of the shapes is not implemented.
748             // Consider reusing getBroadcastedType from mlir/Dialect/Traits.h.
749             TF_RET_CHECK(non_scalar_shape.value().dimensions() ==
750                          shape->dimensions())
751                 << "Unimplemented implicit broadcast.";
752           } else {
753             non_scalar_shape = *shape;
754           }
755         }
756       }
757       if (non_scalar_shape.has_value()) {
758         if (ShapeUtil::IsScalar(*lhs_shape)) {
759           TF_ASSIGN_OR_RETURN(updated_lhs,
760                               AddBroadcastSequence(*non_scalar_shape, lhs));
761         }
762         if (ShapeUtil::IsScalar(*rhs_shape)) {
763           TF_ASSIGN_OR_RETURN(updated_rhs,
764                               AddBroadcastSequence(*non_scalar_shape, rhs));
765         }
766         if (ShapeUtil::IsScalar(*ehs_shape)) {
767           TF_ASSIGN_OR_RETURN(updated_ehs,
768                               AddBroadcastSequence(*non_scalar_shape, ehs));
769         }
770       }
771     }
772 
773     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(updated_lhs));
774     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(updated_rhs));
775     TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(updated_ehs));
776     StatusOr<const Shape> status_or_shape = ShapeInference::InferTernaryOpShape(
777         triop, *lhs_shape, *rhs_shape, *ehs_shape);
778     if (!status_or_shape.status().ok()) {
779       return InvalidArgument(
780           "%s Input scalar shapes may have been changed to non-scalar shapes.",
781           status_or_shape.status().error_message());
782     }
783 
784     return AddOpWithShape(triop, status_or_shape.ValueOrDie(),
785                           {updated_lhs, updated_rhs, updated_ehs});
786   });
787 }
788 
ConstantLiteral(const LiteralSlice & literal)789 XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
790   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
791     if (literal.shape().IsArray() && literal.element_count() > 1 &&
792         literal.IsAllFirst()) {
793       Literal scalar = LiteralUtil::GetFirstScalarLiteral(literal);
794       HloInstructionProto instr;
795       *instr.mutable_shape() = scalar.shape().ToProto();
796       *instr.mutable_literal() = scalar.ToProto();
797       TF_ASSIGN_OR_RETURN(
798           XlaOp scalar_op,
799           AddInstruction(std::move(instr), HloOpcode::kConstant));
800       return Broadcast(scalar_op, literal.shape().dimensions());
801     } else {
802       HloInstructionProto instr;
803       *instr.mutable_shape() = literal.shape().ToProto();
804       *instr.mutable_literal() = literal.ToProto();
805       return AddInstruction(std::move(instr), HloOpcode::kConstant);
806     }
807   });
808 }
809 
Iota(const Shape & shape,int64 iota_dimension)810 XlaOp XlaBuilder::Iota(const Shape& shape, int64 iota_dimension) {
811   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
812     HloInstructionProto instr;
813     *instr.mutable_shape() = shape.ToProto();
814     instr.add_dimensions(iota_dimension);
815     return AddInstruction(std::move(instr), HloOpcode::kIota);
816   });
817 }
818 
Iota(PrimitiveType type,int64 size)819 XlaOp XlaBuilder::Iota(PrimitiveType type, int64 size) {
820   return Iota(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0);
821 }
822 
Call(const XlaComputation & computation,absl::Span<const XlaOp> operands)823 XlaOp XlaBuilder::Call(const XlaComputation& computation,
824                        absl::Span<const XlaOp> operands) {
825   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
826     HloInstructionProto instr;
827     std::vector<const Shape*> operand_shape_ptrs;
828     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
829     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
830                       [](const Shape& shape) { return &shape; });
831     TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
832                         computation.GetProgramShape());
833     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape(
834                                          operand_shape_ptrs,
835                                          /*to_apply=*/called_program_shape));
836     *instr.mutable_shape() = shape.ToProto();
837 
838     AddCalledComputation(computation, &instr);
839 
840     return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
841   });
842 }
843 
Parameter(int64 parameter_number,const Shape & shape,const string & name,const std::vector<bool> & replicated_at_leaf_buffers)844 XlaOp XlaBuilder::Parameter(
845     int64 parameter_number, const Shape& shape, const string& name,
846     const std::vector<bool>& replicated_at_leaf_buffers) {
847   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
848     HloInstructionProto instr;
849     if (!parameter_numbers_.insert(parameter_number).second) {
850       return InvalidArgument("parameter %d already registered",
851                              parameter_number);
852     }
853     instr.set_parameter_number(parameter_number);
854     instr.set_name(name);
855     *instr.mutable_shape() = shape.ToProto();
856     if (!replicated_at_leaf_buffers.empty()) {
857       auto replication = instr.mutable_parameter_replication();
858       for (bool replicated : replicated_at_leaf_buffers) {
859         replication->add_replicated_at_leaf_buffers(replicated);
860       }
861     }
862     return AddInstruction(std::move(instr), HloOpcode::kParameter);
863   });
864 }
865 
Broadcast(XlaOp operand,absl::Span<const int64> broadcast_sizes)866 XlaOp XlaBuilder::Broadcast(XlaOp operand,
867                             absl::Span<const int64> broadcast_sizes) {
868   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
869     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
870     TF_ASSIGN_OR_RETURN(
871         const Shape& shape,
872         ShapeInference::InferBroadcastShape(*operand_shape, broadcast_sizes));
873 
874     // The client-level broadcast op just appends dimensions on the left (adds
875     // lowest numbered dimensions). The HLO broadcast instruction is more
876     // flexible and can add new dimensions anywhere. The instruction's
877     // dimensions field maps operand dimensions to dimensions in the broadcast
878     // output, so to append dimensions on the left the instruction's dimensions
879     // should just be the n highest dimension numbers of the output shape where
880     // n is the number of input dimensions.
881     const int64 operand_rank = operand_shape->rank();
882     std::vector<int64> dimensions(operand_rank);
883     for (int i = 0; i < operand_rank; ++i) {
884       dimensions[i] = i + shape.rank() - operand_rank;
885     }
886     return InDimBroadcast(shape, operand, dimensions);
887   });
888 }
889 
BroadcastInDim(XlaOp operand,const absl::Span<const int64> out_dim_size,const absl::Span<const int64> broadcast_dimensions)890 XlaOp XlaBuilder::BroadcastInDim(
891     XlaOp operand, const absl::Span<const int64> out_dim_size,
892     const absl::Span<const int64> broadcast_dimensions) {
893   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
894     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
895     // Output shape, in the case of degenerate broadcast, the out_dim_size is
896     // not necessarily the same as the dimension sizes of the output shape.
897     TF_ASSIGN_OR_RETURN(auto output_shape,
898                         ShapeUtil::MakeValidatedShape(
899                             operand_shape->element_type(), out_dim_size));
900     tensorflow::int64 broadcast_rank = broadcast_dimensions.size();
901     if (operand_shape->rank() != broadcast_rank) {
902       return InvalidArgument(
903           "Size of broadcast_dimensions has to match operand's rank; operand "
904           "rank: %lld, size of broadcast_dimensions %u.",
905           operand_shape->rank(), broadcast_dimensions.size());
906     }
907     for (int i = 0; i < broadcast_rank; i++) {
908       const tensorflow::int64 num_dims = out_dim_size.size();
909       if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] > num_dims) {
910         return InvalidArgument("Broadcast dimension %lld is out of bound",
911                                broadcast_dimensions[i]);
912       }
913       output_shape.set_dynamic_dimension(
914           broadcast_dimensions[i], operand_shape->is_dynamic_dimension(i));
915     }
916 
917     TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape(
918                            *operand_shape, output_shape, broadcast_dimensions)
919                            .status());
920     std::vector<int64> in_dim_size(out_dim_size.begin(), out_dim_size.end());
921     for (int i = 0; i < broadcast_rank; i++) {
922       in_dim_size[broadcast_dimensions[i]] = operand_shape->dimensions(i);
923     }
924     const auto& in_dim_shape =
925         ShapeUtil::MakeShape(operand_shape->element_type(), in_dim_size);
926     TF_ASSIGN_OR_RETURN(
927         XlaOp in_dim_broadcast,
928         InDimBroadcast(in_dim_shape, operand, broadcast_dimensions));
929 
930     // If broadcast is not degenerate, return broadcasted result.
931     if (ShapeUtil::Equal(in_dim_shape, output_shape)) {
932       return in_dim_broadcast;
933     }
934 
935     // Otherwise handle degenerate broadcast case.
936     return AddBroadcastSequence(output_shape, in_dim_broadcast);
937   });
938 }
939 
ReshapeInternal(const Shape & shape,XlaOp operand,int64 inferred_dimension)940 StatusOr<XlaOp> XlaBuilder::ReshapeInternal(const Shape& shape, XlaOp operand,
941                                             int64 inferred_dimension) {
942   TF_RETURN_IF_ERROR(first_error_);
943 
944   HloInstructionProto instr;
945   *instr.mutable_shape() = shape.ToProto();
946   if (inferred_dimension != -1) {
947     instr.add_dimensions(inferred_dimension);
948   }
949   return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand});
950 }
951 
Slice(XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)952 XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span<const int64> start_indices,
953                         absl::Span<const int64> limit_indices,
954                         absl::Span<const int64> strides) {
955   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
956     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
957     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSliceShape(
958                                          *operand_shape, start_indices,
959                                          limit_indices, strides));
960     return SliceInternal(shape, operand, start_indices, limit_indices, strides);
961   });
962 }
963 
SliceInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)964 StatusOr<XlaOp> XlaBuilder::SliceInternal(const Shape& shape, XlaOp operand,
965                                           absl::Span<const int64> start_indices,
966                                           absl::Span<const int64> limit_indices,
967                                           absl::Span<const int64> strides) {
968   HloInstructionProto instr;
969   *instr.mutable_shape() = shape.ToProto();
970   for (int i = 0, end = start_indices.size(); i < end; i++) {
971     auto* slice_config = instr.add_slice_dimensions();
972     slice_config->set_start(start_indices[i]);
973     slice_config->set_limit(limit_indices[i]);
974     slice_config->set_stride(strides[i]);
975   }
976   return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
977 }
978 
SliceInDim(XlaOp operand,int64 start_index,int64 limit_index,int64 stride,int64 dimno)979 XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64 start_index,
980                              int64 limit_index, int64 stride, int64 dimno) {
981   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
982     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
983     std::vector<int64> starts(shape->rank(), 0);
984     std::vector<int64> limits(shape->dimensions().begin(),
985                               shape->dimensions().end());
986     std::vector<int64> strides(shape->rank(), 1);
987     starts[dimno] = start_index;
988     limits[dimno] = limit_index;
989     strides[dimno] = stride;
990     return Slice(operand, starts, limits, strides);
991   });
992 }
993 
DynamicSlice(XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)994 XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
995                                absl::Span<const XlaOp> start_indices,
996                                absl::Span<const int64> slice_sizes) {
997   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
998     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
999     std::vector<const Shape*> start_indices_shape_ptrs;
1000     TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
1001                         GetOperandShapes(start_indices));
1002     absl::c_transform(start_indices_shapes,
1003                       std::back_inserter(start_indices_shape_ptrs),
1004                       [](const Shape& shape) { return &shape; });
1005     TF_ASSIGN_OR_RETURN(Shape shape,
1006                         ShapeInference::InferDynamicSliceShape(
1007                             *operand_shape, start_indices_shapes, slice_sizes));
1008     return DynamicSliceInternal(shape, operand, start_indices, slice_sizes);
1009   });
1010 }
1011 
DynamicSliceInternal(const Shape & shape,XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)1012 StatusOr<XlaOp> XlaBuilder::DynamicSliceInternal(
1013     const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
1014     absl::Span<const int64> slice_sizes) {
1015   HloInstructionProto instr;
1016   *instr.mutable_shape() = shape.ToProto();
1017 
1018   for (int64 size : slice_sizes) {
1019     instr.add_dynamic_slice_sizes(size);
1020   }
1021 
1022   std::vector<XlaOp> operands = {operand};
1023   operands.insert(operands.end(), start_indices.begin(), start_indices.end());
1024   return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands);
1025 }
1026 
DynamicUpdateSlice(XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)1027 XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
1028                                      absl::Span<const XlaOp> start_indices) {
1029   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1030     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1031     TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
1032     std::vector<const Shape*> start_indices_shape_ptrs;
1033     TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
1034                         GetOperandShapes(start_indices));
1035     absl::c_transform(start_indices_shapes,
1036                       std::back_inserter(start_indices_shape_ptrs),
1037                       [](const Shape& shape) { return &shape; });
1038     TF_ASSIGN_OR_RETURN(
1039         Shape shape, ShapeInference::InferDynamicUpdateSliceShape(
1040                          *operand_shape, *update_shape, start_indices_shapes));
1041     return DynamicUpdateSliceInternal(shape, operand, update, start_indices);
1042   });
1043 }
1044 
DynamicUpdateSliceInternal(const Shape & shape,XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)1045 StatusOr<XlaOp> XlaBuilder::DynamicUpdateSliceInternal(
1046     const Shape& shape, XlaOp operand, XlaOp update,
1047     absl::Span<const XlaOp> start_indices) {
1048   HloInstructionProto instr;
1049   *instr.mutable_shape() = shape.ToProto();
1050 
1051   std::vector<XlaOp> operands = {operand, update};
1052   operands.insert(operands.end(), start_indices.begin(), start_indices.end());
1053   return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
1054                         operands);
1055 }
1056 
ConcatInDim(absl::Span<const XlaOp> operands,int64 dimension)1057 XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
1058                               int64 dimension) {
1059   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1060     std::vector<const Shape*> operand_shape_ptrs;
1061     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
1062     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1063                       [](const Shape& shape) { return &shape; });
1064     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape(
1065                                          operand_shape_ptrs, dimension));
1066     return ConcatInDimInternal(shape, operands, dimension);
1067   });
1068 }
1069 
ConcatInDimInternal(const Shape & shape,absl::Span<const XlaOp> operands,int64 dimension)1070 StatusOr<XlaOp> XlaBuilder::ConcatInDimInternal(
1071     const Shape& shape, absl::Span<const XlaOp> operands, int64 dimension) {
1072   HloInstructionProto instr;
1073   *instr.mutable_shape() = shape.ToProto();
1074 
1075   instr.add_dimensions(dimension);
1076 
1077   return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
1078 }
1079 
Pad(XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)1080 XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value,
1081                       const PaddingConfig& padding_config) {
1082   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1083     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1084     TF_ASSIGN_OR_RETURN(const Shape* padding_value_shape,
1085                         GetShapePtr(padding_value));
1086     TF_ASSIGN_OR_RETURN(
1087         Shape shape, ShapeInference::InferPadShape(
1088                          *operand_shape, *padding_value_shape, padding_config));
1089     return PadInternal(shape, operand, padding_value, padding_config);
1090   });
1091 }
1092 
PadInDim(XlaOp operand,XlaOp padding_value,int64 dimno,int64 pad_lo,int64 pad_hi)1093 XlaOp XlaBuilder::PadInDim(XlaOp operand, XlaOp padding_value, int64 dimno,
1094                            int64 pad_lo, int64 pad_hi) {
1095   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1096     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
1097     PaddingConfig padding_config = MakeNoPaddingConfig(shape->rank());
1098     auto* dims = padding_config.mutable_dimensions(dimno);
1099     dims->set_edge_padding_low(pad_lo);
1100     dims->set_edge_padding_high(pad_hi);
1101     return Pad(operand, padding_value, padding_config);
1102   });
1103 }
1104 
PadInternal(const Shape & shape,XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)1105 StatusOr<XlaOp> XlaBuilder::PadInternal(const Shape& shape, XlaOp operand,
1106                                         XlaOp padding_value,
1107                                         const PaddingConfig& padding_config) {
1108   HloInstructionProto instr;
1109   *instr.mutable_shape() = shape.ToProto();
1110   *instr.mutable_padding_config() = padding_config;
1111   return AddInstruction(std::move(instr), HloOpcode::kPad,
1112                         {operand, padding_value});
1113 }
1114 
Reshape(XlaOp operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes,int64 inferred_dimension)1115 XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> dimensions,
1116                           absl::Span<const int64> new_sizes,
1117                           int64 inferred_dimension) {
1118   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1119     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1120     TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferReshapeShape(
1121                                                *operand_shape, dimensions,
1122                                                new_sizes, inferred_dimension));
1123     XlaOp transposed = IsIdentityPermutation(dimensions)
1124                            ? operand
1125                            : Transpose(operand, dimensions);
1126     return ReshapeInternal(shape, transposed, inferred_dimension);
1127   });
1128 }
1129 
Reshape(XlaOp operand,absl::Span<const int64> new_sizes,int64 inferred_dimension)1130 XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
1131                           int64 inferred_dimension) {
1132   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1133     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
1134     std::vector<int64> dimensions(shape->dimensions_size());
1135     std::iota(dimensions.begin(), dimensions.end(), 0);
1136     return Reshape(operand, dimensions, new_sizes, inferred_dimension);
1137   });
1138 }
1139 
Reshape(const Shape & shape,XlaOp operand,int64 inferred_dimension)1140 XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand,
1141                           int64 inferred_dimension) {
1142   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1143     return ReshapeInternal(shape, operand, inferred_dimension);
1144   });
1145 }
1146 
DynamicReshape(XlaOp operand,absl::Span<const XlaOp> dim_sizes,absl::Span<const int64> new_size_bounds,const std::vector<bool> & dims_are_dynamic)1147 XlaOp XlaBuilder::DynamicReshape(XlaOp operand,
1148                                  absl::Span<const XlaOp> dim_sizes,
1149                                  absl::Span<const int64> new_size_bounds,
1150                                  const std::vector<bool>& dims_are_dynamic) {
1151   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1152     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1153     std::vector<const Shape*> dim_size_shape_ptrs;
1154     TF_ASSIGN_OR_RETURN(const auto& dim_size_shapes,
1155                         GetOperandShapes(dim_sizes));
1156 
1157     absl::c_transform(dim_size_shapes, std::back_inserter(dim_size_shape_ptrs),
1158                       [](const Shape& shape) { return &shape; });
1159     TF_ASSIGN_OR_RETURN(const Shape shape,
1160                         ShapeInference::InferDynamicReshapeShape(
1161                             *operand_shape, dim_size_shape_ptrs,
1162                             new_size_bounds, dims_are_dynamic));
1163     TF_RETURN_IF_ERROR(first_error_);
1164     std::vector<XlaOp> operands;
1165     operands.reserve(1 + dim_sizes.size());
1166     operands.push_back(operand);
1167     for (const XlaOp& dim_size : dim_sizes) {
1168       operands.push_back(dim_size);
1169     }
1170     HloInstructionProto instr;
1171     *instr.mutable_shape() = shape.ToProto();
1172     return AddInstruction(std::move(instr), HloOpcode::kDynamicReshape,
1173                           operands);
1174   });
1175 }
1176 
Collapse(XlaOp operand,absl::Span<const int64> dimensions)1177 XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span<const int64> dimensions) {
1178   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1179     if (dimensions.size() <= 1) {
1180       // Not collapsing anything, trivially we can return the operand versus
1181       // enqueueing a trivial reshape.
1182       return operand;
1183     }
1184 
1185     // Out-of-order collapse is not supported.
1186     // Checks that the collapsed dimensions are in order and consecutive.
1187     for (absl::Span<const int64>::size_type i = 1; i < dimensions.size(); ++i) {
1188       if (dimensions[i] - 1 != dimensions[i - 1]) {
1189         return InvalidArgument(
1190             "Collapsed dimensions are not in consecutive order.");
1191       }
1192     }
1193 
1194     // Create a new sizes vector from the old shape, replacing the collapsed
1195     // dimensions by the product of their sizes.
1196     TF_ASSIGN_OR_RETURN(const Shape* original_shape, GetShapePtr(operand));
1197 
1198     VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape);
1199     VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ",");
1200 
1201     std::vector<int64> new_sizes;
1202     for (int i = 0; i < original_shape->rank(); ++i) {
1203       if (i <= dimensions.front() || i > dimensions.back()) {
1204         new_sizes.push_back(original_shape->dimensions(i));
1205       } else {
1206         new_sizes.back() *= original_shape->dimensions(i);
1207       }
1208     }
1209 
1210     VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]";
1211 
1212     return Reshape(operand, new_sizes);
1213   });
1214 }
1215 
Trace(const string & tag,XlaOp operand)1216 void XlaBuilder::Trace(const string& tag, XlaOp operand) {
1217   ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1218     HloInstructionProto instr;
1219     *instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
1220     *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto();
1221     return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
1222   });
1223 }
1224 
Select(XlaOp pred,XlaOp on_true,XlaOp on_false)1225 XlaOp XlaBuilder::Select(XlaOp pred, XlaOp on_true, XlaOp on_false) {
1226   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1227     TF_ASSIGN_OR_RETURN(const Shape* true_shape, GetShapePtr(on_true));
1228     TF_ASSIGN_OR_RETURN(const Shape* false_shape, GetShapePtr(on_false));
1229     TF_RET_CHECK(true_shape->IsTuple() == false_shape->IsTuple());
1230     HloOpcode opcode =
1231         true_shape->IsTuple() ? HloOpcode::kTupleSelect : HloOpcode::kSelect;
1232     return TernaryOp(opcode, pred, on_true, on_false);
1233   });
1234 }
1235 
Tuple(absl::Span<const XlaOp> elements)1236 XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
1237   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1238     std::vector<const Shape*> operand_shape_ptrs;
1239     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
1240     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1241                       [](const Shape& shape) { return &shape; });
1242     TF_ASSIGN_OR_RETURN(const Shape shape,
1243                         ShapeInference::InferVariadicOpShape(
1244                             HloOpcode::kTuple, operand_shape_ptrs));
1245     return TupleInternal(shape, elements);
1246   });
1247 }
1248 
TupleInternal(const Shape & shape,absl::Span<const XlaOp> elements)1249 StatusOr<XlaOp> XlaBuilder::TupleInternal(const Shape& shape,
1250                                           absl::Span<const XlaOp> elements) {
1251   HloInstructionProto instr;
1252   *instr.mutable_shape() = shape.ToProto();
1253   return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
1254 }
1255 
GetTupleElement(XlaOp tuple_data,int64 index)1256 XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64 index) {
1257   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1258     TF_ASSIGN_OR_RETURN(const Shape* tuple_shape, GetShapePtr(tuple_data));
1259     if (!tuple_shape->IsTuple()) {
1260       return InvalidArgument(
1261           "Operand to GetTupleElement() is not a tuple; got %s",
1262           ShapeUtil::HumanString(*tuple_shape));
1263     }
1264     if (index < 0 || index >= ShapeUtil::TupleElementCount(*tuple_shape)) {
1265       return InvalidArgument(
1266           "GetTupleElement() index (%d) out of range for tuple shape %s", index,
1267           ShapeUtil::HumanString(*tuple_shape));
1268     }
1269     return GetTupleElementInternal(
1270         ShapeUtil::GetTupleElementShape(*tuple_shape, index), tuple_data,
1271         index);
1272   });
1273 }
1274 
GetTupleElementInternal(const Shape & shape,XlaOp tuple_data,int64 index)1275 StatusOr<XlaOp> XlaBuilder::GetTupleElementInternal(const Shape& shape,
1276                                                     XlaOp tuple_data,
1277                                                     int64 index) {
1278   HloInstructionProto instr;
1279   *instr.mutable_shape() = shape.ToProto();
1280   instr.set_tuple_index(index);
1281   return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
1282                         {tuple_data});
1283 }
1284 
Dot(XlaOp lhs,XlaOp rhs,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1285 XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
1286                       const PrecisionConfig* precision_config,
1287                       absl::optional<PrimitiveType> preferred_element_type) {
1288   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1289     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1290 
1291     DotDimensionNumbers dimension_numbers;
1292     dimension_numbers.add_lhs_contracting_dimensions(
1293         lhs_shape->dimensions_size() == 1 ? 0 : 1);
1294     dimension_numbers.add_rhs_contracting_dimensions(0);
1295     return DotGeneral(lhs, rhs, dimension_numbers, precision_config);
1296   });
1297 }
1298 
DotGeneral(XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1299 XlaOp XlaBuilder::DotGeneral(
1300     XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
1301     const PrecisionConfig* precision_config,
1302     absl::optional<PrimitiveType> preferred_element_type) {
1303   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1304     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1305     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1306     TF_ASSIGN_OR_RETURN(
1307         Shape shape,
1308         ShapeInference::InferDotOpShape(
1309             *lhs_shape, *rhs_shape, dimension_numbers, preferred_element_type));
1310     return DotGeneralInternal(shape, lhs, rhs, dimension_numbers,
1311                               precision_config);
1312   });
1313 }
1314 
DotGeneralInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config)1315 StatusOr<XlaOp> XlaBuilder::DotGeneralInternal(
1316     const Shape& shape, XlaOp lhs, XlaOp rhs,
1317     const DotDimensionNumbers& dimension_numbers,
1318     const PrecisionConfig* precision_config) {
1319   HloInstructionProto instr;
1320   *instr.mutable_shape() = shape.ToProto();
1321   *instr.mutable_dot_dimension_numbers() = dimension_numbers;
1322   if (precision_config != nullptr) {
1323     *instr.mutable_precision_config() = *precision_config;
1324   }
1325   return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
1326 }
1327 
VerifyConvolution(const Shape & lhs_shape,const Shape & rhs_shape,const ConvolutionDimensionNumbers & dimension_numbers) const1328 Status XlaBuilder::VerifyConvolution(
1329     const Shape& lhs_shape, const Shape& rhs_shape,
1330     const ConvolutionDimensionNumbers& dimension_numbers) const {
1331   if (lhs_shape.rank() != rhs_shape.rank()) {
1332     return InvalidArgument(
1333         "Convolution arguments must have same number of "
1334         "dimensions. Got: %s and %s",
1335         ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1336   }
1337   int num_dims = lhs_shape.rank();
1338   if (num_dims < 2) {
1339     return InvalidArgument(
1340         "Convolution expects argument arrays with >= 3 dimensions. "
1341         "Got: %s and %s",
1342         ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1343   }
1344   int num_spatial_dims = num_dims - 2;
1345 
1346   const auto check_spatial_dimensions =
1347       [&](const char* const field_name,
1348           const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
1349               numbers) {
1350         if (numbers.size() != num_spatial_dims) {
1351           return InvalidArgument("Expected %d elements for %s, but got %d.",
1352                                  num_spatial_dims, field_name, numbers.size());
1353         }
1354         for (int i = 0; i < numbers.size(); ++i) {
1355           if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
1356             return InvalidArgument("Convolution %s[%d] is out of bounds: %d",
1357                                    field_name, i, numbers.Get(i));
1358           }
1359         }
1360         return Status::OK();
1361       };
1362   TF_RETURN_IF_ERROR(
1363       check_spatial_dimensions("input_spatial_dimensions",
1364                                dimension_numbers.input_spatial_dimensions()));
1365   TF_RETURN_IF_ERROR(
1366       check_spatial_dimensions("kernel_spatial_dimensions",
1367                                dimension_numbers.kernel_spatial_dimensions()));
1368   return check_spatial_dimensions(
1369       "output_spatial_dimensions",
1370       dimension_numbers.output_spatial_dimensions());
1371 }
1372 
Conv(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1373 XlaOp XlaBuilder::Conv(XlaOp lhs, XlaOp rhs,
1374                        absl::Span<const int64> window_strides, Padding padding,
1375                        int64 feature_group_count, int64 batch_group_count,
1376                        const PrecisionConfig* precision_config,
1377                        absl::optional<PrimitiveType> preferred_element_type) {
1378   return ConvWithGeneralDimensions(
1379       lhs, rhs, window_strides, padding,
1380       CreateDefaultConvDimensionNumbers(window_strides.size()),
1381       feature_group_count, batch_group_count, precision_config,
1382       preferred_element_type);
1383 }
1384 
ConvWithGeneralPadding(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1385 XlaOp XlaBuilder::ConvWithGeneralPadding(
1386     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1387     absl::Span<const std::pair<int64, int64>> padding,
1388     int64 feature_group_count, int64 batch_group_count,
1389     const PrecisionConfig* precision_config,
1390     absl::optional<PrimitiveType> preferred_element_type) {
1391   return ConvGeneral(lhs, rhs, window_strides, padding,
1392                      CreateDefaultConvDimensionNumbers(window_strides.size()),
1393                      feature_group_count, batch_group_count, precision_config,
1394                      preferred_element_type);
1395 }
1396 
ConvWithGeneralDimensions(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1397 XlaOp XlaBuilder::ConvWithGeneralDimensions(
1398     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1399     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1400     int64 feature_group_count, int64 batch_group_count,
1401     const PrecisionConfig* precision_config,
1402     absl::optional<PrimitiveType> preferred_element_type) {
1403   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1404     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1405     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1406 
1407     TF_RETURN_IF_ERROR(
1408         VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers));
1409 
1410     std::vector<int64> base_area_dimensions(
1411         dimension_numbers.input_spatial_dimensions_size());
1412     for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size();
1413          ++i) {
1414       base_area_dimensions[i] =
1415           lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i));
1416     }
1417 
1418     std::vector<int64> window_dimensions(
1419         dimension_numbers.kernel_spatial_dimensions_size());
1420     for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
1421          ++i) {
1422       window_dimensions[i] =
1423           rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1424     }
1425 
1426     return ConvGeneral(lhs, rhs, window_strides,
1427                        MakePadding(base_area_dimensions, window_dimensions,
1428                                    window_strides, padding),
1429                        dimension_numbers, feature_group_count,
1430                        batch_group_count, precision_config,
1431                        preferred_element_type);
1432   });
1433 }
1434 
ConvGeneral(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1435 XlaOp XlaBuilder::ConvGeneral(
1436     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1437     absl::Span<const std::pair<int64, int64>> padding,
1438     const ConvolutionDimensionNumbers& dimension_numbers,
1439     int64 feature_group_count, int64 batch_group_count,
1440     const PrecisionConfig* precision_config,
1441     absl::optional<PrimitiveType> preferred_element_type) {
1442   return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
1443                             dimension_numbers, feature_group_count,
1444                             batch_group_count, precision_config,
1445                             preferred_element_type);
1446 }
1447 
ConvGeneralDilated(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)1448 XlaOp XlaBuilder::ConvGeneralDilated(
1449     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1450     absl::Span<const std::pair<int64, int64>> padding,
1451     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1452     const ConvolutionDimensionNumbers& dimension_numbers,
1453     int64 feature_group_count, int64 batch_group_count,
1454     const PrecisionConfig* precision_config,
1455     absl::optional<PrimitiveType> preferred_element_type) {
1456   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1457     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1458     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1459     TF_RETURN_IF_ERROR(
1460         VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers));
1461 
1462     std::vector<int64> window_dimensions(
1463         dimension_numbers.kernel_spatial_dimensions_size());
1464     for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
1465          ++i) {
1466       window_dimensions[i] =
1467           rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1468     }
1469 
1470     TF_ASSIGN_OR_RETURN(Window window,
1471                         ShapeInference::InferWindowFromDimensions(
1472                             window_dimensions, window_strides, padding,
1473                             lhs_dilation, rhs_dilation));
1474     TF_ASSIGN_OR_RETURN(
1475         Shape shape,
1476         ShapeInference::InferConvolveShape(
1477             *lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
1478             window, dimension_numbers, preferred_element_type));
1479     return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides,
1480                                       padding, lhs_dilation, rhs_dilation,
1481                                       dimension_numbers, feature_group_count,
1482                                       batch_group_count, precision_config);
1483   });
1484 }
1485 
DynamicConvInstruction(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)1486 StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
1487     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1488     absl::Span<const std::pair<int64, int64>> padding,
1489     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1490     const ConvolutionDimensionNumbers& dimension_numbers,
1491     int64 feature_group_count, int64 batch_group_count,
1492     const PrecisionConfig* precision_config, PaddingType padding_type,
1493     absl::optional<PrimitiveType> preferred_element_type) {
1494   TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1495   TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1496   std::vector<int64> window_dimensions(
1497       dimension_numbers.kernel_spatial_dimensions_size());
1498   for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) {
1499     window_dimensions[i] =
1500         rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1501   }
1502 
1503   TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions(
1504                                          window_dimensions, window_strides,
1505                                          padding, lhs_dilation, rhs_dilation));
1506   TF_ASSIGN_OR_RETURN(
1507       Shape shape,
1508       ShapeInference::InferConvolveShape(
1509           *lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
1510           window, dimension_numbers, preferred_element_type));
1511 
1512   HloInstructionProto instr;
1513   *instr.mutable_shape() = shape.ToProto();
1514 
1515   *instr.mutable_window() = window;
1516   *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
1517   instr.set_feature_group_count(feature_group_count);
1518   instr.set_batch_group_count(batch_group_count);
1519   instr.set_padding_type(padding_type);
1520 
1521   if (precision_config != nullptr) {
1522     *instr.mutable_precision_config() = *precision_config;
1523   }
1524   return std::move(instr);
1525 }
1526 
DynamicConvInputGrad(XlaOp input_sizes,XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)1527 XlaOp XlaBuilder::DynamicConvInputGrad(
1528     XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
1529     absl::Span<const int64> window_strides,
1530     absl::Span<const std::pair<int64, int64>> padding,
1531     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1532     const ConvolutionDimensionNumbers& dimension_numbers,
1533     int64 feature_group_count, int64 batch_group_count,
1534     const PrecisionConfig* precision_config, PaddingType padding_type,
1535     absl::optional<PrimitiveType> preferred_element_type) {
1536   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1537     TF_ASSIGN_OR_RETURN(
1538         HloInstructionProto instr,
1539         DynamicConvInstruction(
1540             lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
1541             dimension_numbers, feature_group_count, batch_group_count,
1542             precision_config, padding_type, preferred_element_type));
1543 
1544     instr.set_custom_call_target("DynamicConvolutionInputGrad");
1545 
1546     return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
1547                           {input_sizes, lhs, rhs});
1548   });
1549 }
1550 
DynamicConvKernelGrad(XlaOp activations,XlaOp gradients,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)1551 XlaOp XlaBuilder::DynamicConvKernelGrad(
1552     XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
1553     absl::Span<const std::pair<int64, int64>> padding,
1554     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1555     const ConvolutionDimensionNumbers& dimension_numbers,
1556     int64 feature_group_count, int64 batch_group_count,
1557     const PrecisionConfig* precision_config, PaddingType padding_type,
1558     absl::optional<PrimitiveType> preferred_element_type) {
1559   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1560     TF_ASSIGN_OR_RETURN(
1561         HloInstructionProto instr,
1562         DynamicConvInstruction(activations, gradients, window_strides, padding,
1563                                lhs_dilation, rhs_dilation, dimension_numbers,
1564                                feature_group_count, batch_group_count,
1565                                precision_config, padding_type,
1566                                preferred_element_type));
1567 
1568     instr.set_custom_call_target("DynamicConvolutionKernelGrad");
1569     // The gradient of kernel has kernel shape and shouldn't have any dynamic
1570     // sizes.
1571     instr.mutable_shape()->clear_is_dynamic_dimension();
1572     return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
1573                           {activations, gradients});
1574   });
1575 }
1576 
DynamicConvForward(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)1577 XlaOp XlaBuilder::DynamicConvForward(
1578     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1579     absl::Span<const std::pair<int64, int64>> padding,
1580     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1581     const ConvolutionDimensionNumbers& dimension_numbers,
1582     int64 feature_group_count, int64 batch_group_count,
1583     const PrecisionConfig* precision_config, PaddingType padding_type,
1584     absl::optional<PrimitiveType> preferred_element_type) {
1585   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1586     TF_ASSIGN_OR_RETURN(
1587         HloInstructionProto instr,
1588         DynamicConvInstruction(
1589             lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
1590             dimension_numbers, feature_group_count, batch_group_count,
1591             precision_config, padding_type, preferred_element_type));
1592     instr.set_custom_call_target("DynamicConvolutionForward");
1593 
1594     return AddInstruction(std::move(instr), HloOpcode::kCustomCall, {lhs, rhs});
1595   });
1596 }
1597 
ConvGeneralDilatedInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const Window & window,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1598 StatusOr<XlaOp> XlaBuilder::ConvGeneralDilatedInternal(
1599     const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
1600     absl::Span<const int64> window_strides,
1601     absl::Span<const std::pair<int64, int64>> padding,
1602     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1603     const ConvolutionDimensionNumbers& dimension_numbers,
1604     int64 feature_group_count, int64 batch_group_count,
1605     const PrecisionConfig* precision_config) {
1606   HloInstructionProto instr;
1607   *instr.mutable_shape() = shape.ToProto();
1608 
1609   *instr.mutable_window() = window;
1610   *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
1611   instr.set_feature_group_count(feature_group_count);
1612   instr.set_batch_group_count(batch_group_count);
1613 
1614   if (precision_config != nullptr) {
1615     *instr.mutable_precision_config() = *precision_config;
1616   }
1617 
1618   return AddInstruction(std::move(instr), HloOpcode::kConvolution, {lhs, rhs});
1619 }
1620 
Fft(XlaOp operand,const FftType fft_type,const absl::Span<const int64> fft_length)1621 XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type,
1622                       const absl::Span<const int64> fft_length) {
1623   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1624     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1625     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape(
1626                                          *operand_shape, fft_type, fft_length));
1627     return FftInternal(shape, operand, fft_type, fft_length);
1628   });
1629 }
1630 
FftInternal(const Shape & shape,XlaOp operand,const FftType fft_type,const absl::Span<const int64> fft_length)1631 StatusOr<XlaOp> XlaBuilder::FftInternal(
1632     const Shape& shape, XlaOp operand, const FftType fft_type,
1633     const absl::Span<const int64> fft_length) {
1634   HloInstructionProto instr;
1635   *instr.mutable_shape() = shape.ToProto();
1636   instr.set_fft_type(fft_type);
1637   for (int64 i : fft_length) {
1638     instr.add_fft_length(i);
1639   }
1640 
1641   return AddInstruction(std::move(instr), HloOpcode::kFft, {operand});
1642 }
1643 
TriangularSolveInternal(const Shape & shape,XlaOp a,XlaOp b,TriangularSolveOptions options)1644 StatusOr<XlaOp> XlaBuilder::TriangularSolveInternal(
1645     const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) {
1646   HloInstructionProto instr;
1647   *instr.mutable_triangular_solve_options() = std::move(options);
1648   *instr.mutable_shape() = shape.ToProto();
1649 
1650   return AddInstruction(std::move(instr), HloOpcode::kTriangularSolve, {a, b});
1651 }
1652 
CholeskyInternal(const Shape & shape,XlaOp a,bool lower)1653 StatusOr<XlaOp> XlaBuilder::CholeskyInternal(const Shape& shape, XlaOp a,
1654                                              bool lower) {
1655   HloInstructionProto instr;
1656   xla::CholeskyOptions& options = *instr.mutable_cholesky_options();
1657   options.set_lower(lower);
1658   *instr.mutable_shape() = shape.ToProto();
1659 
1660   return AddInstruction(std::move(instr), HloOpcode::kCholesky, {a});
1661 }
1662 
Infeed(const Shape & shape,const string & config)1663 XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
1664   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1665     HloInstructionProto instr;
1666     if (!LayoutUtil::HasLayout(shape)) {
1667       return InvalidArgument("Given shape to Infeed must have a layout");
1668     }
1669     const Shape infeed_instruction_shape =
1670         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1671     *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1672     instr.set_infeed_config(config);
1673 
1674     if (shape.IsArray() && sharding() &&
1675         sharding()->type() == OpSharding::OTHER) {
1676       // TODO(b/110793772): Support tiled array-shaped infeeds.
1677       return InvalidArgument(
1678           "Tiled sharding is not yet supported for array-shaped infeeds");
1679     }
1680 
1681     if (sharding() && sharding()->type() == OpSharding::REPLICATED) {
1682       return InvalidArgument(
1683           "Replicated sharding is not yet supported for infeeds");
1684     }
1685 
1686     // Infeed takes a single token operand. Generate the token to pass to the
1687     // infeed.
1688     XlaOp token;
1689     auto make_token = [&]() {
1690       HloInstructionProto token_instr;
1691       *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1692       return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {});
1693     };
1694     if (sharding()) {
1695       // Arbitrarily assign token to device 0.
1696       OpSharding sharding = sharding_builder::AssignDevice(0);
1697       XlaScopedShardingAssignment scoped_sharding(this, sharding);
1698       TF_ASSIGN_OR_RETURN(token, make_token());
1699     } else {
1700       TF_ASSIGN_OR_RETURN(token, make_token());
1701     }
1702 
1703     // The sharding is set by the client according to the data tuple shape.
1704     // However, the shape of the infeed instruction is a tuple containing the
1705     // data and a token. For tuple sharding type, the sharding must be changed
1706     // to accommodate the token.
1707     XlaOp infeed;
1708     if (sharding() && sharding()->type() == OpSharding::TUPLE) {
1709       // TODO(b/80000000): Remove this when clients have been updated to handle
1710       // tokens.
1711       OpSharding infeed_instruction_sharding = *sharding();
1712       // Arbitrarily assign the token to device 0.
1713       *infeed_instruction_sharding.add_tuple_shardings() =
1714           sharding_builder::AssignDevice(0);
1715       XlaScopedShardingAssignment scoped_sharding(this,
1716                                                   infeed_instruction_sharding);
1717       TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1718                                                  HloOpcode::kInfeed, {token}));
1719     } else {
1720       TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1721                                                  HloOpcode::kInfeed, {token}));
1722     }
1723 
1724     // The infeed instruction produces a tuple of the infed data and a token
1725     // type. Return XLA op containing the data.
1726     // TODO(b/80000000): Remove this when clients have been updated to handle
1727     // tokens.
1728     HloInstructionProto infeed_data;
1729     *infeed_data.mutable_shape() = shape.ToProto();
1730     infeed_data.set_tuple_index(0);
1731     return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement,
1732                           {infeed});
1733   });
1734 }
1735 
InfeedWithToken(XlaOp token,const Shape & shape,const string & config)1736 XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape,
1737                                   const string& config) {
1738   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1739     if (!LayoutUtil::HasLayout(shape)) {
1740       return InvalidArgument("Given shape to Infeed must have a layout");
1741     }
1742     const Shape infeed_instruction_shape =
1743         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1744 
1745     if (shape.IsArray() && sharding() &&
1746         sharding()->type() == OpSharding::OTHER) {
1747       // TODO(b/110793772): Support tiled array-shaped infeeds.
1748       return InvalidArgument(
1749           "Tiled sharding is not yet supported for array-shaped infeeds");
1750     }
1751 
1752     if (sharding() && sharding()->type() == OpSharding::REPLICATED) {
1753       return InvalidArgument(
1754           "Replicated sharding is not yet supported for infeeds");
1755     }
1756     return InfeedWithTokenInternal(infeed_instruction_shape, token, config);
1757   });
1758 }
1759 
InfeedWithTokenInternal(const Shape & infeed_instruction_shape,XlaOp token,const string & config)1760 StatusOr<XlaOp> XlaBuilder::InfeedWithTokenInternal(
1761     const Shape& infeed_instruction_shape, XlaOp token, const string& config) {
1762   HloInstructionProto instr;
1763   *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1764   instr.set_infeed_config(config);
1765   return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
1766 }
1767 
Outfeed(XlaOp operand,const Shape & shape_with_layout,const string & outfeed_config)1768 void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout,
1769                          const string& outfeed_config) {
1770   ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1771     HloInstructionProto instr;
1772 
1773     *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1774 
1775     // Check and set outfeed shape.
1776     if (!LayoutUtil::HasLayout(shape_with_layout)) {
1777       return InvalidArgument("Given shape to Outfeed must have a layout");
1778     }
1779     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1780     if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
1781       return InvalidArgument(
1782           "Outfeed shape %s must be compatible with operand shape %s",
1783           ShapeUtil::HumanStringWithLayout(shape_with_layout),
1784           ShapeUtil::HumanStringWithLayout(*operand_shape));
1785     }
1786     *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1787 
1788     instr.set_outfeed_config(outfeed_config);
1789 
1790     // Outfeed takes a token as its second operand. Generate the token to pass
1791     // to the outfeed.
1792     HloInstructionProto token_instr;
1793     *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1794     TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
1795                                                     HloOpcode::kAfterAll, {}));
1796 
1797     TF_RETURN_IF_ERROR(
1798         AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand, token})
1799             .status());
1800 
1801     // The outfeed instruction produces a token. However, existing users expect
1802     // a nil shape (empty tuple). This should only be relevant if the outfeed is
1803     // the root of a computation.
1804     // TODO(b/80000000): Remove this when clients have been updated to handle
1805     // tokens.
1806     HloInstructionProto tuple_instr;
1807     *tuple_instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
1808 
1809     // The dummy tuple should have no sharding.
1810     {
1811       XlaScopedShardingAssignment scoped_sharding(this, OpSharding());
1812       TF_ASSIGN_OR_RETURN(
1813           XlaOp empty_tuple,
1814           AddInstruction(std::move(tuple_instr), HloOpcode::kTuple, {}));
1815       return empty_tuple;
1816     }
1817   });
1818 }
1819 
OutfeedWithToken(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)1820 XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token,
1821                                    const Shape& shape_with_layout,
1822                                    const string& outfeed_config) {
1823   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1824     // Check and set outfeed shape.
1825     if (!LayoutUtil::HasLayout(shape_with_layout)) {
1826       return InvalidArgument("Given shape to Outfeed must have a layout");
1827     }
1828     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1829     if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
1830       return InvalidArgument(
1831           "Outfeed shape %s must be compatible with operand shape %s",
1832           ShapeUtil::HumanStringWithLayout(shape_with_layout),
1833           ShapeUtil::HumanStringWithLayout(*operand_shape));
1834     }
1835     return OutfeedWithTokenInternal(operand, token, shape_with_layout,
1836                                     outfeed_config);
1837   });
1838 }
1839 
OutfeedWithTokenInternal(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)1840 StatusOr<XlaOp> XlaBuilder::OutfeedWithTokenInternal(
1841     XlaOp operand, XlaOp token, const Shape& shape_with_layout,
1842     const string& outfeed_config) {
1843   HloInstructionProto instr;
1844   *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1845   *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1846   instr.set_outfeed_config(outfeed_config);
1847   return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
1848                         {operand, token});
1849 }
1850 
CreateToken()1851 XlaOp XlaBuilder::CreateToken() {
1852   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1853     HloInstructionProto instr;
1854     *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1855     return AddInstruction(std::move(instr), HloOpcode::kAfterAll);
1856   });
1857 }
1858 
AfterAll(absl::Span<const XlaOp> tokens)1859 XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
1860   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1861     if (tokens.empty()) {
1862       return InvalidArgument("AfterAll requires at least one operand");
1863     }
1864     for (int i = 0, end = tokens.size(); i < end; ++i) {
1865       XlaOp operand = tokens[i];
1866       TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1867       if (!operand_shape->IsToken()) {
1868         return InvalidArgument(
1869             "All operands to AfterAll must be tokens; operand %d has shape %s",
1870             i, ShapeUtil::HumanString(*operand_shape));
1871       }
1872     }
1873     HloInstructionProto instr;
1874     *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1875     return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens);
1876   });
1877 }
1878 
CustomCall(const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal)1879 XlaOp XlaBuilder::CustomCall(
1880     const string& call_target_name, absl::Span<const XlaOp> operands,
1881     const Shape& shape, const string& opaque,
1882     absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
1883     bool has_side_effect,
1884     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1885         output_operand_aliasing,
1886     const Literal* literal) {
1887   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1888     if (absl::StartsWith(call_target_name, "$")) {
1889       return InvalidArgument(
1890           "Invalid custom_call_target \"%s\": Call targets that start with '$' "
1891           "are reserved for internal use.",
1892           call_target_name);
1893     }
1894     if (operand_shapes_with_layout.has_value()) {
1895       if (!LayoutUtil::HasLayout(shape)) {
1896         return InvalidArgument(
1897             "Result shape must have layout for custom call with constrained "
1898             "layout.");
1899       }
1900       if (operands.size() != operand_shapes_with_layout->size()) {
1901         return InvalidArgument(
1902             "Must specify a shape with layout for each operand for custom call "
1903             "with constrained layout; given %d shapes, expected %d",
1904             operand_shapes_with_layout->size(), operands.size());
1905       }
1906       int64 operand_num = 0;
1907       for (const Shape& operand_shape : *operand_shapes_with_layout) {
1908         if (!LayoutUtil::HasLayout(operand_shape)) {
1909           return InvalidArgument(
1910               "No layout specified for operand %d for custom call with "
1911               "constrained layout.",
1912               operand_num);
1913         }
1914         ++operand_num;
1915       }
1916     }
1917     return CustomCallInternal(call_target_name, operands, shape, opaque,
1918                               operand_shapes_with_layout, has_side_effect,
1919                               output_operand_aliasing, literal);
1920   });
1921 }
1922 
CustomCallInternal(const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal)1923 StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
1924     const string& call_target_name, absl::Span<const XlaOp> operands,
1925     const Shape& shape, const string& opaque,
1926     absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
1927     bool has_side_effect,
1928     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1929         output_operand_aliasing,
1930     const Literal* literal) {
1931   HloInstructionProto instr;
1932   *instr.mutable_shape() = shape.ToProto();
1933   instr.set_custom_call_target(call_target_name);
1934   instr.set_backend_config(opaque);
1935   if (operand_shapes_with_layout.has_value()) {
1936     instr.set_constrain_layout(true);
1937     for (const Shape& operand_shape : *operand_shapes_with_layout) {
1938       *instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
1939     }
1940   }
1941   if (literal != nullptr) {
1942     *instr.mutable_literal() = literal->ToProto();
1943   }
1944   instr.set_custom_call_has_side_effect(has_side_effect);
1945   for (const auto& pair : output_operand_aliasing) {
1946     auto aliasing = instr.add_custom_call_output_operand_aliasing();
1947     aliasing->set_operand_index(pair.second.first);
1948     for (int64 index : pair.second.second) {
1949       aliasing->add_operand_shape_index(index);
1950     }
1951     for (int64 index : pair.first) {
1952       aliasing->add_output_shape_index(index);
1953     }
1954   }
1955   return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
1956 }
1957 
CustomCall(const string & call_target_name,absl::Span<const XlaOp> operands,const XlaComputation & computation,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal)1958 XlaOp XlaBuilder::CustomCall(
1959     const string& call_target_name, absl::Span<const XlaOp> operands,
1960     const XlaComputation& computation, const Shape& shape, const string& opaque,
1961     absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
1962     bool has_side_effect,
1963     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1964         output_operand_aliasing,
1965     const Literal* literal) {
1966   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1967     HloInstructionProto instr;
1968     if (absl::StartsWith(call_target_name, "$")) {
1969       return InvalidArgument(
1970           "Invalid custom_call_target \"%s\": Call targets that start with '$' "
1971           "are reserved for internal use.",
1972           call_target_name);
1973     }
1974     *instr.mutable_shape() = shape.ToProto();
1975     instr.set_custom_call_target(call_target_name);
1976     instr.set_backend_config(opaque);
1977     if (literal != nullptr) {
1978       *instr.mutable_literal() = literal->ToProto();
1979     }
1980     if (operand_shapes_with_layout.has_value()) {
1981       if (!LayoutUtil::HasLayout(shape)) {
1982         return InvalidArgument(
1983             "Result shape must have layout for custom call with constrained "
1984             "layout.");
1985       }
1986       if (operands.size() != operand_shapes_with_layout->size()) {
1987         return InvalidArgument(
1988             "Must specify a shape with layout for each operand for custom call "
1989             "with constrained layout; given %d shapes, expected %d",
1990             operand_shapes_with_layout->size(), operands.size());
1991       }
1992       instr.set_constrain_layout(true);
1993       int64 operand_num = 0;
1994       for (const Shape& operand_shape : *operand_shapes_with_layout) {
1995         if (!LayoutUtil::HasLayout(operand_shape)) {
1996           return InvalidArgument(
1997               "No layout specified for operand %d for custom call with "
1998               "constrained layout.",
1999               operand_num);
2000         }
2001         *instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
2002         ++operand_num;
2003       }
2004     }
2005     AddCalledComputation(computation, &instr);
2006     for (const auto& pair : output_operand_aliasing) {
2007       auto aliasing = instr.add_custom_call_output_operand_aliasing();
2008       aliasing->set_operand_index(pair.second.first);
2009       for (int64 index : pair.second.second) {
2010         aliasing->add_operand_shape_index(index);
2011       }
2012       for (int64 index : pair.first) {
2013         aliasing->add_output_shape_index(index);
2014       }
2015     }
2016     return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
2017   });
2018 }
2019 
Transpose(XlaOp operand,absl::Span<const int64> permutation)2020 XlaOp XlaBuilder::Transpose(XlaOp operand,
2021                             absl::Span<const int64> permutation) {
2022   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2023     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2024     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape(
2025                                          *operand_shape, permutation));
2026     return TransposeInternal(shape, operand, permutation);
2027   });
2028 }
2029 
TransposeInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> permutation)2030 StatusOr<XlaOp> XlaBuilder::TransposeInternal(
2031     const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
2032   HloInstructionProto instr;
2033   *instr.mutable_shape() = shape.ToProto();
2034   for (int64 dim : permutation) {
2035     instr.add_dimensions(dim);
2036   }
2037   return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
2038 }
2039 
Rev(XlaOp operand,absl::Span<const int64> dimensions)2040 XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span<const int64> dimensions) {
2041   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2042     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2043     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape(
2044                                          *operand_shape, dimensions));
2045     return RevInternal(shape, operand, dimensions);
2046   });
2047 }
2048 
RevInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> dimensions)2049 StatusOr<XlaOp> XlaBuilder::RevInternal(const Shape& shape, XlaOp operand,
2050                                         absl::Span<const int64> dimensions) {
2051   HloInstructionProto instr;
2052   *instr.mutable_shape() = shape.ToProto();
2053   for (int64 dim : dimensions) {
2054     instr.add_dimensions(dim);
2055   }
2056   return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand});
2057 }
2058 
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64 dimension,bool is_stable)2059 XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
2060                        const XlaComputation& comparator, int64 dimension,
2061                        bool is_stable) {
2062   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2063     std::vector<const Shape*> operand_shape_ptrs;
2064     TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes,
2065                         GetOperandShapes(operands));
2066     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
2067                       [](const Shape& shape) { return &shape; });
2068     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape(
2069                                          HloOpcode::kSort, operand_shape_ptrs));
2070     return SortInternal(shape, operands, comparator, dimension, is_stable);
2071   });
2072 }
2073 
SortInternal(const Shape & shape,absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64 dimension,bool is_stable)2074 StatusOr<XlaOp> XlaBuilder::SortInternal(const Shape& shape,
2075                                          absl::Span<const XlaOp> operands,
2076                                          const XlaComputation& comparator,
2077                                          int64 dimension, bool is_stable) {
2078   HloInstructionProto instr;
2079   *instr.mutable_shape() = shape.ToProto();
2080   instr.set_is_stable(is_stable);
2081   if (dimension == -1) {
2082     TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0]));
2083     dimension = keys_shape->rank() - 1;
2084   }
2085   instr.add_dimensions(dimension);
2086   AddCalledComputation(comparator, &instr);
2087   return AddInstruction(std::move(instr), HloOpcode::kSort, operands);
2088 }
2089 
ConvertElementType(XlaOp operand,PrimitiveType new_element_type)2090 XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
2091                                      PrimitiveType new_element_type) {
2092   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2093     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2094     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
2095                                          *operand_shape, new_element_type));
2096     return AddOpWithShape(HloOpcode::kConvert, shape, {operand});
2097   });
2098 }
2099 
BitcastConvertType(XlaOp operand,PrimitiveType new_element_type)2100 XlaOp XlaBuilder::BitcastConvertType(XlaOp operand,
2101                                      PrimitiveType new_element_type) {
2102   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2103     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2104     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
2105                                          *operand_shape, new_element_type));
2106     return BitcastConvertTypeInternal(shape, operand);
2107   });
2108 }
2109 
BitcastConvertTypeInternal(const Shape & shape,XlaOp operand)2110 StatusOr<XlaOp> XlaBuilder::BitcastConvertTypeInternal(const Shape& shape,
2111                                                        XlaOp operand) {
2112   HloInstructionProto instr;
2113   *instr.mutable_shape() = shape.ToProto();
2114   return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
2115                         {operand});
2116 }
2117 
Clamp(XlaOp min,XlaOp operand,XlaOp max)2118 XlaOp XlaBuilder::Clamp(XlaOp min, XlaOp operand, XlaOp max) {
2119   return TernaryOp(HloOpcode::kClamp, min, operand, max);
2120 }
2121 
Map(absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64> dimensions,absl::Span<const XlaOp> static_operands)2122 XlaOp XlaBuilder::Map(absl::Span<const XlaOp> operands,
2123                       const XlaComputation& computation,
2124                       absl::Span<const int64> dimensions,
2125                       absl::Span<const XlaOp> static_operands) {
2126   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2127     if (!static_operands.empty()) {
2128       return Unimplemented("static_operands is not supported in Map");
2129     }
2130 
2131     HloInstructionProto instr;
2132     std::vector<const Shape*> operand_shape_ptrs;
2133     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
2134     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
2135                       [](const Shape& shape) { return &shape; });
2136     TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
2137                         computation.GetProgramShape());
2138     TF_ASSIGN_OR_RETURN(
2139         Shape shape, ShapeInference::InferMapShape(
2140                          operand_shape_ptrs, called_program_shape, dimensions));
2141     *instr.mutable_shape() = shape.ToProto();
2142 
2143     Shape output_shape(instr.shape());
2144     const int64 output_rank = output_shape.rank();
2145     AddCalledComputation(computation, &instr);
2146     std::vector<XlaOp> new_operands(operands.begin(), operands.end());
2147     for (XlaOp& new_operand : new_operands) {
2148       TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(new_operand));
2149       const int64 rank = shape->rank();
2150       if (rank != output_rank) {
2151         TF_ASSIGN_OR_RETURN(new_operand,
2152                             InDimBroadcast(output_shape, new_operand, {}));
2153         TF_ASSIGN_OR_RETURN(shape, GetShapePtr(new_operand));
2154       }
2155       if (!ShapeUtil::SameDimensions(output_shape, *shape)) {
2156         TF_ASSIGN_OR_RETURN(new_operand,
2157                             AddBroadcastSequence(output_shape, new_operand));
2158       }
2159     }
2160 
2161     return AddInstruction(std::move(instr), HloOpcode::kMap, new_operands);
2162   });
2163 }
2164 
RngOp(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)2165 XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
2166                         absl::Span<const XlaOp> parameters,
2167                         const Shape& shape) {
2168   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2169     // Check the number of parameters per RNG distribution.
2170     switch (distribution) {
2171       case RandomDistribution::RNG_NORMAL:
2172       case RandomDistribution::RNG_UNIFORM:
2173         if (parameters.size() != 2) {
2174           return InvalidArgument(
2175               "RNG distribution (%s) expects 2 parameters, but got %ld",
2176               RandomDistribution_Name(distribution), parameters.size());
2177         }
2178         break;
2179       default:
2180         LOG(FATAL) << "unhandled distribution " << distribution;
2181     }
2182 
2183     TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
2184     return RngOpInternal(distribution, parameters, shape);
2185   });
2186 }
2187 
RngOpInternal(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)2188 StatusOr<XlaOp> XlaBuilder::RngOpInternal(RandomDistribution distribution,
2189                                           absl::Span<const XlaOp> parameters,
2190                                           const Shape& shape) {
2191   HloInstructionProto instr;
2192   *instr.mutable_shape() = shape.ToProto();
2193   instr.set_distribution(distribution);
2194 
2195   return AddInstruction(std::move(instr), HloOpcode::kRng, parameters);
2196 }
2197 
RngNormal(XlaOp mu,XlaOp sigma,const Shape & shape)2198 XlaOp XlaBuilder::RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape) {
2199   return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
2200 }
2201 
RngUniform(XlaOp a,XlaOp b,const Shape & shape)2202 XlaOp XlaBuilder::RngUniform(XlaOp a, XlaOp b, const Shape& shape) {
2203   return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
2204 }
2205 
RngBitGenerator(RandomAlgorithm algorithm,XlaOp initial_state,const Shape & shape)2206 XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm,
2207                                   XlaOp initial_state, const Shape& shape) {
2208   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2209     TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
2210     TF_ASSIGN_OR_RETURN(Shape state_shape, GetShape(initial_state));
2211     Shape output_shape = shape;
2212     switch (output_shape.element_type()) {
2213       case PrimitiveType::F32:
2214       case PrimitiveType::S32:
2215       case PrimitiveType::U32:
2216         output_shape.set_element_type(PrimitiveType::U32);
2217         break;
2218       case PrimitiveType::F64:
2219       case PrimitiveType::S64:
2220       case PrimitiveType::U64:
2221         output_shape.set_element_type(PrimitiveType::U64);
2222         break;
2223       default:
2224         return InvalidArgument("Unsupported shape for RngBitGenerator: %s",
2225                                PrimitiveType_Name(output_shape.element_type()));
2226     }
2227     return RngBitGeneratorInternal(
2228         ShapeUtil::MakeTupleShape({state_shape, output_shape}), algorithm,
2229         initial_state);
2230   });
2231 }
2232 
RngBitGeneratorInternal(const Shape & full_result_shape,RandomAlgorithm algorithm,XlaOp initial_state)2233 StatusOr<XlaOp> XlaBuilder::RngBitGeneratorInternal(
2234     const Shape& full_result_shape, RandomAlgorithm algorithm,
2235     XlaOp initial_state) {
2236   HloInstructionProto instr;
2237   *instr.mutable_shape() = full_result_shape.ToProto();
2238   instr.set_rng_algorithm(algorithm);
2239   return AddInstruction(std::move(instr), HloOpcode::kRngBitGenerator,
2240                         {initial_state});
2241 }
2242 
While(const XlaComputation & condition,const XlaComputation & body,XlaOp init)2243 XlaOp XlaBuilder::While(const XlaComputation& condition,
2244                         const XlaComputation& body, XlaOp init) {
2245   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2246     // Infer shape.
2247     TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape());
2248     TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
2249                         condition.GetProgramShape());
2250     TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init));
2251     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape(
2252                                          condition_program_shape,
2253                                          body_program_shape, *init_shape));
2254     return WhileInternal(shape, condition, body, init);
2255   });
2256 }
2257 
WhileInternal(const Shape & shape,const XlaComputation & condition,const XlaComputation & body,XlaOp init)2258 StatusOr<XlaOp> XlaBuilder::WhileInternal(const Shape& shape,
2259                                           const XlaComputation& condition,
2260                                           const XlaComputation& body,
2261                                           XlaOp init) {
2262   HloInstructionProto instr;
2263   *instr.mutable_shape() = shape.ToProto();
2264   // Body comes before condition computation in the vector.
2265   AddCalledComputation(body, &instr);
2266   AddCalledComputation(condition, &instr);
2267   return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
2268 }
2269 
Gather(XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)2270 XlaOp XlaBuilder::Gather(XlaOp input, XlaOp start_indices,
2271                          const GatherDimensionNumbers& dimension_numbers,
2272                          absl::Span<const int64> slice_sizes,
2273                          bool indices_are_sorted) {
2274   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2275     TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
2276     TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape,
2277                         GetShapePtr(start_indices));
2278     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGatherShape(
2279                                          *input_shape, *start_indices_shape,
2280                                          dimension_numbers, slice_sizes));
2281     return GatherInternal(shape, input, start_indices, dimension_numbers,
2282                           slice_sizes, indices_are_sorted);
2283   });
2284 }
2285 
GatherInternal(const Shape & shape,XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)2286 StatusOr<XlaOp> XlaBuilder::GatherInternal(
2287     const Shape& shape, XlaOp input, XlaOp start_indices,
2288     const GatherDimensionNumbers& dimension_numbers,
2289     absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
2290   HloInstructionProto instr;
2291   instr.set_indices_are_sorted(indices_are_sorted);
2292   *instr.mutable_shape() = shape.ToProto();
2293   *instr.mutable_gather_dimension_numbers() = dimension_numbers;
2294   for (int64 bound : slice_sizes) {
2295     instr.add_gather_slice_sizes(bound);
2296   }
2297 
2298   return AddInstruction(std::move(instr), HloOpcode::kGather,
2299                         {input, start_indices});
2300 }
2301 
Scatter(XlaOp input,XlaOp scatter_indices,XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)2302 XlaOp XlaBuilder::Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
2303                           const XlaComputation& update_computation,
2304                           const ScatterDimensionNumbers& dimension_numbers,
2305                           bool indices_are_sorted, bool unique_indices) {
2306   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2307     TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
2308     TF_ASSIGN_OR_RETURN(const Shape* scatter_indices_shape,
2309                         GetShapePtr(scatter_indices));
2310     TF_ASSIGN_OR_RETURN(const Shape* updates_shape, GetShapePtr(updates));
2311     TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2312                         update_computation.GetProgramShape());
2313     TF_ASSIGN_OR_RETURN(
2314         Shape shape, ShapeInference::InferScatterShape(
2315                          *input_shape, *scatter_indices_shape, *updates_shape,
2316                          to_apply_shape, dimension_numbers));
2317     return ScatterInternal(shape, input, scatter_indices, updates,
2318                            update_computation, dimension_numbers,
2319                            indices_are_sorted, unique_indices);
2320   });
2321 }
2322 
ScatterInternal(const Shape & shape,XlaOp input,XlaOp scatter_indices,XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)2323 StatusOr<XlaOp> XlaBuilder::ScatterInternal(
2324     const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
2325     const XlaComputation& update_computation,
2326     const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
2327     bool unique_indices) {
2328   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2329     HloInstructionProto instr;
2330     instr.set_indices_are_sorted(indices_are_sorted);
2331     instr.set_unique_indices(unique_indices);
2332     *instr.mutable_shape() = shape.ToProto();
2333     *instr.mutable_scatter_dimension_numbers() = dimension_numbers;
2334 
2335     AddCalledComputation(update_computation, &instr);
2336     return AddInstruction(std::move(instr), HloOpcode::kScatter,
2337                           {input, scatter_indices, updates});
2338   });
2339 }
2340 
Conditional(XlaOp predicate,XlaOp true_operand,const XlaComputation & true_computation,XlaOp false_operand,const XlaComputation & false_computation)2341 XlaOp XlaBuilder::Conditional(XlaOp predicate, XlaOp true_operand,
2342                               const XlaComputation& true_computation,
2343                               XlaOp false_operand,
2344                               const XlaComputation& false_computation) {
2345   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2346     TF_ASSIGN_OR_RETURN(const xla::Shape* shape, GetShapePtr(predicate));
2347 
2348     if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != PRED) {
2349       return InvalidArgument(
2350           "Argument to predicated-Conditional is not a scalar of PRED type "
2351           "(%s).",
2352           ShapeUtil::HumanString(*shape));
2353     }
2354     // The index of true_computation must be 0 and that of false computation
2355     // must be 1.
2356     return ConditionalImpl(predicate, {&true_computation, &false_computation},
2357                            {true_operand, false_operand});
2358   });
2359 }
2360 
Conditional(XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)2361 XlaOp XlaBuilder::Conditional(
2362     XlaOp branch_index,
2363     absl::Span<const XlaComputation* const> branch_computations,
2364     absl::Span<const XlaOp> branch_operands) {
2365   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2366     TF_ASSIGN_OR_RETURN(const xla::Shape* shape, GetShapePtr(branch_index));
2367 
2368     if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != S32) {
2369       return InvalidArgument(
2370           "Argument to indexed-Conditional is not a scalar of S32 type (%s).",
2371           ShapeUtil::HumanString(*shape));
2372     }
2373     return ConditionalImpl(branch_index, branch_computations, branch_operands);
2374   });
2375 }
2376 
ConditionalImpl(XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)2377 XlaOp XlaBuilder::ConditionalImpl(
2378     XlaOp branch_index,
2379     absl::Span<const XlaComputation* const> branch_computations,
2380     absl::Span<const XlaOp> branch_operands) {
2381   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2382     HloInstructionProto instr;
2383 
2384     TF_ASSIGN_OR_RETURN(const Shape* branch_index_shape,
2385                         GetShapePtr(branch_index));
2386     std::vector<Shape> branch_operand_shapes(branch_operands.size());
2387     std::vector<ProgramShape> branch_computation_shapes(
2388         branch_computations.size());
2389     for (int j = 0, end = branch_operands.size(); j < end; ++j) {
2390       TF_ASSIGN_OR_RETURN(branch_operand_shapes[j],
2391                           GetShape(branch_operands[j]));
2392       TF_ASSIGN_OR_RETURN(branch_computation_shapes[j],
2393                           branch_computations[j]->GetProgramShape());
2394     }
2395     TF_ASSIGN_OR_RETURN(const Shape shape,
2396                         ShapeInference::InferConditionalShape(
2397                             *branch_index_shape, branch_computation_shapes,
2398                             branch_operand_shapes));
2399     *instr.mutable_shape() = shape.ToProto();
2400 
2401     for (const XlaComputation* branch_computation : branch_computations) {
2402       AddCalledComputation(*branch_computation, &instr);
2403     }
2404 
2405     std::vector<XlaOp> operands(1, branch_index);
2406     for (const XlaOp branch_operand : branch_operands) {
2407       operands.emplace_back(branch_operand);
2408     }
2409     return AddInstruction(std::move(instr), HloOpcode::kConditional,
2410                           absl::MakeSpan(operands));
2411   });
2412 }
2413 
CheckOpBuilder(XlaOp op) const2414 Status XlaBuilder::CheckOpBuilder(XlaOp op) const {
2415   if (this != op.builder()) {
2416     return InvalidArgument(
2417         "XlaOp with handle %d is built by builder '%s', but is trying to use "
2418         "it in builder '%s'",
2419         op.handle(), op.builder()->name(), name());
2420   }
2421   return Status::OK();
2422 }
2423 
Reduce(XlaOp operand,XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)2424 XlaOp XlaBuilder::Reduce(XlaOp operand, XlaOp init_value,
2425                          const XlaComputation& computation,
2426                          absl::Span<const int64> dimensions_to_reduce) {
2427   return Reduce(absl::Span<const XlaOp>({operand}),
2428                 absl::Span<const XlaOp>({init_value}), computation,
2429                 dimensions_to_reduce);
2430 }
2431 
Reduce(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)2432 XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
2433                          absl::Span<const XlaOp> init_values,
2434                          const XlaComputation& computation,
2435                          absl::Span<const int64> dimensions_to_reduce) {
2436   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2437     TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
2438                         computation.GetProgramShape());
2439 
2440     std::vector<XlaOp> all_operands;
2441     all_operands.insert(all_operands.end(), operands.begin(), operands.end());
2442     all_operands.insert(all_operands.end(), init_values.begin(),
2443                         init_values.end());
2444 
2445     std::vector<const Shape*> operand_shape_ptrs;
2446     TF_ASSIGN_OR_RETURN(const auto& operand_shapes,
2447                         GetOperandShapes(all_operands));
2448     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
2449                       [](const Shape& shape) { return &shape; });
2450 
2451     TF_ASSIGN_OR_RETURN(
2452         Shape shape,
2453         ShapeInference::InferReduceShape(
2454             operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
2455     return ReduceInternal(shape, all_operands, computation,
2456                           dimensions_to_reduce);
2457   });
2458 }
2459 
ReduceInternal(const Shape & shape,absl::Span<const XlaOp> all_operands,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)2460 StatusOr<XlaOp> XlaBuilder::ReduceInternal(
2461     const Shape& shape, absl::Span<const XlaOp> all_operands,
2462     const XlaComputation& computation,
2463     absl::Span<const int64> dimensions_to_reduce) {
2464   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2465     HloInstructionProto instr;
2466     *instr.mutable_shape() = shape.ToProto();
2467 
2468     for (int64 dim : dimensions_to_reduce) {
2469       instr.add_dimensions(dim);
2470     }
2471 
2472     AddCalledComputation(computation, &instr);
2473     return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands);
2474   });
2475 }
2476 
ReduceAll(XlaOp operand,XlaOp init_value,const XlaComputation & computation)2477 XlaOp XlaBuilder::ReduceAll(XlaOp operand, XlaOp init_value,
2478                             const XlaComputation& computation) {
2479   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2480     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2481     std::vector<int64> all_dimnos(operand_shape->rank());
2482     std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
2483     return Reduce(operand, init_value, computation, all_dimnos);
2484   });
2485 }
2486 
ReduceWindow(XlaOp operand,XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)2487 XlaOp XlaBuilder::ReduceWindow(XlaOp operand, XlaOp init_value,
2488                                const XlaComputation& computation,
2489                                absl::Span<const int64> window_dimensions,
2490                                absl::Span<const int64> window_strides,
2491                                Padding padding) {
2492   return ReduceWindow(absl::MakeSpan(&operand, 1),
2493                       absl::MakeSpan(&init_value, 1), computation,
2494                       window_dimensions, window_strides, padding);
2495 }
2496 
ReduceWindow(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)2497 XlaOp XlaBuilder::ReduceWindow(absl::Span<const XlaOp> operands,
2498                                absl::Span<const XlaOp> init_values,
2499                                const XlaComputation& computation,
2500                                absl::Span<const int64> window_dimensions,
2501                                absl::Span<const int64> window_strides,
2502                                Padding padding) {
2503   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2504     const Shape* operand_shape = nullptr;
2505     for (const auto& operand : operands) {
2506       TF_ASSIGN_OR_RETURN(operand_shape, GetShapePtr(operand));
2507       TF_RETURN_IF_ERROR(
2508           ValidatePaddingValues(AsInt64Slice(operand_shape->dimensions()),
2509                                 window_dimensions, window_strides));
2510     }
2511     CHECK(operand_shape != nullptr);
2512     std::vector<std::pair<int64, int64>> padding_values =
2513         MakePadding(AsInt64Slice(operand_shape->dimensions()),
2514                     window_dimensions, window_strides, padding);
2515     TF_ASSIGN_OR_RETURN(auto window,
2516                         ShapeInference::InferWindowFromDimensions(
2517                             window_dimensions, window_strides, padding_values,
2518                             /*lhs_dilation=*/{},
2519                             /*rhs_dilation=*/{}));
2520     PaddingType padding_type = PADDING_INVALID;
2521     for (int64 i = 0; i < operand_shape->rank(); ++i) {
2522       if (operand_shape->is_dynamic_dimension(i) &&
2523           !window_util::IsTrivialWindowDimension(window.dimensions(i)) &&
2524           padding == Padding::kSame) {
2525         // SAME padding can create dynamic padding sizes. The padding size
2526         // need to be rewritten by dynamic padder using HloInstructions. We
2527         // create a CustomCall to handle this.
2528         padding_type = PADDING_SAME;
2529       }
2530     }
2531     if (padding_type == PADDING_SAME) {
2532       TF_ASSIGN_OR_RETURN(
2533           HloInstructionProto instr,
2534           ReduceWindowInternal(operands, init_values, computation,
2535                                window_dimensions, window_strides, {}, {},
2536                                padding_values));
2537       instr.set_custom_call_target("DynamicReduceWindowSamePadding");
2538       std::vector<XlaOp> args;
2539       args.insert(args.end(), operands.begin(), operands.end());
2540       args.insert(args.end(), init_values.begin(), init_values.end());
2541       return AddInstruction(std::move(instr), HloOpcode::kCustomCall, args);
2542     }
2543     return ReduceWindowWithGeneralPadding(
2544         operands, init_values, computation, window_dimensions, window_strides,
2545         /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values);
2546   });
2547 }
2548 
ReduceWindowWithGeneralPadding(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)2549 XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
2550     absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
2551     const XlaComputation& computation,
2552     absl::Span<const int64> window_dimensions,
2553     absl::Span<const int64> window_strides,
2554     absl::Span<const int64> base_dilations,
2555     absl::Span<const int64> window_dilations,
2556     absl::Span<const std::pair<int64, int64>> padding) {
2557   std::vector<const Shape*> operand_shapes, init_shapes;
2558   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2559     if (operands.size() == 1) {
2560       const auto& operand = operands[0];
2561       const auto& init_value = init_values[0];
2562       TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2563       operand_shapes.push_back(operand_shape);
2564       TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
2565       init_shapes.push_back(init_shape);
2566 
2567       TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2568                           computation.GetProgramShape());
2569       TF_ASSIGN_OR_RETURN(auto window,
2570                           ShapeInference::InferWindowFromDimensions(
2571                               window_dimensions, window_strides, padding,
2572                               /*lhs_dilation=*/base_dilations,
2573                               /*rhs_dilation=*/window_dilations));
2574       TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape(
2575                                            absl::MakeSpan(operand_shapes),
2576                                            absl::MakeSpan(init_shapes), window,
2577                                            to_apply_shape));
2578       return ReduceWindowInternal(shape, operands[0], init_values[0],
2579                                   computation, window);
2580     }
2581 
2582     TF_ASSIGN_OR_RETURN(
2583         HloInstructionProto instr,
2584         ReduceWindowInternal(operands, init_values, computation,
2585                              window_dimensions, window_strides, base_dilations,
2586                              window_dilations, padding));
2587     std::vector<XlaOp> args;
2588     args.insert(args.end(), operands.begin(), operands.end());
2589     args.insert(args.end(), init_values.begin(), init_values.end());
2590     return AddInstruction(std::move(instr), HloOpcode::kReduceWindow, args);
2591   });
2592 }
2593 
ReduceWindowInternal(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)2594 StatusOr<HloInstructionProto> XlaBuilder::ReduceWindowInternal(
2595     absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
2596     const XlaComputation& computation,
2597     absl::Span<const int64> window_dimensions,
2598     absl::Span<const int64> window_strides,
2599     absl::Span<const int64> base_dilations,
2600     absl::Span<const int64> window_dilations,
2601     absl::Span<const std::pair<int64, int64>> padding) {
2602   std::vector<const Shape*> operand_shapes, init_shapes;
2603   for (int i = 0; i < operands.size(); ++i) {
2604     const auto& operand = operands[i];
2605     const auto& init_value = init_values[i];
2606     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2607     operand_shapes.push_back(operand_shape);
2608     TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
2609     init_shapes.push_back(init_shape);
2610   }
2611   TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2612                       computation.GetProgramShape());
2613   TF_ASSIGN_OR_RETURN(auto window,
2614                       ShapeInference::InferWindowFromDimensions(
2615                           window_dimensions, window_strides, padding,
2616                           /*lhs_dilation=*/base_dilations,
2617                           /*rhs_dilation=*/window_dilations));
2618   TF_ASSIGN_OR_RETURN(Shape shape,
2619                       ShapeInference::InferReduceWindowShape(
2620                           absl::MakeSpan(operand_shapes),
2621                           absl::MakeSpan(init_shapes), window, to_apply_shape));
2622   HloInstructionProto instr;
2623   *instr.mutable_shape() = shape.ToProto();
2624   *instr.mutable_window() = std::move(window);
2625   AddCalledComputation(computation, &instr);
2626   return instr;
2627 }
2628 
ReduceWindowInternal(const Shape & shape,XlaOp operand,XlaOp init_value,const XlaComputation & computation,Window window)2629 StatusOr<XlaOp> XlaBuilder::ReduceWindowInternal(
2630     const Shape& shape, XlaOp operand, XlaOp init_value,
2631     const XlaComputation& computation, Window window) {
2632   HloInstructionProto instr;
2633   *instr.mutable_shape() = shape.ToProto();
2634   *instr.mutable_window() = std::move(window);
2635 
2636   AddCalledComputation(computation, &instr);
2637   return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
2638                         {operand, init_value});
2639 }
2640 
BatchNormTraining(XlaOp operand,XlaOp scale,XlaOp offset,float epsilon,int64 feature_index)2641 XlaOp XlaBuilder::BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
2642                                     float epsilon, int64 feature_index) {
2643   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2644     HloInstructionProto instr;
2645 
2646     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2647     TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2648     TF_ASSIGN_OR_RETURN(const Shape* offset_shape, GetShapePtr(offset));
2649     TF_ASSIGN_OR_RETURN(
2650         Shape shape,
2651         ShapeInference::InferBatchNormTrainingShape(
2652             *operand_shape, *scale_shape, *offset_shape, feature_index));
2653     *instr.mutable_shape() = shape.ToProto();
2654 
2655     instr.set_epsilon(epsilon);
2656     instr.set_feature_index(feature_index);
2657 
2658     return AddInstruction(std::move(instr), HloOpcode::kBatchNormTraining,
2659                           {operand, scale, offset});
2660   });
2661 }
2662 
BatchNormInference(XlaOp operand,XlaOp scale,XlaOp offset,XlaOp mean,XlaOp variance,float epsilon,int64 feature_index)2663 XlaOp XlaBuilder::BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset,
2664                                      XlaOp mean, XlaOp variance, float epsilon,
2665                                      int64 feature_index) {
2666   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2667     HloInstructionProto instr;
2668 
2669     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2670     TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2671     TF_ASSIGN_OR_RETURN(const Shape* offset_shape, GetShapePtr(offset));
2672     TF_ASSIGN_OR_RETURN(const Shape* mean_shape, GetShapePtr(mean));
2673     TF_ASSIGN_OR_RETURN(const Shape* variance_shape, GetShapePtr(variance));
2674     TF_ASSIGN_OR_RETURN(Shape shape,
2675                         ShapeInference::InferBatchNormInferenceShape(
2676                             *operand_shape, *scale_shape, *offset_shape,
2677                             *mean_shape, *variance_shape, feature_index));
2678     *instr.mutable_shape() = shape.ToProto();
2679 
2680     instr.set_epsilon(epsilon);
2681     instr.set_feature_index(feature_index);
2682 
2683     return AddInstruction(std::move(instr), HloOpcode::kBatchNormInference,
2684                           {operand, scale, offset, mean, variance});
2685   });
2686 }
2687 
BatchNormGrad(XlaOp operand,XlaOp scale,XlaOp batch_mean,XlaOp batch_var,XlaOp grad_output,float epsilon,int64 feature_index)2688 XlaOp XlaBuilder::BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
2689                                 XlaOp batch_var, XlaOp grad_output,
2690                                 float epsilon, int64 feature_index) {
2691   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2692     HloInstructionProto instr;
2693 
2694     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2695     TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2696     TF_ASSIGN_OR_RETURN(const Shape* batch_mean_shape, GetShapePtr(batch_mean));
2697     TF_ASSIGN_OR_RETURN(const Shape* batch_var_shape, GetShapePtr(batch_var));
2698     TF_ASSIGN_OR_RETURN(const Shape* grad_output_shape,
2699                         GetShapePtr(grad_output));
2700     TF_ASSIGN_OR_RETURN(
2701         Shape shape, ShapeInference::InferBatchNormGradShape(
2702                          *operand_shape, *scale_shape, *batch_mean_shape,
2703                          *batch_var_shape, *grad_output_shape, feature_index));
2704     *instr.mutable_shape() = shape.ToProto();
2705 
2706     instr.set_epsilon(epsilon);
2707     instr.set_feature_index(feature_index);
2708 
2709     return AddInstruction(std::move(instr), HloOpcode::kBatchNormGrad,
2710                           {operand, scale, batch_mean, batch_var, grad_output});
2711   });
2712 }
2713 
AllGather(XlaOp operand,int64 all_gather_dimension,int64 shard_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Layout> & layout,const absl::optional<bool> use_global_device_ids)2714 XlaOp XlaBuilder::AllGather(XlaOp operand, int64 all_gather_dimension,
2715                             int64 shard_count,
2716                             absl::Span<const ReplicaGroup> replica_groups,
2717                             const absl::optional<ChannelHandle>& channel_id,
2718                             const absl::optional<Layout>& layout,
2719                             const absl::optional<bool> use_global_device_ids) {
2720   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2721     HloInstructionProto instr;
2722     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2723 
2724     TF_ASSIGN_OR_RETURN(Shape inferred_shape,
2725                         ShapeInference::InferAllGatherShape(
2726                             *operand_shape, all_gather_dimension, shard_count));
2727     if (layout) {
2728       *inferred_shape.mutable_layout() = *layout;
2729       instr.set_constrain_layout(true);
2730     }
2731     *instr.mutable_shape() = inferred_shape.ToProto();
2732 
2733     instr.add_dimensions(all_gather_dimension);
2734     for (const ReplicaGroup& group : replica_groups) {
2735       *instr.add_replica_groups() = group;
2736     }
2737     if (channel_id.has_value()) {
2738       instr.set_channel_id(channel_id->handle());
2739     }
2740     if (use_global_device_ids.has_value()) {
2741       instr.set_use_global_device_ids(use_global_device_ids.value());
2742     }
2743 
2744     TF_ASSIGN_OR_RETURN(
2745         auto all_gather,
2746         AddInstruction(std::move(instr), HloOpcode::kAllGather, {operand}));
2747     return all_gather;
2748   });
2749 }
2750 
CrossReplicaSum(XlaOp operand,absl::Span<const ReplicaGroup> replica_groups)2751 XlaOp XlaBuilder::CrossReplicaSum(
2752     XlaOp operand, absl::Span<const ReplicaGroup> replica_groups) {
2753   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2754     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
2755     const Shape* element_shape;
2756     if (shape->IsTuple()) {
2757       if (shape->tuple_shapes_size() == 0) {
2758         return Unimplemented(
2759             "0 element tuple CrossReplicaSum is not supported");
2760       }
2761       element_shape = &shape->tuple_shapes(0);
2762     } else {
2763       element_shape = shape;
2764     }
2765     const Shape scalar_shape =
2766         ShapeUtil::MakeShape(element_shape->element_type(), {});
2767     auto b = CreateSubBuilder("sum");
2768     auto x = b->Parameter(/*parameter_number=*/0, scalar_shape, "x");
2769     auto y = b->Parameter(/*parameter_number=*/1, scalar_shape, "y");
2770     if (scalar_shape.element_type() == PRED) {
2771       Or(x, y);
2772     } else {
2773       Add(x, y);
2774     }
2775     TF_ASSIGN_OR_RETURN(auto computation, b->Build());
2776     return AllReduce(operand, computation, replica_groups,
2777                      /*channel_id=*/absl::nullopt);
2778   });
2779 }
2780 
AllReduce(XlaOp operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Shape> & shape_with_layout)2781 XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
2782                             absl::Span<const ReplicaGroup> replica_groups,
2783                             const absl::optional<ChannelHandle>& channel_id,
2784                             const absl::optional<Shape>& shape_with_layout) {
2785   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2786     HloInstructionProto instr;
2787     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2788     std::vector<const Shape*> operand_shapes;
2789     std::vector<XlaOp> operands;
2790     if (operand_shape->IsTuple()) {
2791       if (operand_shape->tuple_shapes_size() == 0) {
2792         return Unimplemented("0 element tuple AllReduce is not supported");
2793       }
2794       for (int64 i = 0; i < operand_shape->tuple_shapes_size(); ++i) {
2795         if (operand_shape->tuple_shapes(i).element_type() !=
2796             operand_shape->tuple_shapes(0).element_type()) {
2797           return Unimplemented(
2798               "All the shapes of a tuple input of AllReduce must have the same "
2799               "element type");
2800         }
2801         operand_shapes.push_back(&operand_shape->tuple_shapes(i));
2802         operands.push_back(GetTupleElement(operand, i));
2803       }
2804     } else {
2805       operand_shapes.push_back(operand_shape);
2806       operands.push_back(operand);
2807     }
2808 
2809     TF_ASSIGN_OR_RETURN(Shape inferred_shape,
2810                         ShapeInference::InferAllReduceShape(operand_shapes));
2811     if (shape_with_layout) {
2812       if (!LayoutUtil::HasLayout(*shape_with_layout)) {
2813         return InvalidArgument("shape_with_layout must have the layout set: %s",
2814                                shape_with_layout->ToString());
2815       }
2816       if (!ShapeUtil::Compatible(*shape_with_layout, *operand_shape)) {
2817         return InvalidArgument(
2818             "Provided shape_with_layout must be compatible with the "
2819             "operand shape: %s vs %s",
2820             shape_with_layout->ToString(), operand_shape->ToString());
2821       }
2822       instr.set_constrain_layout(true);
2823       if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
2824         // For a single-element tuple, take the tuple element shape.
2825         TF_RET_CHECK(shape_with_layout->tuple_shapes_size() == 1);
2826         *instr.mutable_shape() = shape_with_layout->tuple_shapes(0).ToProto();
2827       } else {
2828         *instr.mutable_shape() = shape_with_layout->ToProto();
2829       }
2830     } else {
2831       *instr.mutable_shape() = inferred_shape.ToProto();
2832     }
2833 
2834     for (const ReplicaGroup& group : replica_groups) {
2835       *instr.add_replica_groups() = group;
2836     }
2837 
2838     if (channel_id.has_value()) {
2839       instr.set_channel_id(channel_id->handle());
2840     }
2841 
2842     AddCalledComputation(computation, &instr);
2843 
2844     TF_ASSIGN_OR_RETURN(
2845         auto all_reduce,
2846         AddInstruction(std::move(instr), HloOpcode::kAllReduce, operands));
2847     if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
2848       // For a single-element tuple, wrap the result into a tuple.
2849       TF_RET_CHECK(operand_shapes.size() == 1);
2850       TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], inferred_shape));
2851       return Tuple({all_reduce});
2852     }
2853     return all_reduce;
2854   });
2855 }
2856 
AllToAll(XlaOp operand,int64 split_dimension,int64 concat_dimension,int64 split_count,const std::vector<ReplicaGroup> & replica_groups,const absl::optional<Layout> & layout)2857 XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension,
2858                            int64 concat_dimension, int64 split_count,
2859                            const std::vector<ReplicaGroup>& replica_groups,
2860                            const absl::optional<Layout>& layout) {
2861   // Array all_to_all may need to violate layout constraint to be legal so use
2862   // the tuple version.
2863   if (layout.has_value()) {
2864     return AllToAllTuple(operand, split_dimension, concat_dimension,
2865                          split_count, replica_groups, layout);
2866   }
2867   return AllToAllArray(operand, split_dimension, concat_dimension, split_count,
2868                        replica_groups);
2869 }
2870 
AllToAllArray(XlaOp operand,int64 split_dimension,int64 concat_dimension,int64 split_count,const std::vector<ReplicaGroup> & replica_groups)2871 XlaOp XlaBuilder::AllToAllArray(
2872     XlaOp operand, int64 split_dimension, int64 concat_dimension,
2873     int64 split_count, const std::vector<ReplicaGroup>& replica_groups) {
2874   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2875     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2876     TF_ASSIGN_OR_RETURN(
2877         const Shape all_to_all_shape,
2878         ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
2879                                            concat_dimension, split_count));
2880     HloInstructionProto instr;
2881     *instr.mutable_shape() = operand_shape->ToProto();
2882     if (replica_groups.empty()) {
2883       auto* group = instr.add_replica_groups();
2884       for (int64 i = 0; i < split_count; ++i) {
2885         group->add_replica_ids(i);
2886       }
2887     } else {
2888       for (const ReplicaGroup& group : replica_groups) {
2889         *instr.add_replica_groups() = group;
2890       }
2891     }
2892     instr.add_dimensions(split_dimension);
2893     TF_ASSIGN_OR_RETURN(
2894         XlaOp all_to_all,
2895         AddInstruction(std::move(instr), HloOpcode::kAllToAll, {operand}));
2896     if (split_dimension == concat_dimension) {
2897       return all_to_all;
2898     }
2899     DimensionVector sizes;
2900     for (int64 i = 0; i < operand_shape->rank(); ++i) {
2901       if (i != split_dimension) {
2902         sizes.push_back(operand_shape->dimensions(i));
2903         continue;
2904       }
2905       sizes.push_back(split_count);
2906       sizes.push_back(operand_shape->dimensions(i) / split_count);
2907     }
2908     all_to_all = Reshape(all_to_all, sizes);
2909 
2910     std::vector<int64> permutation;
2911     for (int64 i = 0; i < operand_shape->rank(); ++i) {
2912       int64 dim_after_reshape = i >= split_dimension ? i + 1 : i;
2913       if (i == concat_dimension) {
2914         permutation.push_back(split_dimension);
2915       }
2916       permutation.push_back(dim_after_reshape);
2917     }
2918     all_to_all = Transpose(all_to_all, permutation);
2919     return Reshape(all_to_all_shape, all_to_all);
2920   });
2921 }
2922 
AllToAllTuple(XlaOp operand,int64 split_dimension,int64 concat_dimension,int64 split_count,const std::vector<ReplicaGroup> & replica_groups,const absl::optional<Layout> & layout)2923 XlaOp XlaBuilder::AllToAllTuple(XlaOp operand, int64 split_dimension,
2924                                 int64 concat_dimension, int64 split_count,
2925                                 const std::vector<ReplicaGroup>& replica_groups,
2926                                 const absl::optional<Layout>& layout) {
2927   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2928     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2929 
2930     // The HloInstruction for Alltoall currently only handles the data
2931     // communication: it accepts N already split parts and scatters them to N
2932     // cores, and each core gathers the N received parts into a tuple as the
2933     // output. So here we explicitly split the operand before the hlo alltoall,
2934     // and concat the tuple elements.
2935     //
2936     // First, run shape inference to make sure the shapes are valid.
2937     TF_RETURN_IF_ERROR(
2938         ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
2939                                            concat_dimension, split_count)
2940             .status());
2941 
2942     // Split into N parts.
2943     std::vector<XlaOp> slices;
2944     slices.reserve(split_count);
2945     const int64 block_size =
2946         operand_shape->dimensions(split_dimension) / split_count;
2947     for (int i = 0; i < split_count; i++) {
2948       slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size,
2949                                   /*limit_index=*/(i + 1) * block_size,
2950                                   /*stride=*/1, /*dimno=*/split_dimension));
2951     }
2952 
2953     // Handle data communication.
2954     HloInstructionProto instr;
2955     TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices));
2956     std::vector<const Shape*> slice_shape_ptrs;
2957     absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs),
2958                       [](const Shape& shape) { return &shape; });
2959     TF_ASSIGN_OR_RETURN(
2960         Shape shape, ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
2961 
2962     if (layout) {
2963       TF_RET_CHECK(shape.IsTuple() && !ShapeUtil::IsNestedTuple(shape));
2964       for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2965         const int64 layout_minor_to_major_size =
2966             layout->minor_to_major().size();
2967         if (layout_minor_to_major_size != shape.tuple_shapes(i).rank()) {
2968           return InvalidArgument(
2969               "Provided layout must be compatible with the operand shape: %s "
2970               "vs %s",
2971               layout->ToString(), operand_shape->ToString());
2972         }
2973         *(shape.mutable_tuple_shapes(i)->mutable_layout()) = *layout;
2974       }
2975       instr.set_constrain_layout(true);
2976     }
2977     *instr.mutable_shape() = shape.ToProto();
2978 
2979     for (const ReplicaGroup& group : replica_groups) {
2980       *instr.add_replica_groups() = group;
2981     }
2982     TF_ASSIGN_OR_RETURN(
2983         XlaOp alltoall,
2984         AddInstruction(std::move(instr), HloOpcode::kAllToAll, slices));
2985 
2986     // Concat the N received parts.
2987     std::vector<XlaOp> received;
2988     received.reserve(split_count);
2989     for (int i = 0; i < split_count; i++) {
2990       received.push_back(this->GetTupleElement(alltoall, i));
2991     }
2992     return this->ConcatInDim(received, concat_dimension);
2993   });
2994 }
2995 
CollectivePermute(XlaOp operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)2996 XlaOp XlaBuilder::CollectivePermute(
2997     XlaOp operand,
2998     const std::vector<std::pair<int64, int64>>& source_target_pairs) {
2999   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3000     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3001     HloInstructionProto instr;
3002     TF_ASSIGN_OR_RETURN(
3003         Shape shape,
3004         ShapeInference::InferCollectivePermuteShape(*operand_shape));
3005     *instr.mutable_shape() = shape.ToProto();
3006 
3007     for (const auto& pair : source_target_pairs) {
3008       auto* proto_pair = instr.add_source_target_pairs();
3009       proto_pair->set_source(pair.first);
3010       proto_pair->set_target(pair.second);
3011     }
3012 
3013     return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute,
3014                           {operand});
3015   });
3016 }
3017 
ReplicaId()3018 XlaOp XlaBuilder::ReplicaId() {
3019   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3020     HloInstructionProto instr;
3021     *instr.mutable_shape() = ShapeUtil::MakeShape(U32, {}).ToProto();
3022     return AddInstruction(std::move(instr), HloOpcode::kReplicaId, {});
3023   });
3024 }
3025 
SelectAndScatter(XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)3026 XlaOp XlaBuilder::SelectAndScatter(XlaOp operand, const XlaComputation& select,
3027                                    absl::Span<const int64> window_dimensions,
3028                                    absl::Span<const int64> window_strides,
3029                                    Padding padding, XlaOp source,
3030                                    XlaOp init_value,
3031                                    const XlaComputation& scatter) {
3032   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3033     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3034 
3035     std::vector<std::pair<int64, int64>> padding_values =
3036         MakePadding(AsInt64Slice(operand_shape->dimensions()),
3037                     window_dimensions, window_strides, padding);
3038 
3039     TF_ASSIGN_OR_RETURN(auto window,
3040                         ShapeInference::InferWindowFromDimensions(
3041                             window_dimensions, window_strides, padding_values,
3042                             /*lhs_dilation=*/{},
3043                             /*rhs_dilation=*/{}));
3044     PaddingType padding_type = PADDING_INVALID;
3045     for (int64 i = 0; i < operand_shape->rank(); ++i) {
3046       if (operand_shape->is_dynamic_dimension(i) &&
3047           !window_util::IsTrivialWindowDimension(window.dimensions(i)) &&
3048           padding == Padding::kSame) {
3049         // SAME padding can create dynamic padding sizes. The padding size
3050         // need to be rewritten by dynamic padder using HloInstructions. We
3051         // create a CustomCall to handle this.
3052         padding_type = PADDING_SAME;
3053       }
3054     }
3055     if (padding_type == PADDING_SAME) {
3056       TF_ASSIGN_OR_RETURN(
3057           HloInstructionProto instr,
3058           SelectAndScatterInternal(operand, select, window_dimensions,
3059                                    window_strides, padding_values, source,
3060                                    init_value, scatter));
3061       instr.set_custom_call_target("DynamicSelectAndScatterSamePadding");
3062       return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
3063                             {operand, source, init_value});
3064     }
3065     return SelectAndScatterWithGeneralPadding(
3066         operand, select, window_dimensions, window_strides, padding_values,
3067         source, init_value, scatter);
3068   });
3069 }
3070 
SelectAndScatterInternal(XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)3071 StatusOr<HloInstructionProto> XlaBuilder::SelectAndScatterInternal(
3072     XlaOp operand, const XlaComputation& select,
3073     absl::Span<const int64> window_dimensions,
3074     absl::Span<const int64> window_strides,
3075     absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
3076     XlaOp init_value, const XlaComputation& scatter) {
3077   HloInstructionProto instr;
3078 
3079   TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3080   TF_ASSIGN_OR_RETURN(const Shape* source_shape, GetShapePtr(source));
3081   TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
3082   TF_ASSIGN_OR_RETURN(const ProgramShape& select_shape,
3083                       select.GetProgramShape());
3084   TF_ASSIGN_OR_RETURN(const ProgramShape& scatter_shape,
3085                       scatter.GetProgramShape());
3086   TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
3087                       ShapeInference::InferWindowFromDimensions(
3088                           window_dimensions, window_strides, padding,
3089                           /*lhs_dilation=*/{}, /*rhs_dilation=*/{}));
3090   TF_ASSIGN_OR_RETURN(Shape shape,
3091                       ShapeInference::InferSelectAndScatterShape(
3092                           *operand_shape, select_shape, instr.window(),
3093                           *source_shape, *init_shape, scatter_shape));
3094   *instr.mutable_shape() = shape.ToProto();
3095 
3096   AddCalledComputation(select, &instr);
3097   AddCalledComputation(scatter, &instr);
3098   return instr;
3099 }
3100 
SelectAndScatterWithGeneralPadding(XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)3101 XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
3102     XlaOp operand, const XlaComputation& select,
3103     absl::Span<const int64> window_dimensions,
3104     absl::Span<const int64> window_strides,
3105     absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
3106     XlaOp init_value, const XlaComputation& scatter) {
3107   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3108     TF_ASSIGN_OR_RETURN(HloInstructionProto instr,
3109                         SelectAndScatterInternal(
3110                             operand, select, window_dimensions, window_strides,
3111                             padding, source, init_value, scatter));
3112 
3113     return AddInstruction(std::move(instr), HloOpcode::kSelectAndScatter,
3114                           {operand, source, init_value});
3115   });
3116 }
3117 
ReducePrecision(XlaOp operand,const int exponent_bits,const int mantissa_bits)3118 XlaOp XlaBuilder::ReducePrecision(XlaOp operand, const int exponent_bits,
3119                                   const int mantissa_bits) {
3120   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3121     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3122     TF_ASSIGN_OR_RETURN(Shape shape,
3123                         ShapeInference::InferReducePrecisionShape(
3124                             *operand_shape, exponent_bits, mantissa_bits));
3125     return ReducePrecisionInternal(shape, operand, exponent_bits,
3126                                    mantissa_bits);
3127   });
3128 }
3129 
ReducePrecisionInternal(const Shape & shape,XlaOp operand,const int exponent_bits,const int mantissa_bits)3130 StatusOr<XlaOp> XlaBuilder::ReducePrecisionInternal(const Shape& shape,
3131                                                     XlaOp operand,
3132                                                     const int exponent_bits,
3133                                                     const int mantissa_bits) {
3134   HloInstructionProto instr;
3135   *instr.mutable_shape() = shape.ToProto();
3136   instr.set_exponent_bits(exponent_bits);
3137   instr.set_mantissa_bits(mantissa_bits);
3138   return AddInstruction(std::move(instr), HloOpcode::kReducePrecision,
3139                         {operand});
3140 }
3141 
Send(XlaOp operand,const ChannelHandle & handle)3142 void XlaBuilder::Send(XlaOp operand, const ChannelHandle& handle) {
3143   ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3144     // Send HLO takes two operands: a data operand and a token. Generate the
3145     // token to pass into the send.
3146     // TODO(b/80000000): Remove this when clients have been updated to handle
3147     // tokens.
3148     HloInstructionProto token_instr;
3149     *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3150     TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
3151                                                     HloOpcode::kAfterAll, {}));
3152 
3153     return SendWithToken(operand, token, handle);
3154   });
3155 }
3156 
SendWithToken(XlaOp operand,XlaOp token,const ChannelHandle & handle)3157 XlaOp XlaBuilder::SendWithToken(XlaOp operand, XlaOp token,
3158                                 const ChannelHandle& handle) {
3159   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3160     if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
3161       return InvalidArgument("Send must use a device-to-device channel");
3162     }
3163 
3164     // Send instruction produces a tuple of {aliased operand, U32 context,
3165     // token}.
3166     HloInstructionProto send_instr;
3167     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
3168     *send_instr.mutable_shape() =
3169         ShapeUtil::MakeTupleShape({*shape, ShapeUtil::MakeShape(U32, {}),
3170                                    ShapeUtil::MakeTokenShape()})
3171             .ToProto();
3172     send_instr.set_channel_id(handle.handle());
3173     TF_ASSIGN_OR_RETURN(XlaOp send,
3174                         AddInstruction(std::move(send_instr), HloOpcode::kSend,
3175                                        {operand, token}));
3176 
3177     HloInstructionProto send_done_instr;
3178     *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3179     send_done_instr.set_channel_id(handle.handle());
3180     return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
3181                           {send});
3182   });
3183 }
3184 
Recv(const Shape & shape,const ChannelHandle & handle)3185 XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
3186   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3187     // Recv HLO takes a single token operand. Generate the token to pass into
3188     // the Recv and RecvDone instructions.
3189     // TODO(b/80000000): Remove this when clients have been updated to handle
3190     // tokens.
3191     HloInstructionProto token_instr;
3192     *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3193     TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
3194                                                     HloOpcode::kAfterAll, {}));
3195 
3196     XlaOp recv = RecvWithToken(token, shape, handle);
3197 
3198     // The RecvDone instruction produces a tuple of the data and a token
3199     // type. Return XLA op containing the data.
3200     // TODO(b/80000000): Remove this when clients have been updated to handle
3201     // tokens.
3202     HloInstructionProto recv_data;
3203     *recv_data.mutable_shape() = shape.ToProto();
3204     recv_data.set_tuple_index(0);
3205     return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement,
3206                           {recv});
3207   });
3208 }
3209 
RecvWithToken(XlaOp token,const Shape & shape,const ChannelHandle & handle)3210 XlaOp XlaBuilder::RecvWithToken(XlaOp token, const Shape& shape,
3211                                 const ChannelHandle& handle) {
3212   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3213     if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
3214       return InvalidArgument("Recv must use a device-to-device channel");
3215     }
3216 
3217     // Recv instruction produces a tuple of {receive buffer, U32 context,
3218     // token}.
3219     HloInstructionProto recv_instr;
3220     *recv_instr.mutable_shape() =
3221         ShapeUtil::MakeTupleShape(
3222             {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
3223             .ToProto();
3224     recv_instr.set_channel_id(handle.handle());
3225     TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
3226                                                    HloOpcode::kRecv, {token}));
3227 
3228     HloInstructionProto recv_done_instr;
3229     *recv_done_instr.mutable_shape() =
3230         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
3231             .ToProto();
3232     recv_done_instr.set_channel_id(handle.handle());
3233     return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
3234                           {recv});
3235   });
3236 }
3237 
SendToHost(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const ChannelHandle & handle)3238 XlaOp XlaBuilder::SendToHost(XlaOp operand, XlaOp token,
3239                              const Shape& shape_with_layout,
3240                              const ChannelHandle& handle) {
3241   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3242     if (!LayoutUtil::HasLayout(shape_with_layout)) {
3243       return InvalidArgument("Shape passed to SendToHost must have a layout");
3244     }
3245     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3246     if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
3247       return InvalidArgument(
3248           "SendToHost shape %s must be compatible with operand shape %s",
3249           ShapeUtil::HumanStringWithLayout(shape_with_layout),
3250           ShapeUtil::HumanStringWithLayout(*operand_shape));
3251     }
3252     // TODO(b/111544877): Support tuple shapes.
3253     if (!operand_shape->IsArray()) {
3254       return InvalidArgument("SendToHost only supports array shapes, shape: %s",
3255                              ShapeUtil::HumanString(*operand_shape));
3256     }
3257 
3258     if (handle.type() != ChannelHandle::DEVICE_TO_HOST) {
3259       return InvalidArgument("SendToHost must use a device-to-host channel");
3260     }
3261 
3262     // Send instruction produces a tuple of {aliased operand, U32 context,
3263     // token}.
3264     HloInstructionProto send_instr;
3265     *send_instr.mutable_shape() =
3266         ShapeUtil::MakeTupleShape({shape_with_layout,
3267                                    ShapeUtil::MakeShape(U32, {}),
3268                                    ShapeUtil::MakeTokenShape()})
3269             .ToProto();
3270     send_instr.set_channel_id(handle.handle());
3271     send_instr.set_is_host_transfer(true);
3272     TF_ASSIGN_OR_RETURN(XlaOp send,
3273                         AddInstruction(std::move(send_instr), HloOpcode::kSend,
3274                                        {operand, token}));
3275 
3276     HloInstructionProto send_done_instr;
3277     *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
3278     send_done_instr.set_channel_id(handle.handle());
3279     send_done_instr.set_is_host_transfer(true);
3280     return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
3281                           {send});
3282   });
3283 }
3284 
RecvFromHost(XlaOp token,const Shape & shape,const ChannelHandle & handle)3285 XlaOp XlaBuilder::RecvFromHost(XlaOp token, const Shape& shape,
3286                                const ChannelHandle& handle) {
3287   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3288     if (!LayoutUtil::HasLayout(shape)) {
3289       return InvalidArgument("Shape passed to RecvFromHost must have a layout");
3290     }
3291 
3292     // TODO(b/111544877): Support tuple shapes.
3293     if (!shape.IsArray()) {
3294       return InvalidArgument(
3295           "RecvFromHost only supports array shapes, shape: %s",
3296           ShapeUtil::HumanString(shape));
3297     }
3298 
3299     if (handle.type() != ChannelHandle::HOST_TO_DEVICE) {
3300       return InvalidArgument("RecvFromHost must use a host-to-device channel");
3301     }
3302 
3303     // Recv instruction produces a tuple of {receive buffer, U32 context,
3304     // token}.
3305     HloInstructionProto recv_instr;
3306     *recv_instr.mutable_shape() =
3307         ShapeUtil::MakeTupleShape(
3308             {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
3309             .ToProto();
3310     recv_instr.set_channel_id(handle.handle());
3311     recv_instr.set_is_host_transfer(true);
3312     TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
3313                                                    HloOpcode::kRecv, {token}));
3314 
3315     HloInstructionProto recv_done_instr;
3316     *recv_done_instr.mutable_shape() =
3317         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
3318             .ToProto();
3319     recv_done_instr.set_channel_id(handle.handle());
3320     recv_done_instr.set_is_host_transfer(true);
3321     return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
3322                           {recv});
3323   });
3324 }
3325 
GetDimensionSize(XlaOp operand,int64 dimension)3326 XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64 dimension) {
3327   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3328     HloInstructionProto instr;
3329     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3330     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape(
3331                                          *operand_shape, dimension));
3332     // Calling GetDimensionSize on a static dimension returns a constant
3333     // instruction.
3334     if (!operand_shape->is_dynamic_dimension(dimension)) {
3335       return ConstantR0<int32>(this, operand_shape->dimensions(dimension));
3336     }
3337     *instr.mutable_shape() = shape.ToProto();
3338     instr.add_dimensions(dimension);
3339     return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize,
3340                           {operand});
3341   });
3342 }
3343 
RemoveDynamicDimension(XlaOp operand,int64 dimension)3344 XlaOp XlaBuilder::RemoveDynamicDimension(XlaOp operand, int64 dimension) {
3345   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3346     HloInstructionProto instr;
3347     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3348 
3349     Shape shape = *operand_shape;
3350     shape.set_dynamic_dimension(dimension, false);
3351     // Setting an op's dynamic dimension to its static size removes the dynamic
3352     // dimension.
3353     XlaOp static_size =
3354         ConstantR0<int32>(this, operand_shape->dimensions(dimension));
3355 
3356     *instr.mutable_shape() = shape.ToProto();
3357     instr.add_dimensions(dimension);
3358     return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize,
3359                           {operand, static_size});
3360   });
3361 }
3362 
SetDimensionSize(XlaOp operand,XlaOp val,int64 dimension)3363 XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) {
3364   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3365     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
3366     TF_ASSIGN_OR_RETURN(const Shape* val_shape, GetShapePtr(val));
3367 
3368     TF_ASSIGN_OR_RETURN(Shape shape,
3369                         ShapeInference::InferSetDimensionSizeShape(
3370                             *operand_shape, *val_shape, dimension));
3371     return SetDimensionSizeInternal(shape, operand, val, dimension);
3372   });
3373 }
3374 
SetDimensionSizeInternal(const Shape & shape,XlaOp operand,XlaOp val,int64 dimension)3375 StatusOr<XlaOp> XlaBuilder::SetDimensionSizeInternal(const Shape& shape,
3376                                                      XlaOp operand, XlaOp val,
3377                                                      int64 dimension) {
3378   // Setting an op's dynamic dimension to the static size is a noop.
3379   TF_ASSIGN_OR_RETURN(const HloInstructionProto* val_proto,
3380                       LookUpInstruction(val));
3381   if (StringToHloOpcode(val_proto->opcode()).ValueOrDie() ==
3382       HloOpcode::kConstant) {
3383     TF_ASSIGN_OR_RETURN(auto literal,
3384                         Literal::CreateFromProto(val_proto->literal(), true));
3385     if (literal.Get<int32>({}) == shape.dimensions(dimension)) {
3386       return operand;
3387     }
3388   }
3389 
3390   HloInstructionProto instr;
3391   *instr.mutable_shape() = shape.ToProto();
3392   instr.add_dimensions(dimension);
3393   return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize,
3394                         {operand, val});
3395 }
3396 
IsConstant(XlaOp operand) const3397 StatusOr<bool> XlaBuilder::IsConstant(XlaOp operand) const {
3398   TF_RETURN_IF_ERROR(first_error_);
3399 
3400   // Verify that the handle is valid.
3401   TF_RETURN_IF_ERROR(LookUpInstruction(operand).status());
3402 
3403   bool is_constant = true;
3404   absl::flat_hash_set<int64> visited;
3405   IsConstantVisitor(operand.handle(), &visited, &is_constant);
3406   return is_constant;
3407 }
3408 
BuildDynamicInferenceGraph(XlaOp root_op)3409 StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
3410   TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
3411                       LookUpInstruction(root_op));
3412 
3413   HloComputationProto entry;
3414   SetProtoIdAndName(&entry, StrCat(name_, "_dynamic_inference"), kNameSeparator,
3415                     GetNextId());
3416   ProgramShapeProto* program_shape = entry.mutable_program_shape();
3417   *program_shape->mutable_result() =
3418       ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto();
3419 
3420   std::vector<HloComputationProto> called_computations;
3421   auto operand_is_constant = [&](const HloInstructionProto* instr_proto,
3422                                  int64 operand_index) -> StatusOr<bool> {
3423     int64 operand_id = instr_proto->operand_ids(operand_index);
3424     bool is_constant = true;
3425     absl::flat_hash_set<int64> visited;
3426     IsConstantVisitor(operand_id, &visited, &is_constant);
3427     return is_constant;
3428   };
3429   // Process instruction and copy it into the new graph. The new node in the new
3430   // graph with have id set to `id`.
3431   auto process_instruction = [&](const HloInstructionProto* instr_proto,
3432                                  bool need_rewrite, int64 id,
3433                                  absl::Span<int64 const> operand_ids,
3434                                  int64* global_id) {
3435     // Rewrite the instruction with following rules:
3436     // - Unary ops: Convert into bitcast (identity) with type Pred.
3437     // - Binary ops: Convert into binary or.
3438     // - Select: Convert into binary or with its two data operands.
3439     // - Concat / Tuple/ GTE / Bitcast: Copy.
3440     // - Param: Convert to constant True.
3441     // - GetDimensionSize: Convert to constant True if dimension is dynamic,
3442     // contant False if dimension is static.
3443     // - Reduce: Convert to reduce or.
3444     // - Constant: Convert to constant False.
3445     // - Reshape, slice, transpose, pad:
3446     //   Convert into predicate type with same opcode.
3447     // - Other ops: Not supported.
3448     // Create the instruction for the new handle.
3449     TF_ASSIGN_OR_RETURN(HloOpcode opcode,
3450                         StringToHloOpcode(instr_proto->opcode()));
3451     auto* new_instr = entry.add_instructions();
3452     *new_instr = *instr_proto;
3453     new_instr->set_id(id);
3454     new_instr->mutable_operand_ids()->Clear();
3455     for (auto operand_id : operand_ids) {
3456       new_instr->mutable_operand_ids()->Add(operand_id);
3457     }
3458 
3459     if (!need_rewrite) {
3460       *new_instr->mutable_name() =
3461           GetFullName(instr_proto->opcode(), kNameSeparator, id);
3462       if (opcode == HloOpcode::kReduce) {
3463         // Copy the reducer to the new module, with a new id that's same as the
3464         // reduce op.
3465         HloComputationProto* reducer =
3466             &embedded_[new_instr->called_computation_ids(0)];
3467         int64 reducer_id = (*global_id)++;
3468         new_instr->clear_called_computation_ids();
3469         new_instr->add_called_computation_ids(reducer_id);
3470         called_computations.push_back(CopyReducer(
3471             reducer_id, reducer, /*rewrite_into_pred=*/false, global_id));
3472       }
3473       return Status::OK();
3474     }
3475     *new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape());
3476     Shape new_shape(new_instr->shape());
3477     switch (opcode) {
3478       case HloOpcode::kAbs:
3479       case HloOpcode::kRoundNearestAfz:
3480       case HloOpcode::kBitcast:
3481       case HloOpcode::kCeil:
3482       case HloOpcode::kCollectivePermuteDone:
3483       case HloOpcode::kCos:
3484       case HloOpcode::kClz:
3485       case HloOpcode::kExp:
3486       case HloOpcode::kExpm1:
3487       case HloOpcode::kFloor:
3488       case HloOpcode::kImag:
3489       case HloOpcode::kIsFinite:
3490       case HloOpcode::kLog:
3491       case HloOpcode::kLog1p:
3492       case HloOpcode::kNot:
3493       case HloOpcode::kNegate:
3494       case HloOpcode::kPopulationCount:
3495       case HloOpcode::kReal:
3496       case HloOpcode::kRsqrt:
3497       case HloOpcode::kLogistic:
3498       case HloOpcode::kSign:
3499       case HloOpcode::kSin:
3500       case HloOpcode::kConvert:
3501       case HloOpcode::kSqrt:
3502       case HloOpcode::kCbrt:
3503       case HloOpcode::kTanh:
3504         CHECK_EQ(instr_proto->operand_ids_size(), 1);
3505         *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kBitcast);
3506         break;
3507       case HloOpcode::kAdd:
3508       case HloOpcode::kAtan2:
3509       case HloOpcode::kDivide:
3510       case HloOpcode::kComplex:
3511       case HloOpcode::kMaximum:
3512       case HloOpcode::kMinimum:
3513       case HloOpcode::kMultiply:
3514       case HloOpcode::kPower:
3515       case HloOpcode::kRemainder:
3516       case HloOpcode::kSubtract:
3517       case HloOpcode::kCompare:
3518       case HloOpcode::kAnd:
3519       case HloOpcode::kOr:
3520       case HloOpcode::kXor:
3521       case HloOpcode::kShiftLeft:
3522       case HloOpcode::kShiftRightArithmetic:
3523       case HloOpcode::kShiftRightLogical:
3524         CHECK_EQ(instr_proto->operand_ids_size(), 2);
3525         *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
3526         break;
3527       case HloOpcode::kSelect: {
3528         TF_ASSIGN_OR_RETURN(bool constant_predicate,
3529                             operand_is_constant(instr_proto, 0));
3530         if (!constant_predicate) {
3531           // If the predicate operand is not constant, conservatively assume the
3532           // entire result values are dynamic.
3533           SetInstructionAsConstant(new_instr, id, new_shape, true);
3534         }
3535         break;
3536       }
3537       case HloOpcode::kGather: {
3538         TF_ASSIGN_OR_RETURN(bool constant_indices,
3539                             operand_is_constant(instr_proto, 1));
3540         if (!constant_indices) {
3541           // If the indices operand is not constant, conservatively assume the
3542           // entire result values are dynamic.
3543           SetInstructionAsConstant(new_instr, id, new_shape, true);
3544         }
3545         break;
3546       }
3547       case HloOpcode::kReduce: {
3548         auto* reducer = &embedded_[new_instr->called_computation_ids(0)];
3549         int64 reducer_id = (*global_id)++;
3550         new_instr->clear_called_computation_ids();
3551         new_instr->add_called_computation_ids(reducer_id);
3552         called_computations.push_back(CopyReducer(
3553             reducer_id, reducer, /*rewrite_into_pred=*/true, global_id));
3554         break;
3555       }
3556       case HloOpcode::kTuple:
3557       case HloOpcode::kTranspose:
3558       case HloOpcode::kSlice:
3559       case HloOpcode::kBroadcast:
3560       case HloOpcode::kConcatenate:
3561       case HloOpcode::kReshape:
3562       case HloOpcode::kPad:
3563         break;
3564       case HloOpcode::kGetTupleElement: {
3565         // Rewrite parameter followed by gte into constants to avoid
3566         // rematerializing the tuple parameter (could be very large).
3567         int64 operand_handle = instr_proto->operand_ids(0);
3568         TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
3569                             LookUpInstructionByHandle(operand_handle));
3570         TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
3571                             StringToHloOpcode(operand_proto->opcode()));
3572         if (operand_opcode == HloOpcode::kParameter) {
3573           SetInstructionAsConstant(new_instr, id, new_shape, true);
3574         }
3575         break;
3576       }
3577       case HloOpcode::kGetDimensionSize: {
3578         int64 dimension = instr_proto->dimensions(0);
3579         int64 operand_handle = instr_proto->operand_ids(0);
3580         TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
3581                             LookUpInstructionByHandle(operand_handle));
3582 
3583         SetInstructionAsConstant(
3584             new_instr, id, new_shape,
3585             operand_proto->shape().is_dynamic_dimension(dimension));
3586         break;
3587       }
3588       case HloOpcode::kConstant:
3589       case HloOpcode::kIota:
3590         SetInstructionAsConstant(new_instr, id, new_shape, false);
3591         break;
3592       case HloOpcode::kCustomCall:
3593         if (instr_proto->custom_call_target() == "SetBound") {
3594           SetInstructionAsConstant(new_instr, id, new_shape, true);
3595           break;
3596         } else {
3597           return InvalidArgument(
3598               "Dynamic inferencing on custom call %s is not supported",
3599               instr_proto->DebugString());
3600         }
3601       case HloOpcode::kParameter:
3602         SetInstructionAsConstant(new_instr, id, new_shape, true);
3603         break;
3604       default:
3605         return InvalidArgument("Dynamic inferencing %s is not supported",
3606                                instr_proto->DebugString());
3607     }
3608     *new_instr->mutable_name() =
3609         GetFullName(instr_proto->opcode(), kNameSeparator, id);
3610     return Status::OK();
3611   };
3612 
3613   struct WorkItem {
3614     explicit WorkItem(int64 handle, bool need_rewrite)
3615         : handle(handle), need_rewrite(need_rewrite), visited(false) {}
3616     int64 handle;
3617     // If need_rewrite is true, the instruction will be copied and rewrite into
3618     // a pred instruction indicating if each value is dynamic. If need_rewrite
3619     // is false, simply copy the instruction to the output graph.
3620     // E.g.,
3621     // For select(P, A, B), we need to rewrite A and B into predicates, but
3622     // don't need to rewrite P.
3623     bool need_rewrite;
3624     // Used in dfs to remember the ids of processed operands of this item.
3625     std::vector<int64> processed_operands;
3626     // Whether this node been visited before or not.
3627     bool visited;
3628   };
3629   // Only copy each pair of {handle, need_rewrite} once. Value is the id in the
3630   // new graph.
3631   absl::flat_hash_map<std::pair<int64, bool>, int64> seen;
3632   // Monotonically increasing id to assign to new instructions.
3633   int64 global_id = 0;
3634   // The result id of the last rewritten item -- return value of last stack
3635   // item.
3636   int64 stacktop_id = -1;
3637   std::vector<WorkItem> worklist;
3638   worklist.push_back(WorkItem(root->id(), true));
3639   while (!worklist.empty()) {
3640     WorkItem& item = worklist.back();
3641     auto item_key = std::make_pair(item.handle, item.need_rewrite);
3642     auto iter = seen.find(item_key);
3643     // Already processed this item. Return previous results.
3644     if (iter != seen.end()) {
3645       stacktop_id = iter->second;
3646       worklist.pop_back();
3647       continue;
3648     }
3649 
3650     int64 next_operand = item.processed_operands.size();
3651     TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
3652                         LookUpInstructionByHandle(item.handle));
3653     VLOG(3) << "Processing" << instr_proto->name();
3654     if (!item.visited) {
3655       item.visited = true;
3656     } else {
3657       // Record previous processed operand.
3658       item.processed_operands.push_back(stacktop_id);
3659       next_operand++;
3660     }
3661     TF_ASSIGN_OR_RETURN(HloOpcode opcode,
3662                         StringToHloOpcode(instr_proto->opcode()));
3663     bool should_visit_operand = true;
3664     if (opcode == HloOpcode::kGetDimensionSize) {
3665       // We rewrite gte instructions into constants based on its operand shape
3666       // so no need to visit their operands.
3667       should_visit_operand = false;
3668     }
3669 
3670     if (opcode == HloOpcode::kGetTupleElement) {
3671       int64 operand_handle = instr_proto->operand_ids(0);
3672       TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
3673                           LookUpInstructionByHandle(operand_handle));
3674       TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
3675                           StringToHloOpcode(operand_proto->opcode()));
3676       if (operand_opcode == HloOpcode::kParameter) {
3677         // Don't rematerialize the whole parameter if it's followed by a GTE.
3678         should_visit_operand = false;
3679       }
3680     }
3681 
3682     if (opcode == HloOpcode::kSelect) {
3683       TF_ASSIGN_OR_RETURN(bool constant_predicate,
3684                           operand_is_constant(instr_proto, 0));
3685       // If the predicate operand (first operand) is non-constant, we don't
3686       // visit operands and we say the all result values are dynamic.
3687       should_visit_operand = constant_predicate;
3688     }
3689     if (opcode == HloOpcode::kGather) {
3690       TF_ASSIGN_OR_RETURN(bool constant_indices,
3691                           operand_is_constant(instr_proto, 1));
3692       // If the indices operand (second operand) is non-constant, we don't
3693       // visit operands and we say the all result values are dynamic.
3694       should_visit_operand = constant_indices;
3695     }
3696     if (opcode == HloOpcode::kGetDimensionSize && next_operand == 0) {
3697       // Always rewrite get dimension size into constant.
3698       item.need_rewrite = true;
3699     }
3700     if (next_operand >= instr_proto->operand_ids_size() ||
3701         !should_visit_operand || InstrIsSetBound(instr_proto)) {
3702       // No more operands to process, process self.
3703       int64 new_id = global_id++;
3704       VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name();
3705       TF_RETURN_IF_ERROR(process_instruction(instr_proto, item.need_rewrite,
3706                                              new_id, item.processed_operands,
3707                                              &global_id));
3708       stacktop_id = new_id;
3709       seen[item_key] = stacktop_id;
3710       worklist.pop_back();
3711     } else {
3712       // Visit and process operand.  If an operand doesn't need rewrite
3713       // (predicate of kSelect, or indices of kGather), we also don't rewrite
3714       // its ancestors.
3715       WorkItem next_item(instr_proto->operand_ids(next_operand),
3716                          item.need_rewrite);
3717       if (opcode == HloOpcode::kSelect && next_operand == 0) {
3718         next_item.need_rewrite = false;
3719       }
3720       if (opcode == HloOpcode::kGather && next_operand == 1) {
3721         next_item.need_rewrite = false;
3722       }
3723       // Push next operand into worklist.
3724       worklist.push_back(next_item);
3725     }
3726   }
3727   TF_RET_CHECK(stacktop_id != -1);
3728   entry.set_root_id(stacktop_id);
3729   absl::c_sort(*entry.mutable_instructions(),
3730                [](const HloInstructionProto& p1,
3731                   const HloInstructionProto& p2) { return p1.id() < p2.id(); });
3732   XlaComputation computation(entry.id());
3733   HloModuleProto* module = computation.mutable_proto();
3734   module->set_name(entry.name());
3735   module->set_id(entry.id());
3736   module->set_entry_computation_name(entry.name());
3737   module->set_entry_computation_id(entry.id());
3738   *module->mutable_host_program_shape() = *program_shape;
3739   for (auto& called_comp : called_computations) {
3740     *module->add_computations() = called_comp;
3741   }
3742   *module->add_computations() = std::move(entry);
3743   // Make sure all ids appear in the computation with ascending order.
3744   absl::c_sort(*module->mutable_computations(),
3745                [](const HloComputationProto& c1,
3746                   const HloComputationProto& c2) { return c1.id() < c2.id(); });
3747   XLA_VLOG_LINES(3, module->DebugString());
3748   return std::move(computation);
3749 }
3750 
BuildConstantSubGraph(XlaOp root_op,bool dynamic_dimension_is_minus_one)3751 StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
3752     XlaOp root_op, bool dynamic_dimension_is_minus_one) {
3753   TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
3754   if (!is_constant) {
3755     auto op_status = LookUpInstruction(root_op);
3756     string op_string =
3757         op_status.ok() ? op_status.ValueOrDie()->name() : "<unknown operation>";
3758     return InvalidArgument(
3759         "Operand to BuildConstantSubGraph depends on a parameter.\n\n"
3760         "  op requested for constant subgraph: %s\n\n"
3761         "This is an internal error that typically happens when the XLA user "
3762         "(e.g. TensorFlow) is attempting to determine a value that must be a "
3763         "compile-time constant (e.g. an array dimension) but it is not capable "
3764         "of being evaluated at XLA compile time.\n\n"
3765         "Please file a usability bug with the framework being used (e.g. "
3766         "TensorFlow).",
3767         op_string);
3768   }
3769 
3770   TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
3771                       LookUpInstruction(root_op));
3772 
3773   HloComputationProto entry;
3774   SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator,
3775                     GetNextId());
3776   entry.set_root_id(root->id());
3777   ProgramShapeProto* program_shape = entry.mutable_program_shape();
3778   *program_shape->mutable_result() = root->shape();
3779 
3780   // We use std::set to keep the instruction ids in ascending order (which is
3781   // also a valid dependency order). The related ops will be added to the
3782   // subgraph in the same order.
3783   std::set<int64> related_ops;
3784   absl::flat_hash_set<int64> related_calls;  // Related computations.
3785   std::queue<int64> worklist;
3786   worklist.push(root->id());
3787   related_ops.insert(root->id());
3788   while (!worklist.empty()) {
3789     int64 handle = worklist.front();
3790     worklist.pop();
3791     TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
3792                         LookUpInstructionByHandle(handle));
3793 
3794     if (instr_proto->opcode() ==
3795             HloOpcodeString(HloOpcode::kGetDimensionSize) ||
3796         InstrIsSetBound(instr_proto)) {
3797       int32 constant_value = -1;
3798       HloInstructionProto const_instr;
3799 
3800       if (instr_proto->opcode() ==
3801           HloOpcodeString(HloOpcode::kGetDimensionSize)) {
3802         // At this point, BuildConstantSubGraph should never encounter a
3803         // GetDimensionSize with a dynamic dimension. IsConstant check would
3804         // have failed at the beginning of this function.
3805         //
3806         // Replace GetDimensionSize with a Constant representing the static
3807         // bound of the shape.
3808         int64 dimension = instr_proto->dimensions(0);
3809         int64 operand_handle = instr_proto->operand_ids(0);
3810         TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
3811                             LookUpInstructionByHandle(operand_handle));
3812 
3813         if (!(operand_proto->shape().is_dynamic_dimension(dimension) &&
3814               dynamic_dimension_is_minus_one)) {
3815           constant_value =
3816               static_cast<int32>(operand_proto->shape().dimensions(dimension));
3817         }
3818         Literal literal = LiteralUtil::CreateR0(constant_value);
3819         *const_instr.mutable_literal() = literal.ToProto();
3820         *const_instr.mutable_shape() = literal.shape().ToProto();
3821       } else {
3822         *const_instr.mutable_literal() = instr_proto->literal();
3823         *const_instr.mutable_shape() = instr_proto->shape();
3824       }
3825       *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
3826       const_instr.set_id(handle);
3827       *const_instr.mutable_name() =
3828           GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id());
3829       *entry.add_instructions() =
3830           const_instr;  // Add to the result constant graph.
3831     } else {
3832       for (int64 id : instr_proto->operand_ids()) {
3833         if (related_ops.insert(id).second) {
3834           worklist.push(id);
3835         }
3836       }
3837       for (int64 called_id : instr_proto->called_computation_ids()) {
3838         related_calls.insert(called_id);
3839       }
3840     }
3841   }
3842 
3843   // Add related ops to the computation.
3844   for (int64 id : related_ops) {
3845     TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
3846                         LookUpInstructionByHandle(id));
3847 
3848     if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize)) {
3849       continue;
3850     }
3851     if (InstrIsSetBound(instr_src)) {
3852       continue;
3853     }
3854     auto* instr = entry.add_instructions();
3855 
3856     *instr = *instr_src;
3857     // Ensures that the instruction names are unique among the graph.
3858     const string& new_name =
3859         StrCat(instr->name(), ".", entry.id(), ".", instr->id());
3860     instr->set_name(new_name);
3861   }
3862 
3863   XlaComputation computation(entry.id());
3864   HloModuleProto* module = computation.mutable_proto();
3865   module->set_name(entry.name());
3866   module->set_id(entry.id());
3867   module->set_entry_computation_name(entry.name());
3868   module->set_entry_computation_id(entry.id());
3869   *module->mutable_host_program_shape() = *program_shape;
3870   for (auto& e : embedded_) {
3871     if (related_calls.find(e.second.id()) != related_calls.end()) {
3872       *module->add_computations() = e.second;
3873     }
3874   }
3875   *module->add_computations() = std::move(entry);
3876   return std::move(computation);
3877 }
3878 
CreateSubBuilder(const string & computation_name)3879 std::unique_ptr<XlaBuilder> XlaBuilder::CreateSubBuilder(
3880     const string& computation_name) {
3881   auto sub_builder = absl::make_unique<XlaBuilder>(computation_name);
3882   sub_builder->parent_builder_ = this;
3883   sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_;
3884   return sub_builder;
3885 }
3886 
3887 /* static */ ConvolutionDimensionNumbers
CreateDefaultConvDimensionNumbers(int num_spatial_dims)3888 XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
3889   ConvolutionDimensionNumbers dimension_numbers;
3890   dimension_numbers.set_input_batch_dimension(kConvBatchDimension);
3891   dimension_numbers.set_input_feature_dimension(kConvFeatureDimension);
3892   dimension_numbers.set_output_batch_dimension(kConvBatchDimension);
3893   dimension_numbers.set_output_feature_dimension(kConvFeatureDimension);
3894   dimension_numbers.set_kernel_output_feature_dimension(
3895       kConvKernelOutputDimension);
3896   dimension_numbers.set_kernel_input_feature_dimension(
3897       kConvKernelInputDimension);
3898   for (int i = 0; i < num_spatial_dims; ++i) {
3899     dimension_numbers.add_input_spatial_dimensions(i + 2);
3900     dimension_numbers.add_kernel_spatial_dimensions(i + 2);
3901     dimension_numbers.add_output_spatial_dimensions(i + 2);
3902   }
3903   return dimension_numbers;
3904 }
3905 
Validate(const ConvolutionDimensionNumbers & dnum)3906 /* static */ Status XlaBuilder::Validate(
3907     const ConvolutionDimensionNumbers& dnum) {
3908   if (dnum.input_spatial_dimensions_size() < 2) {
3909     return FailedPrecondition("input spacial dimension < 2: %d",
3910                               dnum.input_spatial_dimensions_size());
3911   }
3912   if (dnum.kernel_spatial_dimensions_size() < 2) {
3913     return FailedPrecondition("kernel spacial dimension < 2: %d",
3914                               dnum.kernel_spatial_dimensions_size());
3915   }
3916   if (dnum.output_spatial_dimensions_size() < 2) {
3917     return FailedPrecondition("output spacial dimension < 2: %d",
3918                               dnum.output_spatial_dimensions_size());
3919   }
3920 
3921   if (std::set<int64>(
3922           {dnum.input_batch_dimension(), dnum.input_feature_dimension(),
3923            dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)})
3924           .size() != 4) {
3925     return FailedPrecondition(
3926         "dimension numbers for the input are not unique: (%d, %d, %d, "
3927         "%d)",
3928         dnum.input_batch_dimension(), dnum.input_feature_dimension(),
3929         dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1));
3930   }
3931   if (std::set<int64>({dnum.kernel_output_feature_dimension(),
3932                        dnum.kernel_input_feature_dimension(),
3933                        dnum.kernel_spatial_dimensions(0),
3934                        dnum.kernel_spatial_dimensions(1)})
3935           .size() != 4) {
3936     return FailedPrecondition(
3937         "dimension numbers for the weight are not unique: (%d, %d, %d, "
3938         "%d)",
3939         dnum.kernel_output_feature_dimension(),
3940         dnum.kernel_input_feature_dimension(),
3941         dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1));
3942   }
3943   if (std::set<int64>({dnum.output_batch_dimension(),
3944                        dnum.output_feature_dimension(),
3945                        dnum.output_spatial_dimensions(0),
3946                        dnum.output_spatial_dimensions(1)})
3947           .size() != 4) {
3948     return FailedPrecondition(
3949         "dimension numbers for the output are not unique: (%d, %d, %d, "
3950         "%d)",
3951         dnum.output_batch_dimension(), dnum.output_feature_dimension(),
3952         dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1));
3953   }
3954   return Status::OK();
3955 }
3956 
AddInstruction(HloInstructionProto && instr,HloOpcode opcode,absl::Span<const XlaOp> operands)3957 StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
3958                                            HloOpcode opcode,
3959                                            absl::Span<const XlaOp> operands) {
3960   TF_RETURN_IF_ERROR(first_error_);
3961 
3962   const int64 handle = GetNextId();
3963   instr.set_id(handle);
3964   instr.set_opcode(HloOpcodeString(opcode));
3965   if (instr.name().empty()) {
3966     instr.set_name(instr.opcode());
3967   }
3968   for (const auto& operand : operands) {
3969     if (operand.builder_ == nullptr) {
3970       return InvalidArgument("invalid XlaOp with handle %d", operand.handle());
3971     }
3972     if (operand.builder_ != this) {
3973       return InvalidArgument("Do not add XlaOp from builder %s to builder %s",
3974                              operand.builder_->name(), this->name());
3975     }
3976     instr.add_operand_ids(operand.handle());
3977   }
3978 
3979   if (one_shot_metadata_.has_value()) {
3980     *instr.mutable_metadata() = one_shot_metadata_.value();
3981     one_shot_metadata_.reset();
3982   } else {
3983     *instr.mutable_metadata() = metadata_;
3984   }
3985   if (sharding_) {
3986     *instr.mutable_sharding() = *sharding_;
3987   }
3988   *instr.mutable_frontend_attributes() = frontend_attributes_;
3989 
3990   handle_to_index_[handle] = instructions_.size();
3991   instructions_.push_back(std::move(instr));
3992   instruction_shapes_.push_back(
3993       absl::make_unique<Shape>(instructions_.back().shape()));
3994 
3995   XlaOp op(handle, this);
3996   return op;
3997 }
3998 
AddOpWithShape(HloOpcode opcode,const Shape & shape,absl::Span<const XlaOp> operands)3999 StatusOr<XlaOp> XlaBuilder::AddOpWithShape(HloOpcode opcode, const Shape& shape,
4000                                            absl::Span<const XlaOp> operands) {
4001   HloInstructionProto instr;
4002   *instr.mutable_shape() = shape.ToProto();
4003   return AddInstruction(std::move(instr), opcode, operands);
4004 }
4005 
AddCalledComputation(const XlaComputation & computation,HloInstructionProto * instr)4006 void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
4007                                       HloInstructionProto* instr) {
4008   absl::flat_hash_map<int64, int64> remapped_ids;
4009   std::vector<HloComputationProto> imported_computations;
4010   imported_computations.reserve(computation.proto().computations_size());
4011   // Before we import the computations by remapping IDs, and capturing the
4012   // old->new mappings in remapped_ids.
4013   for (const HloComputationProto& e : computation.proto().computations()) {
4014     HloComputationProto new_computation(e);
4015     int64 computation_id = GetNextId();
4016     remapped_ids[new_computation.id()] = computation_id;
4017     SetProtoIdAndName(&new_computation,
4018                       GetBaseName(new_computation.name(), kNameSeparator),
4019                       kNameSeparator, computation_id);
4020     for (auto& instruction : *new_computation.mutable_instructions()) {
4021       int64 instruction_id = GetNextId();
4022       remapped_ids[instruction.id()] = instruction_id;
4023       SetProtoIdAndName(&instruction,
4024                         GetBaseName(instruction.name(), kNameSeparator),
4025                         kNameSeparator, instruction_id);
4026     }
4027     new_computation.set_root_id(remapped_ids.at(new_computation.root_id()));
4028 
4029     imported_computations.push_back(std::move(new_computation));
4030   }
4031   // Once we have imported all the computations, and captured all the ID
4032   // mappings, we go back and fixup the IDs in the imported computations.
4033   instr->add_called_computation_ids(
4034       remapped_ids.at(computation.proto().entry_computation_id()));
4035   for (auto& imported_computation : imported_computations) {
4036     for (auto& instruction : *imported_computation.mutable_instructions()) {
4037       for (auto& operand_id : *instruction.mutable_operand_ids()) {
4038         operand_id = remapped_ids.at(operand_id);
4039       }
4040       for (auto& control_predecessor_id :
4041            *instruction.mutable_control_predecessor_ids()) {
4042         control_predecessor_id = remapped_ids.at(control_predecessor_id);
4043       }
4044       for (auto& called_computation_id :
4045            *instruction.mutable_called_computation_ids()) {
4046         called_computation_id = remapped_ids.at(called_computation_id);
4047       }
4048     }
4049 
4050     int64 computation_id = imported_computation.id();
4051     embedded_.insert({computation_id, std::move(imported_computation)});
4052   }
4053 }
4054 
LookUpInstruction(const XlaOp op) const4055 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
4056     const XlaOp op) const {
4057   TF_RETURN_IF_ERROR(first_error_);
4058   return LookUpInstructionInternal<const HloInstructionProto*>(op);
4059 }
4060 
LookUpInstructionByHandle(int64 handle) const4061 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
4062     int64 handle) const {
4063   return LookUpInstructionByHandleInternal<const HloInstructionProto*>(handle);
4064 }
4065 
LookUpMutableInstruction(const XlaOp op)4066 StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstruction(
4067     const XlaOp op) {
4068   TF_RETURN_IF_ERROR(first_error_);
4069   return LookUpInstructionInternal<HloInstructionProto*>(op);
4070 }
4071 
LookUpMutableInstructionByHandle(int64 handle)4072 StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstructionByHandle(
4073     int64 handle) {
4074   return LookUpInstructionByHandleInternal<HloInstructionProto*>(handle);
4075 }
4076 
4077 // Enqueues a "retrieve parameter value" instruction for a parameter that was
4078 // passed to the computation.
Parameter(XlaBuilder * builder,int64 parameter_number,const Shape & shape,const string & name)4079 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
4080                 const string& name) {
4081   std::vector<bool> empty_bools;
4082   return Parameter(builder, parameter_number, shape, name, empty_bools);
4083 }
4084 
Parameter(XlaBuilder * builder,int64 parameter_number,const Shape & shape,const string & name,const std::vector<bool> & replicated_at_leaf_buffers)4085 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
4086                 const string& name,
4087                 const std::vector<bool>& replicated_at_leaf_buffers) {
4088   return builder->Parameter(parameter_number, shape, name,
4089                             replicated_at_leaf_buffers);
4090 }
4091 
4092 // Enqueues a constant with the value of the given literal onto the
4093 // computation.
ConstantLiteral(XlaBuilder * builder,const LiteralSlice & literal)4094 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) {
4095   return builder->ConstantLiteral(literal);
4096 }
4097 
Broadcast(const XlaOp operand,absl::Span<const int64> broadcast_sizes)4098 XlaOp Broadcast(const XlaOp operand, absl::Span<const int64> broadcast_sizes) {
4099   return operand.builder()->Broadcast(operand, broadcast_sizes);
4100 }
4101 
BroadcastInDim(const XlaOp operand,const absl::Span<const int64> out_dim_size,const absl::Span<const int64> broadcast_dimensions)4102 XlaOp BroadcastInDim(const XlaOp operand,
4103                      const absl::Span<const int64> out_dim_size,
4104                      const absl::Span<const int64> broadcast_dimensions) {
4105   return operand.builder()->BroadcastInDim(operand, out_dim_size,
4106                                            broadcast_dimensions);
4107 }
4108 
Copy(const XlaOp operand)4109 XlaOp Copy(const XlaOp operand) {
4110   return operand.builder()->UnaryOp(HloOpcode::kCopy, operand);
4111 }
4112 
Pad(const XlaOp operand,const XlaOp padding_value,const PaddingConfig & padding_config)4113 XlaOp Pad(const XlaOp operand, const XlaOp padding_value,
4114           const PaddingConfig& padding_config) {
4115   return operand.builder()->Pad(operand, padding_value, padding_config);
4116 }
4117 
PadInDim(XlaOp operand,XlaOp padding_value,int64 dimno,int64 pad_lo,int64 pad_hi)4118 XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64 dimno, int64 pad_lo,
4119                int64 pad_hi) {
4120   return operand.builder()->PadInDim(operand, padding_value, dimno, pad_lo,
4121                                      pad_hi);
4122 }
4123 
Reshape(const XlaOp operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes)4124 XlaOp Reshape(const XlaOp operand, absl::Span<const int64> dimensions,
4125               absl::Span<const int64> new_sizes) {
4126   return operand.builder()->Reshape(operand, dimensions, new_sizes);
4127 }
4128 
Reshape(const XlaOp operand,absl::Span<const int64> new_sizes)4129 XlaOp Reshape(const XlaOp operand, absl::Span<const int64> new_sizes) {
4130   return operand.builder()->Reshape(operand, new_sizes);
4131 }
4132 
Reshape(const Shape & shape,XlaOp operand)4133 XlaOp Reshape(const Shape& shape, XlaOp operand) {
4134   return operand.builder()->Reshape(shape, operand);
4135 }
4136 
DynamicReshape(XlaOp operand,absl::Span<const XlaOp> dim_sizes,absl::Span<const int64> new_size_bounds,const std::vector<bool> & dims_are_dynamic)4137 XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
4138                      absl::Span<const int64> new_size_bounds,
4139                      const std::vector<bool>& dims_are_dynamic) {
4140   return operand.builder()->DynamicReshape(operand, dim_sizes, new_size_bounds,
4141                                            dims_are_dynamic);
4142 }
4143 
ReshapeWithInferredDimension(XlaOp operand,absl::Span<const int64> new_sizes,int64 inferred_dimension)4144 XlaOp ReshapeWithInferredDimension(XlaOp operand,
4145                                    absl::Span<const int64> new_sizes,
4146                                    int64 inferred_dimension) {
4147   return operand.builder()->Reshape(operand, new_sizes, inferred_dimension);
4148 }
4149 
Collapse(const XlaOp operand,absl::Span<const int64> dimensions)4150 XlaOp Collapse(const XlaOp operand, absl::Span<const int64> dimensions) {
4151   return operand.builder()->Collapse(operand, dimensions);
4152 }
4153 
Slice(const XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)4154 XlaOp Slice(const XlaOp operand, absl::Span<const int64> start_indices,
4155             absl::Span<const int64> limit_indices,
4156             absl::Span<const int64> strides) {
4157   return operand.builder()->Slice(operand, start_indices, limit_indices,
4158                                   strides);
4159 }
4160 
SliceInDim(const XlaOp operand,int64 start_index,int64 limit_index,int64 stride,int64 dimno)4161 XlaOp SliceInDim(const XlaOp operand, int64 start_index, int64 limit_index,
4162                  int64 stride, int64 dimno) {
4163   return operand.builder()->SliceInDim(operand, start_index, limit_index,
4164                                        stride, dimno);
4165 }
4166 
DynamicSlice(const XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)4167 XlaOp DynamicSlice(const XlaOp operand, absl::Span<const XlaOp> start_indices,
4168                    absl::Span<const int64> slice_sizes) {
4169   return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
4170 }
4171 
DynamicUpdateSlice(const XlaOp operand,const XlaOp update,absl::Span<const XlaOp> start_indices)4172 XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update,
4173                          absl::Span<const XlaOp> start_indices) {
4174   return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
4175 }
4176 
ConcatInDim(XlaBuilder * builder,absl::Span<const XlaOp> operands,int64 dimension)4177 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
4178                   int64 dimension) {
4179   return builder->ConcatInDim(operands, dimension);
4180 }
4181 
Trace(const string & tag,const XlaOp operand)4182 void Trace(const string& tag, const XlaOp operand) {
4183   return operand.builder()->Trace(tag, operand);
4184 }
4185 
Select(const XlaOp pred,const XlaOp on_true,const XlaOp on_false)4186 XlaOp Select(const XlaOp pred, const XlaOp on_true, const XlaOp on_false) {
4187   return pred.builder()->Select(pred, on_true, on_false);
4188 }
4189 
Tuple(XlaBuilder * builder,absl::Span<const XlaOp> elements)4190 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements) {
4191   return builder->Tuple(elements);
4192 }
4193 
GetTupleElement(const XlaOp tuple_data,int64 index)4194 XlaOp GetTupleElement(const XlaOp tuple_data, int64 index) {
4195   return tuple_data.builder()->GetTupleElement(tuple_data, index);
4196 }
4197 
Eq(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4198 XlaOp Eq(const XlaOp lhs, const XlaOp rhs,
4199          absl::Span<const int64> broadcast_dimensions) {
4200   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
4201 }
4202 
CompareTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions,ComparisonDirection comparison_direction)4203 static XlaOp CompareTotalOrder(const XlaOp lhs, const XlaOp rhs,
4204                                absl::Span<const int64> broadcast_dimensions,
4205                                ComparisonDirection comparison_direction) {
4206   auto b = lhs.builder();
4207   return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
4208     TF_ASSIGN_OR_RETURN(auto operand_shape, b->GetShape(lhs));
4209     auto operand_element_type = operand_shape.element_type();
4210     auto compare_type =
4211         primitive_util::IsFloatingPointType(operand_element_type)
4212             ? Comparison::Type::kFloatTotalOrder
4213             : Comparison::DefaultComparisonType(operand_element_type);
4214     return Compare(lhs, rhs, broadcast_dimensions, comparison_direction,
4215                    compare_type);
4216   });
4217 }
4218 
EqTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4219 XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs,
4220                    absl::Span<const int64> broadcast_dimensions) {
4221   return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4222                            ComparisonDirection::kEq);
4223 }
4224 
Ne(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4225 XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
4226          absl::Span<const int64> broadcast_dimensions) {
4227   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe);
4228 }
4229 
NeTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4230 XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs,
4231                    absl::Span<const int64> broadcast_dimensions) {
4232   return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4233                            ComparisonDirection::kNe);
4234 }
4235 
Ge(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4236 XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
4237          absl::Span<const int64> broadcast_dimensions) {
4238   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe);
4239 }
4240 
GeTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4241 XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs,
4242                    absl::Span<const int64> broadcast_dimensions) {
4243   return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4244                            ComparisonDirection::kGe);
4245 }
4246 
Gt(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4247 XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
4248          absl::Span<const int64> broadcast_dimensions) {
4249   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt);
4250 }
4251 
GtTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4252 XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs,
4253                    absl::Span<const int64> broadcast_dimensions) {
4254   return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4255                            ComparisonDirection::kGt);
4256 }
4257 
Le(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4258 XlaOp Le(const XlaOp lhs, const XlaOp rhs,
4259          absl::Span<const int64> broadcast_dimensions) {
4260   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe);
4261 }
4262 
LeTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4263 XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs,
4264                    absl::Span<const int64> broadcast_dimensions) {
4265   return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4266                            ComparisonDirection::kLe);
4267 }
4268 
Lt(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4269 XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
4270          absl::Span<const int64> broadcast_dimensions) {
4271   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
4272 }
4273 
LtTotalOrder(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4274 XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs,
4275                    absl::Span<const int64> broadcast_dimensions) {
4276   return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
4277                            ComparisonDirection::kLt);
4278 }
4279 
Compare(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions,ComparisonDirection direction)4280 XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
4281               absl::Span<const int64> broadcast_dimensions,
4282               ComparisonDirection direction) {
4283   return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
4284                                  broadcast_dimensions, direction);
4285 }
4286 
Compare(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions,ComparisonDirection direction,Comparison::Type compare_type)4287 XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
4288               absl::Span<const int64> broadcast_dimensions,
4289               ComparisonDirection direction, Comparison::Type compare_type) {
4290   return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
4291                                  broadcast_dimensions, direction, compare_type);
4292 }
4293 
Compare(const XlaOp lhs,const XlaOp rhs,ComparisonDirection direction)4294 XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) {
4295   return Compare(lhs, rhs, {}, direction);
4296 }
4297 
Dot(const XlaOp lhs,const XlaOp rhs,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4298 XlaOp Dot(const XlaOp lhs, const XlaOp rhs,
4299           const PrecisionConfig* precision_config,
4300           absl::optional<PrimitiveType> preferred_element_type) {
4301   return lhs.builder()->Dot(lhs, rhs, precision_config, preferred_element_type);
4302 }
4303 
DotGeneral(const XlaOp lhs,const XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4304 XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs,
4305                  const DotDimensionNumbers& dimension_numbers,
4306                  const PrecisionConfig* precision_config,
4307                  absl::optional<PrimitiveType> preferred_element_type) {
4308   return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
4309                                    precision_config, preferred_element_type);
4310 }
4311 
Conv(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4312 XlaOp Conv(const XlaOp lhs, const XlaOp rhs,
4313            absl::Span<const int64> window_strides, Padding padding,
4314            int64 feature_group_count, int64 batch_group_count,
4315            const PrecisionConfig* precision_config,
4316            absl::optional<PrimitiveType> preferred_element_type) {
4317   return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
4318                              feature_group_count, batch_group_count,
4319                              precision_config, preferred_element_type);
4320 }
4321 
ConvWithGeneralPadding(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4322 XlaOp ConvWithGeneralPadding(
4323     const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
4324     absl::Span<const std::pair<int64, int64>> padding,
4325     int64 feature_group_count, int64 batch_group_count,
4326     const PrecisionConfig* precision_config,
4327     absl::optional<PrimitiveType> preferred_element_type) {
4328   return lhs.builder()->ConvWithGeneralPadding(
4329       lhs, rhs, window_strides, padding, feature_group_count, batch_group_count,
4330       precision_config, preferred_element_type);
4331 }
4332 
ConvWithGeneralDimensions(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4333 XlaOp ConvWithGeneralDimensions(
4334     const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
4335     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
4336     int64 feature_group_count, int64 batch_group_count,
4337     const PrecisionConfig* precision_config,
4338     absl::optional<PrimitiveType> preferred_element_type) {
4339   return lhs.builder()->ConvWithGeneralDimensions(
4340       lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
4341       batch_group_count, precision_config, preferred_element_type);
4342 }
4343 
ConvGeneral(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4344 XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
4345                   absl::Span<const int64> window_strides,
4346                   absl::Span<const std::pair<int64, int64>> padding,
4347                   const ConvolutionDimensionNumbers& dimension_numbers,
4348                   int64 feature_group_count, int64 batch_group_count,
4349                   const PrecisionConfig* precision_config,
4350                   absl::optional<PrimitiveType> preferred_element_type) {
4351   return lhs.builder()->ConvGeneral(
4352       lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
4353       batch_group_count, precision_config, preferred_element_type);
4354 }
4355 
ConvGeneralDilated(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,absl::optional<PrimitiveType> preferred_element_type)4356 XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
4357                          absl::Span<const int64> window_strides,
4358                          absl::Span<const std::pair<int64, int64>> padding,
4359                          absl::Span<const int64> lhs_dilation,
4360                          absl::Span<const int64> rhs_dilation,
4361                          const ConvolutionDimensionNumbers& dimension_numbers,
4362                          int64 feature_group_count, int64 batch_group_count,
4363                          const PrecisionConfig* precision_config,
4364                          absl::optional<PrimitiveType> preferred_element_type) {
4365   return lhs.builder()->ConvGeneralDilated(
4366       lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
4367       dimension_numbers, feature_group_count, batch_group_count,
4368       precision_config, preferred_element_type);
4369 }
4370 
DynamicConvInputGrad(XlaOp input_sizes,const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)4371 XlaOp DynamicConvInputGrad(
4372     XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs,
4373     absl::Span<const int64> window_strides,
4374     absl::Span<const std::pair<int64, int64>> padding,
4375     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
4376     const ConvolutionDimensionNumbers& dimension_numbers,
4377     int64 feature_group_count, int64 batch_group_count,
4378     const PrecisionConfig* precision_config, PaddingType padding_type,
4379     absl::optional<PrimitiveType> preferred_element_type) {
4380   return lhs.builder()->DynamicConvInputGrad(
4381       input_sizes, lhs, rhs, window_strides, padding, lhs_dilation,
4382       rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
4383       precision_config, padding_type, preferred_element_type);
4384 }
4385 
DynamicConvKernelGrad(XlaOp activations,XlaOp gradients,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)4386 XlaOp DynamicConvKernelGrad(
4387     XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
4388     absl::Span<const std::pair<int64, int64>> padding,
4389     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
4390     const ConvolutionDimensionNumbers& dimension_numbers,
4391     int64 feature_group_count, int64 batch_group_count,
4392     const PrecisionConfig* precision_config, PaddingType padding_type,
4393     absl::optional<PrimitiveType> preferred_element_type) {
4394   return activations.builder()->DynamicConvKernelGrad(
4395       activations, gradients, window_strides, padding, lhs_dilation,
4396       rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
4397       precision_config, padding_type, preferred_element_type);
4398 }
4399 
DynamicConvForward(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config,PaddingType padding_type,absl::optional<PrimitiveType> preferred_element_type)4400 XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
4401                          absl::Span<const int64> window_strides,
4402                          absl::Span<const std::pair<int64, int64>> padding,
4403                          absl::Span<const int64> lhs_dilation,
4404                          absl::Span<const int64> rhs_dilation,
4405                          const ConvolutionDimensionNumbers& dimension_numbers,
4406                          int64 feature_group_count, int64 batch_group_count,
4407                          const PrecisionConfig* precision_config,
4408                          PaddingType padding_type,
4409                          absl::optional<PrimitiveType> preferred_element_type) {
4410   return lhs.builder()->DynamicConvForward(
4411       lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
4412       dimension_numbers, feature_group_count, batch_group_count,
4413       precision_config, padding_type, preferred_element_type);
4414 }
4415 
Fft(const XlaOp operand,FftType fft_type,absl::Span<const int64> fft_length)4416 XlaOp Fft(const XlaOp operand, FftType fft_type,
4417           absl::Span<const int64> fft_length) {
4418   return operand.builder()->Fft(operand, fft_type, fft_length);
4419 }
4420 
TriangularSolve(XlaOp a,XlaOp b,bool left_side,bool lower,bool unit_diagonal,TriangularSolveOptions::Transpose transpose_a)4421 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
4422                       bool unit_diagonal,
4423                       TriangularSolveOptions::Transpose transpose_a) {
4424   XlaBuilder* builder = a.builder();
4425   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
4426     TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a));
4427     TF_ASSIGN_OR_RETURN(const Shape* b_shape, builder->GetShapePtr(b));
4428     xla::TriangularSolveOptions options;
4429     options.set_left_side(left_side);
4430     options.set_lower(lower);
4431     options.set_unit_diagonal(unit_diagonal);
4432     options.set_transpose_a(transpose_a);
4433     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTriangularSolveShape(
4434                                          *a_shape, *b_shape, options));
4435     return builder->TriangularSolveInternal(shape, a, b, std::move(options));
4436   });
4437 }
4438 
Cholesky(XlaOp a,bool lower)4439 XlaOp Cholesky(XlaOp a, bool lower) {
4440   XlaBuilder* builder = a.builder();
4441   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
4442     TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a));
4443     TF_ASSIGN_OR_RETURN(Shape shape,
4444                         ShapeInference::InferCholeskyShape(*a_shape));
4445     return builder->CholeskyInternal(shape, a, lower);
4446   });
4447 }
4448 
Infeed(XlaBuilder * builder,const Shape & shape,const string & config)4449 XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) {
4450   return builder->Infeed(shape, config);
4451 }
4452 
Outfeed(const XlaOp operand,const Shape & shape_with_layout,const string & outfeed_config)4453 void Outfeed(const XlaOp operand, const Shape& shape_with_layout,
4454              const string& outfeed_config) {
4455   return operand.builder()->Outfeed(operand, shape_with_layout, outfeed_config);
4456 }
4457 
Call(XlaBuilder * builder,const XlaComputation & computation,absl::Span<const XlaOp> operands)4458 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
4459            absl::Span<const XlaOp> operands) {
4460   return builder->Call(computation, operands);
4461 }
4462 
CustomCall(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal)4463 XlaOp CustomCall(
4464     XlaBuilder* builder, const string& call_target_name,
4465     absl::Span<const XlaOp> operands, const Shape& shape, const string& opaque,
4466     bool has_side_effect,
4467     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
4468         output_operand_aliasing,
4469     const Literal* literal) {
4470   return builder->CustomCall(call_target_name, operands, shape, opaque,
4471                              /*operand_shapes_with_layout=*/absl::nullopt,
4472                              has_side_effect, output_operand_aliasing, literal);
4473 }
4474 
CustomCallWithComputation(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const XlaComputation & computation,const Shape & shape,const string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal)4475 XlaOp CustomCallWithComputation(
4476     XlaBuilder* builder, const string& call_target_name,
4477     absl::Span<const XlaOp> operands, const XlaComputation& computation,
4478     const Shape& shape, const string& opaque, bool has_side_effect,
4479     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
4480         output_operand_aliasing,
4481     const Literal* literal) {
4482   return builder->CustomCall(call_target_name, operands, computation, shape,
4483                              opaque,
4484                              /*operand_shapes_with_layout=*/absl::nullopt,
4485                              has_side_effect, output_operand_aliasing, literal);
4486 }
4487 
CustomCallWithLayout(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,absl::Span<const Shape> operand_shapes_with_layout,const string & opaque,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal)4488 XlaOp CustomCallWithLayout(
4489     XlaBuilder* builder, const string& call_target_name,
4490     absl::Span<const XlaOp> operands, const Shape& shape,
4491     absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
4492     bool has_side_effect,
4493     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
4494         output_operand_aliasing,
4495     const Literal* literal) {
4496   return builder->CustomCall(call_target_name, operands, shape, opaque,
4497                              operand_shapes_with_layout, has_side_effect,
4498                              output_operand_aliasing, literal);
4499 }
4500 
Complex(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4501 XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
4502               absl::Span<const int64> broadcast_dimensions) {
4503   return lhs.builder()->BinaryOp(HloOpcode::kComplex, lhs, rhs,
4504                                  broadcast_dimensions);
4505 }
4506 
Conj(const XlaOp operand)4507 XlaOp Conj(const XlaOp operand) {
4508   return Complex(Real(operand), Neg(Imag(operand)));
4509 }
4510 
Add(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4511 XlaOp Add(const XlaOp lhs, const XlaOp rhs,
4512           absl::Span<const int64> broadcast_dimensions) {
4513   return lhs.builder()->BinaryOp(HloOpcode::kAdd, lhs, rhs,
4514                                  broadcast_dimensions);
4515 }
4516 
Sub(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4517 XlaOp Sub(const XlaOp lhs, const XlaOp rhs,
4518           absl::Span<const int64> broadcast_dimensions) {
4519   return lhs.builder()->BinaryOp(HloOpcode::kSubtract, lhs, rhs,
4520                                  broadcast_dimensions);
4521 }
4522 
Mul(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4523 XlaOp Mul(const XlaOp lhs, const XlaOp rhs,
4524           absl::Span<const int64> broadcast_dimensions) {
4525   return lhs.builder()->BinaryOp(HloOpcode::kMultiply, lhs, rhs,
4526                                  broadcast_dimensions);
4527 }
4528 
Div(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4529 XlaOp Div(const XlaOp lhs, const XlaOp rhs,
4530           absl::Span<const int64> broadcast_dimensions) {
4531   return lhs.builder()->BinaryOp(HloOpcode::kDivide, lhs, rhs,
4532                                  broadcast_dimensions);
4533 }
4534 
Rem(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4535 XlaOp Rem(const XlaOp lhs, const XlaOp rhs,
4536           absl::Span<const int64> broadcast_dimensions) {
4537   return lhs.builder()->BinaryOp(HloOpcode::kRemainder, lhs, rhs,
4538                                  broadcast_dimensions);
4539 }
4540 
Max(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4541 XlaOp Max(const XlaOp lhs, const XlaOp rhs,
4542           absl::Span<const int64> broadcast_dimensions) {
4543   return lhs.builder()->BinaryOp(HloOpcode::kMaximum, lhs, rhs,
4544                                  broadcast_dimensions);
4545 }
4546 
Min(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4547 XlaOp Min(const XlaOp lhs, const XlaOp rhs,
4548           absl::Span<const int64> broadcast_dimensions) {
4549   return lhs.builder()->BinaryOp(HloOpcode::kMinimum, lhs, rhs,
4550                                  broadcast_dimensions);
4551 }
4552 
And(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4553 XlaOp And(const XlaOp lhs, const XlaOp rhs,
4554           absl::Span<const int64> broadcast_dimensions) {
4555   return lhs.builder()->BinaryOp(HloOpcode::kAnd, lhs, rhs,
4556                                  broadcast_dimensions);
4557 }
4558 
Or(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4559 XlaOp Or(const XlaOp lhs, const XlaOp rhs,
4560          absl::Span<const int64> broadcast_dimensions) {
4561   return lhs.builder()->BinaryOp(HloOpcode::kOr, lhs, rhs,
4562                                  broadcast_dimensions);
4563 }
4564 
Xor(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4565 XlaOp Xor(const XlaOp lhs, const XlaOp rhs,
4566           absl::Span<const int64> broadcast_dimensions) {
4567   return lhs.builder()->BinaryOp(HloOpcode::kXor, lhs, rhs,
4568                                  broadcast_dimensions);
4569 }
4570 
Not(const XlaOp operand)4571 XlaOp Not(const XlaOp operand) {
4572   return operand.builder()->UnaryOp(HloOpcode::kNot, operand);
4573 }
4574 
PopulationCount(const XlaOp operand)4575 XlaOp PopulationCount(const XlaOp operand) {
4576   return operand.builder()->UnaryOp(HloOpcode::kPopulationCount, operand);
4577 }
4578 
ShiftLeft(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4579 XlaOp ShiftLeft(const XlaOp lhs, const XlaOp rhs,
4580                 absl::Span<const int64> broadcast_dimensions) {
4581   return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs,
4582                                  broadcast_dimensions);
4583 }
4584 
ShiftRightArithmetic(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4585 XlaOp ShiftRightArithmetic(const XlaOp lhs, const XlaOp rhs,
4586                            absl::Span<const int64> broadcast_dimensions) {
4587   return lhs.builder()->BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs,
4588                                  broadcast_dimensions);
4589 }
4590 
ShiftRightLogical(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4591 XlaOp ShiftRightLogical(const XlaOp lhs, const XlaOp rhs,
4592                         absl::Span<const int64> broadcast_dimensions) {
4593   return lhs.builder()->BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs,
4594                                  broadcast_dimensions);
4595 }
4596 
Reduce(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)4597 XlaOp Reduce(const XlaOp operand, const XlaOp init_value,
4598              const XlaComputation& computation,
4599              absl::Span<const int64> dimensions_to_reduce) {
4600   return operand.builder()->Reduce(operand, init_value, computation,
4601                                    dimensions_to_reduce);
4602 }
4603 
4604 // Reduces several arrays simultaneously among the provided dimensions, given
4605 // "computation" as a reduction operator.
Reduce(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)4606 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
4607              absl::Span<const XlaOp> init_values,
4608              const XlaComputation& computation,
4609              absl::Span<const int64> dimensions_to_reduce) {
4610   return builder->Reduce(operands, init_values, computation,
4611                          dimensions_to_reduce);
4612 }
4613 
ReduceAll(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation)4614 XlaOp ReduceAll(const XlaOp operand, const XlaOp init_value,
4615                 const XlaComputation& computation) {
4616   return operand.builder()->ReduceAll(operand, init_value, computation);
4617 }
4618 
ReduceWindow(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)4619 XlaOp ReduceWindow(const XlaOp operand, const XlaOp init_value,
4620                    const XlaComputation& computation,
4621                    absl::Span<const int64> window_dimensions,
4622                    absl::Span<const int64> window_strides, Padding padding) {
4623   return operand.builder()->ReduceWindow(operand, init_value, computation,
4624                                          window_dimensions, window_strides,
4625                                          padding);
4626 }
4627 
ReduceWindow(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)4628 XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
4629                    absl::Span<const XlaOp> init_values,
4630                    const XlaComputation& computation,
4631                    absl::Span<const int64> window_dimensions,
4632                    absl::Span<const int64> window_strides, Padding padding) {
4633   CHECK(!operands.empty());
4634   return operands[0].builder()->ReduceWindow(operands, init_values, computation,
4635                                              window_dimensions, window_strides,
4636                                              padding);
4637 }
4638 
ReduceWindowWithGeneralPadding(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)4639 XlaOp ReduceWindowWithGeneralPadding(
4640     const XlaOp operand, const XlaOp init_value,
4641     const XlaComputation& computation,
4642     absl::Span<const int64> window_dimensions,
4643     absl::Span<const int64> window_strides,
4644     absl::Span<const int64> base_dilations,
4645     absl::Span<const int64> window_dilations,
4646     absl::Span<const std::pair<int64, int64>> padding) {
4647   return operand.builder()->ReduceWindowWithGeneralPadding(
4648       absl::MakeSpan(&operand, 1), absl::MakeSpan(&init_value, 1), computation,
4649       window_dimensions, window_strides, base_dilations, window_dilations,
4650       padding);
4651 }
4652 
AllGather(const XlaOp operand,int64 all_gather_dimension,int64 shard_count,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Layout> & layout,const absl::optional<bool> use_global_device_ids)4653 XlaOp AllGather(const XlaOp operand, int64 all_gather_dimension,
4654                 int64 shard_count,
4655                 absl::Span<const ReplicaGroup> replica_groups,
4656                 const absl::optional<ChannelHandle>& channel_id,
4657                 const absl::optional<Layout>& layout,
4658                 const absl::optional<bool> use_global_device_ids) {
4659   return operand.builder()->AllGather(operand, all_gather_dimension,
4660                                       shard_count, replica_groups, channel_id,
4661                                       layout, use_global_device_ids);
4662 }
4663 
CrossReplicaSum(const XlaOp operand,absl::Span<const ReplicaGroup> replica_groups)4664 XlaOp CrossReplicaSum(const XlaOp operand,
4665                       absl::Span<const ReplicaGroup> replica_groups) {
4666   return operand.builder()->CrossReplicaSum(operand, replica_groups);
4667 }
4668 
AllReduce(const XlaOp operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Shape> & shape_with_layout)4669 XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
4670                 absl::Span<const ReplicaGroup> replica_groups,
4671                 const absl::optional<ChannelHandle>& channel_id,
4672                 const absl::optional<Shape>& shape_with_layout) {
4673   return operand.builder()->AllReduce(operand, computation, replica_groups,
4674                                       channel_id, shape_with_layout);
4675 }
4676 
AllToAll(const XlaOp operand,int64 split_dimension,int64 concat_dimension,int64 split_count,const std::vector<ReplicaGroup> & replica_groups,const absl::optional<Layout> & layout)4677 XlaOp AllToAll(const XlaOp operand, int64 split_dimension,
4678                int64 concat_dimension, int64 split_count,
4679                const std::vector<ReplicaGroup>& replica_groups,
4680                const absl::optional<Layout>& layout) {
4681   return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
4682                                      split_count, replica_groups, layout);
4683 }
4684 
AllToAllTuple(const XlaOp operand,int64 split_dimension,int64 concat_dimension,int64 split_count,const std::vector<ReplicaGroup> & replica_groups,const absl::optional<Layout> & layout)4685 XlaOp AllToAllTuple(const XlaOp operand, int64 split_dimension,
4686                     int64 concat_dimension, int64 split_count,
4687                     const std::vector<ReplicaGroup>& replica_groups,
4688                     const absl::optional<Layout>& layout) {
4689   return operand.builder()->AllToAllTuple(operand, split_dimension,
4690                                           concat_dimension, split_count,
4691                                           replica_groups, layout);
4692 }
4693 
CollectivePermute(const XlaOp operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)4694 XlaOp CollectivePermute(
4695     const XlaOp operand,
4696     const std::vector<std::pair<int64, int64>>& source_target_pairs) {
4697   return operand.builder()->CollectivePermute(operand, source_target_pairs);
4698 }
4699 
ReplicaId(XlaBuilder * builder)4700 XlaOp ReplicaId(XlaBuilder* builder) { return builder->ReplicaId(); }
4701 
SelectAndScatter(const XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding,const XlaOp source,const XlaOp init_value,const XlaComputation & scatter)4702 XlaOp SelectAndScatter(const XlaOp operand, const XlaComputation& select,
4703                        absl::Span<const int64> window_dimensions,
4704                        absl::Span<const int64> window_strides, Padding padding,
4705                        const XlaOp source, const XlaOp init_value,
4706                        const XlaComputation& scatter) {
4707   return operand.builder()->SelectAndScatter(operand, select, window_dimensions,
4708                                              window_strides, padding, source,
4709                                              init_value, scatter);
4710 }
4711 
SelectAndScatterWithGeneralPadding(const XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const XlaOp source,const XlaOp init_value,const XlaComputation & scatter)4712 XlaOp SelectAndScatterWithGeneralPadding(
4713     const XlaOp operand, const XlaComputation& select,
4714     absl::Span<const int64> window_dimensions,
4715     absl::Span<const int64> window_strides,
4716     absl::Span<const std::pair<int64, int64>> padding, const XlaOp source,
4717     const XlaOp init_value, const XlaComputation& scatter) {
4718   return operand.builder()->SelectAndScatterWithGeneralPadding(
4719       operand, select, window_dimensions, window_strides, padding, source,
4720       init_value, scatter);
4721 }
4722 
Abs(const XlaOp operand)4723 XlaOp Abs(const XlaOp operand) {
4724   return operand.builder()->UnaryOp(HloOpcode::kAbs, operand);
4725 }
4726 
Atan2(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4727 XlaOp Atan2(const XlaOp lhs, const XlaOp rhs,
4728             absl::Span<const int64> broadcast_dimensions) {
4729   return lhs.builder()->BinaryOp(HloOpcode::kAtan2, lhs, rhs,
4730                                  broadcast_dimensions);
4731 }
4732 
Exp(const XlaOp operand)4733 XlaOp Exp(const XlaOp operand) {
4734   return operand.builder()->UnaryOp(HloOpcode::kExp, operand);
4735 }
Expm1(const XlaOp operand)4736 XlaOp Expm1(const XlaOp operand) {
4737   return operand.builder()->UnaryOp(HloOpcode::kExpm1, operand);
4738 }
Floor(const XlaOp operand)4739 XlaOp Floor(const XlaOp operand) {
4740   return operand.builder()->UnaryOp(HloOpcode::kFloor, operand);
4741 }
Ceil(const XlaOp operand)4742 XlaOp Ceil(const XlaOp operand) {
4743   return operand.builder()->UnaryOp(HloOpcode::kCeil, operand);
4744 }
Round(const XlaOp operand)4745 XlaOp Round(const XlaOp operand) {
4746   return operand.builder()->UnaryOp(HloOpcode::kRoundNearestAfz, operand);
4747 }
Log(const XlaOp operand)4748 XlaOp Log(const XlaOp operand) {
4749   return operand.builder()->UnaryOp(HloOpcode::kLog, operand);
4750 }
Log1p(const XlaOp operand)4751 XlaOp Log1p(const XlaOp operand) {
4752   return operand.builder()->UnaryOp(HloOpcode::kLog1p, operand);
4753 }
Logistic(const XlaOp operand)4754 XlaOp Logistic(const XlaOp operand) {
4755   return operand.builder()->UnaryOp(HloOpcode::kLogistic, operand);
4756 }
Sign(const XlaOp operand)4757 XlaOp Sign(const XlaOp operand) {
4758   return operand.builder()->UnaryOp(HloOpcode::kSign, operand);
4759 }
Clz(const XlaOp operand)4760 XlaOp Clz(const XlaOp operand) {
4761   return operand.builder()->UnaryOp(HloOpcode::kClz, operand);
4762 }
Cos(const XlaOp operand)4763 XlaOp Cos(const XlaOp operand) {
4764   return operand.builder()->UnaryOp(HloOpcode::kCos, operand);
4765 }
Sin(const XlaOp operand)4766 XlaOp Sin(const XlaOp operand) {
4767   return operand.builder()->UnaryOp(HloOpcode::kSin, operand);
4768 }
Tanh(const XlaOp operand)4769 XlaOp Tanh(const XlaOp operand) {
4770   return operand.builder()->UnaryOp(HloOpcode::kTanh, operand);
4771 }
Real(const XlaOp operand)4772 XlaOp Real(const XlaOp operand) {
4773   return operand.builder()->UnaryOp(HloOpcode::kReal, operand);
4774 }
Imag(const XlaOp operand)4775 XlaOp Imag(const XlaOp operand) {
4776   return operand.builder()->UnaryOp(HloOpcode::kImag, operand);
4777 }
Sqrt(const XlaOp operand)4778 XlaOp Sqrt(const XlaOp operand) {
4779   return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand);
4780 }
Cbrt(const XlaOp operand)4781 XlaOp Cbrt(const XlaOp operand) {
4782   return operand.builder()->UnaryOp(HloOpcode::kCbrt, operand);
4783 }
Rsqrt(const XlaOp operand)4784 XlaOp Rsqrt(const XlaOp operand) {
4785   return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand);
4786 }
4787 
Pow(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)4788 XlaOp Pow(const XlaOp lhs, const XlaOp rhs,
4789           absl::Span<const int64> broadcast_dimensions) {
4790   return lhs.builder()->BinaryOp(HloOpcode::kPower, lhs, rhs,
4791                                  broadcast_dimensions);
4792 }
4793 
IsFinite(const XlaOp operand)4794 XlaOp IsFinite(const XlaOp operand) {
4795   return operand.builder()->UnaryOp(HloOpcode::kIsFinite, operand);
4796 }
4797 
ConvertElementType(const XlaOp operand,PrimitiveType new_element_type)4798 XlaOp ConvertElementType(const XlaOp operand, PrimitiveType new_element_type) {
4799   return operand.builder()->ConvertElementType(operand, new_element_type);
4800 }
4801 
BitcastConvertType(const XlaOp operand,PrimitiveType new_element_type)4802 XlaOp BitcastConvertType(const XlaOp operand, PrimitiveType new_element_type) {
4803   return operand.builder()->BitcastConvertType(operand, new_element_type);
4804 }
4805 
Neg(const XlaOp operand)4806 XlaOp Neg(const XlaOp operand) {
4807   return operand.builder()->UnaryOp(HloOpcode::kNegate, operand);
4808 }
4809 
Transpose(const XlaOp operand,absl::Span<const int64> permutation)4810 XlaOp Transpose(const XlaOp operand, absl::Span<const int64> permutation) {
4811   return operand.builder()->Transpose(operand, permutation);
4812 }
4813 
Rev(const XlaOp operand,absl::Span<const int64> dimensions)4814 XlaOp Rev(const XlaOp operand, absl::Span<const int64> dimensions) {
4815   return operand.builder()->Rev(operand, dimensions);
4816 }
4817 
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64 dimension,bool is_stable)4818 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
4819            int64 dimension, bool is_stable) {
4820   return operands[0].builder()->Sort(operands, comparator, dimension,
4821                                      is_stable);
4822 }
4823 
Clamp(const XlaOp min,const XlaOp operand,const XlaOp max)4824 XlaOp Clamp(const XlaOp min, const XlaOp operand, const XlaOp max) {
4825   return min.builder()->Clamp(min, operand, max);
4826 }
4827 
Map(XlaBuilder * builder,absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64> dimensions,absl::Span<const XlaOp> static_operands)4828 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
4829           const XlaComputation& computation, absl::Span<const int64> dimensions,
4830           absl::Span<const XlaOp> static_operands) {
4831   return builder->Map(operands, computation, dimensions, static_operands);
4832 }
4833 
RngNormal(const XlaOp mu,const XlaOp sigma,const Shape & shape)4834 XlaOp RngNormal(const XlaOp mu, const XlaOp sigma, const Shape& shape) {
4835   return mu.builder()->RngNormal(mu, sigma, shape);
4836 }
4837 
RngUniform(const XlaOp a,const XlaOp b,const Shape & shape)4838 XlaOp RngUniform(const XlaOp a, const XlaOp b, const Shape& shape) {
4839   return a.builder()->RngUniform(a, b, shape);
4840 }
4841 
RngBitGenerator(RandomAlgorithm algorithm,const XlaOp initial_state,const Shape & shape)4842 XlaOp RngBitGenerator(RandomAlgorithm algorithm, const XlaOp initial_state,
4843                       const Shape& shape) {
4844   return initial_state.builder()->RngBitGenerator(algorithm, initial_state,
4845                                                   shape);
4846 }
4847 
While(const XlaComputation & condition,const XlaComputation & body,const XlaOp init)4848 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
4849             const XlaOp init) {
4850   return init.builder()->While(condition, body, init);
4851 }
4852 
Conditional(const XlaOp predicate,const XlaOp true_operand,const XlaComputation & true_computation,const XlaOp false_operand,const XlaComputation & false_computation)4853 XlaOp Conditional(const XlaOp predicate, const XlaOp true_operand,
4854                   const XlaComputation& true_computation,
4855                   const XlaOp false_operand,
4856                   const XlaComputation& false_computation) {
4857   return predicate.builder()->Conditional(predicate, true_operand,
4858                                           true_computation, false_operand,
4859                                           false_computation);
4860 }
4861 
Conditional(const XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)4862 XlaOp Conditional(const XlaOp branch_index,
4863                   absl::Span<const XlaComputation* const> branch_computations,
4864                   absl::Span<const XlaOp> branch_operands) {
4865   return branch_index.builder()->Conditional(branch_index, branch_computations,
4866                                              branch_operands);
4867 }
4868 
ReducePrecision(const XlaOp operand,const int exponent_bits,const int mantissa_bits)4869 XlaOp ReducePrecision(const XlaOp operand, const int exponent_bits,
4870                       const int mantissa_bits) {
4871   return operand.builder()->ReducePrecision(operand, exponent_bits,
4872                                             mantissa_bits);
4873 }
4874 
Gather(const XlaOp input,const XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)4875 XlaOp Gather(const XlaOp input, const XlaOp start_indices,
4876              const GatherDimensionNumbers& dimension_numbers,
4877              absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
4878   return input.builder()->Gather(input, start_indices, dimension_numbers,
4879                                  slice_sizes, indices_are_sorted);
4880 }
4881 
Scatter(const XlaOp input,const XlaOp scatter_indices,const XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)4882 XlaOp Scatter(const XlaOp input, const XlaOp scatter_indices,
4883               const XlaOp updates, const XlaComputation& update_computation,
4884               const ScatterDimensionNumbers& dimension_numbers,
4885               bool indices_are_sorted, bool unique_indices) {
4886   return input.builder()->Scatter(input, scatter_indices, updates,
4887                                   update_computation, dimension_numbers,
4888                                   indices_are_sorted, unique_indices);
4889 }
4890 
Send(const XlaOp operand,const ChannelHandle & handle)4891 void Send(const XlaOp operand, const ChannelHandle& handle) {
4892   return operand.builder()->Send(operand, handle);
4893 }
4894 
Recv(XlaBuilder * builder,const Shape & shape,const ChannelHandle & handle)4895 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
4896            const ChannelHandle& handle) {
4897   return builder->Recv(shape, handle);
4898 }
4899 
SendWithToken(const XlaOp operand,const XlaOp token,const ChannelHandle & handle)4900 XlaOp SendWithToken(const XlaOp operand, const XlaOp token,
4901                     const ChannelHandle& handle) {
4902   return operand.builder()->SendWithToken(operand, token, handle);
4903 }
4904 
RecvWithToken(const XlaOp token,const Shape & shape,const ChannelHandle & handle)4905 XlaOp RecvWithToken(const XlaOp token, const Shape& shape,
4906                     const ChannelHandle& handle) {
4907   return token.builder()->RecvWithToken(token, shape, handle);
4908 }
4909 
SendToHost(const XlaOp operand,const XlaOp token,const Shape & shape_with_layout,const ChannelHandle & handle)4910 XlaOp SendToHost(const XlaOp operand, const XlaOp token,
4911                  const Shape& shape_with_layout, const ChannelHandle& handle) {
4912   return operand.builder()->SendToHost(operand, token, shape_with_layout,
4913                                        handle);
4914 }
4915 
RecvFromHost(const XlaOp token,const Shape & shape,const ChannelHandle & handle)4916 XlaOp RecvFromHost(const XlaOp token, const Shape& shape,
4917                    const ChannelHandle& handle) {
4918   return token.builder()->RecvFromHost(token, shape, handle);
4919 }
4920 
InfeedWithToken(const XlaOp token,const Shape & shape,const string & config)4921 XlaOp InfeedWithToken(const XlaOp token, const Shape& shape,
4922                       const string& config) {
4923   return token.builder()->InfeedWithToken(token, shape, config);
4924 }
4925 
OutfeedWithToken(const XlaOp operand,const XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)4926 XlaOp OutfeedWithToken(const XlaOp operand, const XlaOp token,
4927                        const Shape& shape_with_layout,
4928                        const string& outfeed_config) {
4929   return operand.builder()->OutfeedWithToken(operand, token, shape_with_layout,
4930                                              outfeed_config);
4931 }
4932 
CreateToken(XlaBuilder * builder)4933 XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); }
4934 
AfterAll(XlaBuilder * builder,absl::Span<const XlaOp> tokens)4935 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens) {
4936   return builder->AfterAll(tokens);
4937 }
4938 
BatchNormTraining(const XlaOp operand,const XlaOp scale,const XlaOp offset,float epsilon,int64 feature_index)4939 XlaOp BatchNormTraining(const XlaOp operand, const XlaOp scale,
4940                         const XlaOp offset, float epsilon,
4941                         int64 feature_index) {
4942   return operand.builder()->BatchNormTraining(operand, scale, offset, epsilon,
4943                                               feature_index);
4944 }
4945 
BatchNormInference(const XlaOp operand,const XlaOp scale,const XlaOp offset,const XlaOp mean,const XlaOp variance,float epsilon,int64 feature_index)4946 XlaOp BatchNormInference(const XlaOp operand, const XlaOp scale,
4947                          const XlaOp offset, const XlaOp mean,
4948                          const XlaOp variance, float epsilon,
4949                          int64 feature_index) {
4950   return operand.builder()->BatchNormInference(
4951       operand, scale, offset, mean, variance, epsilon, feature_index);
4952 }
4953 
BatchNormGrad(const XlaOp operand,const XlaOp scale,const XlaOp batch_mean,const XlaOp batch_var,const XlaOp grad_output,float epsilon,int64 feature_index)4954 XlaOp BatchNormGrad(const XlaOp operand, const XlaOp scale,
4955                     const XlaOp batch_mean, const XlaOp batch_var,
4956                     const XlaOp grad_output, float epsilon,
4957                     int64 feature_index) {
4958   return operand.builder()->BatchNormGrad(operand, scale, batch_mean, batch_var,
4959                                           grad_output, epsilon, feature_index);
4960 }
4961 
Iota(XlaBuilder * builder,PrimitiveType type,int64 size)4962 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) {
4963   return builder->Iota(type, size);
4964 }
4965 
Iota(XlaBuilder * builder,const Shape & shape,int64 iota_dimension)4966 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) {
4967   return builder->Iota(shape, iota_dimension);
4968 }
4969 
GetDimensionSize(const XlaOp operand,int64 dimension)4970 XlaOp GetDimensionSize(const XlaOp operand, int64 dimension) {
4971   return operand.builder()->GetDimensionSize(operand, dimension);
4972 }
4973 
SetDimensionSize(const XlaOp operand,const XlaOp val,int64 dimension)4974 XlaOp SetDimensionSize(const XlaOp operand, const XlaOp val, int64 dimension) {
4975   return operand.builder()->SetDimensionSize(operand, val, dimension);
4976 }
4977 
RemoveDynamicDimension(const XlaOp operand,int64 dimension)4978 XlaOp RemoveDynamicDimension(const XlaOp operand, int64 dimension) {
4979   return operand.builder()->RemoveDynamicDimension(operand, dimension);
4980 }
4981 
4982 }  // namespace xla
4983