• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/xla_builder.h"
17 
18 #include <functional>
19 #include <numeric>
20 #include <queue>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/memory/memory.h"
28 #include "absl/strings/match.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_join.h"
31 #include "tensorflow/compiler/xla/client/sharding_builder.h"
32 #include "tensorflow/compiler/xla/client/xla_computation.h"
33 #include "tensorflow/compiler/xla/execution_options_util.h"
34 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
37 #include "tensorflow/compiler/xla/service/shape_inference.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/util.h"
40 
41 namespace xla {
42 
43 using absl::StrCat;
44 
45 namespace {
46 
47 static const char kNameSeparator = '.';
48 
49 // Retrieves the base name of an instruction or computation fully qualified
50 // name, using separator as boundary between the initial base name part, and
51 // the numeric identification.
GetBaseName(const string & name,char separator)52 string GetBaseName(const string& name, char separator) {
53   auto pos = name.rfind(separator);
54   CHECK_NE(pos, string::npos) << name;
55   return name.substr(0, pos);
56 }
57 
58 // Generates a fully qualified computation/instruction name.
GetFullName(const string & base_name,char separator,int64 id)59 string GetFullName(const string& base_name, char separator, int64 id) {
60   const char separator_str[] = {separator, '\0'};
61   return StrCat(base_name, separator_str, id);
62 }
63 
64 // Common function to standardize setting name and IDs on computation and
65 // instruction proto entities.
66 template <typename T>
SetProtoIdAndName(T * entry,const string & base_name,char separator,int64 id)67 void SetProtoIdAndName(T* entry, const string& base_name, char separator,
68                        int64 id) {
69   entry->set_id(id);
70   entry->set_name(GetFullName(base_name, separator, id));
71 }
72 
73 template <typename InstructionType>
LookUpInstructionByHandleInternal(const absl::flat_hash_map<int64,int64> & handle_to_index,const std::vector<HloInstructionProto> & instructions,int64 handle)74 StatusOr<InstructionType> LookUpInstructionByHandleInternal(
75     const absl::flat_hash_map<int64, int64>& handle_to_index,
76     const std::vector<HloInstructionProto>& instructions, int64 handle) {
77   auto it = handle_to_index.find(handle);
78   if (it == handle_to_index.end()) {
79     return InvalidArgument("No XlaOp with handle %d", handle);
80   }
81   return const_cast<InstructionType>(&instructions.at(it->second));
82 }
83 
CheckBuildersAffinity(const XlaBuilder * op_builder,const XlaBuilder * builder,int64 handle)84 Status CheckBuildersAffinity(const XlaBuilder* op_builder,
85                              const XlaBuilder* builder, int64 handle) {
86   if (op_builder != builder) {
87     return InvalidArgument(
88         "XlaOp with handle %d is built by builder '%s', but is trying to use "
89         "it in builder '%s'",
90         handle, op_builder->name(), builder->name());
91   }
92   return Status::OK();
93 }
94 
95 template <typename InstructionType, typename OpBuilderType,
96           typename BuilderType, typename OpType>
LookUpInstructionInternal(const absl::flat_hash_map<int64,int64> & handle_to_index,const std::vector<HloInstructionProto> & instructions,OpBuilderType op_builder,BuilderType builder,OpType op_handle)97 StatusOr<InstructionType> LookUpInstructionInternal(
98     const absl::flat_hash_map<int64, int64>& handle_to_index,
99     const std::vector<HloInstructionProto>& instructions,
100     OpBuilderType op_builder, BuilderType builder, OpType op_handle) {
101   if (op_builder == nullptr) {
102     return InvalidArgument(
103         "Invalid XlaOp with handle %d; the builder of this op is freed",
104         op_handle);
105   }
106   TF_RETURN_IF_ERROR(CheckBuildersAffinity(op_builder, builder, op_handle));
107   return LookUpInstructionByHandleInternal<InstructionType>(
108       handle_to_index, instructions, op_handle);
109 }
110 
111 }  // namespace
112 
operator -(XlaOp x)113 XlaOp operator-(XlaOp x) { return Neg(x); }
operator +(XlaOp x,XlaOp y)114 XlaOp operator+(XlaOp x, XlaOp y) { return Add(x, y); }
operator -(XlaOp x,XlaOp y)115 XlaOp operator-(XlaOp x, XlaOp y) { return Sub(x, y); }
operator *(XlaOp x,XlaOp y)116 XlaOp operator*(XlaOp x, XlaOp y) { return Mul(x, y); }
operator /(XlaOp x,XlaOp y)117 XlaOp operator/(XlaOp x, XlaOp y) { return Div(x, y); }
operator %(XlaOp x,XlaOp y)118 XlaOp operator%(XlaOp x, XlaOp y) { return Rem(x, y); }
119 
operator ~(XlaOp x)120 XlaOp operator~(XlaOp x) { return Not(x); }
operator &(XlaOp x,XlaOp y)121 XlaOp operator&(XlaOp x, XlaOp y) { return And(x, y); }
operator |(XlaOp x,XlaOp y)122 XlaOp operator|(XlaOp x, XlaOp y) { return Or(x, y); }
operator ^(XlaOp x,XlaOp y)123 XlaOp operator^(XlaOp x, XlaOp y) { return Xor(x, y); }
operator <<(XlaOp x,XlaOp y)124 XlaOp operator<<(XlaOp x, XlaOp y) { return ShiftLeft(x, y); }
125 
operator >>(XlaOp x,XlaOp y)126 XlaOp operator>>(XlaOp x, XlaOp y) {
127   XlaBuilder* builder = x.builder();
128   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
129     TF_ASSIGN_OR_RETURN(const xla::Shape* shape, builder->GetShapePtr(x));
130     if (!ShapeUtil::ElementIsIntegral(*shape)) {
131       return InvalidArgument(
132           "Argument to >> operator does not have an integral type (%s).",
133           ShapeUtil::HumanString(*shape));
134     }
135     if (ShapeUtil::ElementIsSigned(*shape)) {
136       return ShiftRightArithmetic(x, y);
137     } else {
138       return ShiftRightLogical(x, y);
139     }
140   });
141 }
142 
GetShapePtr(XlaOp op) const143 StatusOr<const Shape*> XlaBuilder::GetShapePtr(XlaOp op) const {
144   TF_RETURN_IF_ERROR(first_error_);
145   TF_RETURN_IF_ERROR(CheckBuildersAffinity(op.builder(), this, op.handle()));
146   auto it = handle_to_index_.find(op.handle());
147   if (it == handle_to_index_.end()) {
148     return InvalidArgument("No XlaOp with handle %d", op.handle());
149   }
150   return instruction_shapes_.at(it->second).get();
151 }
152 
GetShape(XlaOp op) const153 StatusOr<Shape> XlaBuilder::GetShape(XlaOp op) const {
154   TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(op));
155   return *shape;
156 }
157 
GetOperandShapes(absl::Span<const XlaOp> operands) const158 StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
159     absl::Span<const XlaOp> operands) const {
160   std::vector<Shape> operand_shapes;
161   for (XlaOp operand : operands) {
162     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
163     operand_shapes.push_back(*shape);
164   }
165   return operand_shapes;
166 }
167 
XlaBuilder(const string & computation_name)168 XlaBuilder::XlaBuilder(const string& computation_name)
169     : name_(computation_name) {}
170 
~XlaBuilder()171 XlaBuilder::~XlaBuilder() {}
172 
ReportError(const Status & error)173 XlaOp XlaBuilder::ReportError(const Status& error) {
174   CHECK(!error.ok());
175   if (die_immediately_on_error_) {
176     LOG(FATAL) << "error building computation: " << error;
177   }
178 
179   if (first_error_.ok()) {
180     first_error_ = error;
181     first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
182   }
183   return XlaOp(this);
184 }
185 
ReportErrorOrReturn(const StatusOr<XlaOp> & op)186 XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr<XlaOp>& op) {
187   if (!first_error_.ok()) {
188     return XlaOp(this);
189   }
190   if (!op.ok()) {
191     return ReportError(op.status());
192   }
193   return op.ValueOrDie();
194 }
195 
ReportErrorOrReturn(const std::function<StatusOr<XlaOp> ()> & op_creator)196 XlaOp XlaBuilder::ReportErrorOrReturn(
197     const std::function<StatusOr<XlaOp>()>& op_creator) {
198   return ReportErrorOrReturn(op_creator());
199 }
200 
GetProgramShape(int64 root_id) const201 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
202   TF_RETURN_IF_ERROR(first_error_);
203   TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
204                       LookUpInstructionByHandle(root_id));
205 
206   ProgramShape program_shape;
207 
208   *program_shape.mutable_result() = Shape(root_proto->shape());
209 
210   // Check that the parameter numbers are continuous from 0, and add parameter
211   // shapes and names to the program shape.
212   const int64 param_count = parameter_numbers_.size();
213   for (int64 i = 0; i < param_count; i++) {
214     program_shape.add_parameters();
215     program_shape.add_parameter_names();
216   }
217   for (const HloInstructionProto& instr : instructions_) {
218     // Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So
219     // to verify continuity, we just need to verify that every parameter is in
220     // the right range.
221     if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
222       const int64 index = instr.parameter_number();
223       TF_RET_CHECK(index >= 0 && index < param_count)
224           << "invalid parameter number: " << index;
225       *program_shape.mutable_parameters(index) = Shape(instr.shape());
226       *program_shape.mutable_parameter_names(index) = instr.name();
227     }
228   }
229   return program_shape;
230 }
231 
GetProgramShape() const232 StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const {
233   TF_RET_CHECK(!instructions_.empty());
234   return GetProgramShape(instructions_.back().id());
235 }
236 
GetProgramShape(XlaOp root) const237 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(XlaOp root) const {
238   if (root.builder_ != this) {
239     return InvalidArgument("Given root operation is not in this computation.");
240   }
241   return GetProgramShape(root.handle());
242 }
243 
IsConstantVisitor(const int64 op_handle,absl::flat_hash_set<int64> * visited,bool * is_constant) const244 void XlaBuilder::IsConstantVisitor(const int64 op_handle,
245                                    absl::flat_hash_set<int64>* visited,
246                                    bool* is_constant) const {
247   if (visited->contains(op_handle) || !*is_constant) {
248     return;
249   }
250 
251   const HloInstructionProto& instr =
252       *(LookUpInstructionByHandle(op_handle).ValueOrDie());
253   const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
254   switch (opcode) {
255     default:
256       for (const int64 operand_id : instr.operand_ids()) {
257         IsConstantVisitor(operand_id, visited, is_constant);
258       }
259       // TODO(b/32495713): We aren't checking the called computations.
260       break;
261 
262     case HloOpcode::kGetDimensionSize:
263       // GetDimensionSize is always considered constant in XLA -- If a dynamic
264       // dimension is presented, -1 is returned.
265       break;
266 
267     // Non functional ops.
268     case HloOpcode::kRng:
269     case HloOpcode::kAllReduce:
270       // TODO(b/33009255): Implement constant folding for cross replica sum.
271     case HloOpcode::kInfeed:
272     case HloOpcode::kOutfeed:
273     case HloOpcode::kCall:
274       // TODO(b/32495713): We aren't checking the to_apply computation itself,
275       // so we conservatively say that computations containing the Call op
276       // cannot be constant.  We cannot set is_functional=false in other similar
277       // cases since we're already relying on IsConstant to return true.
278     case HloOpcode::kCustomCall:
279     case HloOpcode::kWhile:
280       // TODO(b/32495713): We aren't checking the condition and body
281       // computations themselves.
282     case HloOpcode::kScatter:
283       // TODO(b/32495713): We aren't checking the embedded computation in
284       // Scatter.
285     case HloOpcode::kSend:
286     case HloOpcode::kRecv:
287     case HloOpcode::kParameter:
288       *is_constant = false;
289       break;
290   }
291   if (!*is_constant) {
292     VLOG(1) << "Non-constant: " << instr.name();
293   }
294   visited->insert(op_handle);
295 }
296 
SetDynamicBinding(int64 dynamic_size_param_num,ShapeIndex dynamic_size_param_index,int64 target_param_num,ShapeIndex target_param_index,int64 target_dim_num)297 Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num,
298                                      ShapeIndex dynamic_size_param_index,
299                                      int64 target_param_num,
300                                      ShapeIndex target_param_index,
301                                      int64 target_dim_num) {
302   bool param_exists = false;
303   for (size_t index = 0; index < instructions_.size(); ++index) {
304     HloInstructionProto& instr = instructions_[index];
305     if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
306         instr.parameter_number() == target_param_num) {
307       param_exists = true;
308       Shape param_shape(instr.shape());
309       Shape* param_shape_ptr = &param_shape;
310       for (int64 index : target_param_index) {
311         param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index);
312       }
313       param_shape_ptr->set_dynamic_dimension(target_dim_num,
314                                              /*is_dynamic=*/true);
315       *instr.mutable_shape() = param_shape.ToProto();
316       instruction_shapes_[index] =
317           absl::make_unique<Shape>(std::move(param_shape));
318     }
319   }
320   if (!param_exists) {
321     return InvalidArgument(
322         "Asked to mark parameter %lld as dynamic sized parameter, but the "
323         "doesn't exists",
324         target_param_num);
325   }
326 
327   TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind(
328       DynamicParameterBinding::DynamicParameter{dynamic_size_param_num,
329                                                 dynamic_size_param_index},
330       DynamicParameterBinding::DynamicDimension{
331           target_param_num, target_param_index, target_dim_num}));
332   return Status::OK();
333 }
334 
SetInstructionFrontendAttribute(const XlaOp op,std::string attribute,std::string value)335 Status XlaBuilder::SetInstructionFrontendAttribute(const XlaOp op,
336                                                    std::string attribute,
337                                                    std::string value) {
338   TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op));
339   auto* frontend_attributes = instr_proto->mutable_frontend_attributes();
340   (*frontend_attributes->mutable_map())[attribute] = std::move(value);
341   return Status::OK();
342 }
343 
BuildAndNoteError()344 XlaComputation XlaBuilder::BuildAndNoteError() {
345   DCHECK(parent_builder_ != nullptr);
346   auto build_status = Build();
347   if (!build_status.ok()) {
348     parent_builder_->ReportError(
349         AddStatus(build_status.status(), absl::StrCat("error from: ", name_)));
350     return {};
351   }
352   return build_status.ConsumeValueOrDie();
353 }
354 
GetCurrentStatus() const355 Status XlaBuilder::GetCurrentStatus() const {
356   if (!first_error_.ok()) {
357     string backtrace;
358     first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
359     return AppendStatus(first_error_, backtrace);
360   }
361   return Status::OK();
362 }
363 
Build(bool remove_dynamic_dimensions)364 StatusOr<XlaComputation> XlaBuilder::Build(bool remove_dynamic_dimensions) {
365   TF_RETURN_IF_ERROR(GetCurrentStatus());
366   return Build(instructions_.back().id(), remove_dynamic_dimensions);
367 }
368 
Build(XlaOp root,bool remove_dynamic_dimensions)369 StatusOr<XlaComputation> XlaBuilder::Build(XlaOp root,
370                                            bool remove_dynamic_dimensions) {
371   if (root.builder_ != this) {
372     return InvalidArgument("Given root operation is not in this computation.");
373   }
374   return Build(root.handle(), remove_dynamic_dimensions);
375 }
376 
Build(int64 root_id,bool remove_dynamic_dimensions)377 StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id,
378                                            bool remove_dynamic_dimensions) {
379   TF_RETURN_IF_ERROR(GetCurrentStatus());
380 
381   // TODO(b/121223198): XLA backend cannot handle dynamic dimensions yet, remove
382   // all dynamic dimensions before building xla program until we have support in
383   // the backend.
384   if (remove_dynamic_dimensions) {
385     std::function<void(Shape*)> remove_dynamic_dimension = [&](Shape* shape) {
386       if (shape->tuple_shapes_size() != 0) {
387         for (int64 i = 0; i < shape->tuple_shapes_size(); ++i) {
388           remove_dynamic_dimension(shape->mutable_tuple_shapes(i));
389         }
390       }
391       for (int64 i = 0; i < shape->dimensions_size(); ++i) {
392         shape->set_dynamic_dimension(i, false);
393       }
394     };
395     for (size_t index = 0; index < instructions_.size(); ++index) {
396       remove_dynamic_dimension(instruction_shapes_[index].get());
397       *instructions_[index].mutable_shape() =
398           instruction_shapes_[index]->ToProto();
399     }
400   }
401 
402   HloComputationProto entry;
403   SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId());
404   TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id));
405   *entry.mutable_program_shape() = program_shape.ToProto();
406   entry.set_root_id(root_id);
407 
408   for (auto& instruction : instructions_) {
409     // Ensures that the instruction names are unique among the whole graph.
410     instruction.set_name(
411         GetFullName(instruction.name(), kNameSeparator, instruction.id()));
412     entry.add_instructions()->Swap(&instruction);
413   }
414 
415   XlaComputation computation(entry.id());
416   HloModuleProto* module = computation.mutable_proto();
417   module->set_name(entry.name());
418   module->set_id(entry.id());
419   module->set_entry_computation_name(entry.name());
420   module->set_entry_computation_id(entry.id());
421   *module->mutable_host_program_shape() = entry.program_shape();
422   for (auto& e : embedded_) {
423     module->add_computations()->Swap(&e.second);
424   }
425   module->add_computations()->Swap(&entry);
426   if (!input_output_aliases_.empty()) {
427     TF_RETURN_IF_ERROR(
428         PopulateInputOutputAlias(module, program_shape, input_output_aliases_));
429   }
430   *(module->mutable_dynamic_parameter_binding()) =
431       dynamic_parameter_binding_.ToProto();
432 
433   // Clear data held by this builder.
434   this->instructions_.clear();
435   this->instruction_shapes_.clear();
436   this->handle_to_index_.clear();
437   this->embedded_.clear();
438   this->parameter_numbers_.clear();
439 
440   return std::move(computation);
441 }
442 
PopulateInputOutputAlias(HloModuleProto * module,const ProgramShape & program_shape,const std::vector<InputOutputAlias> & input_output_aliases)443 /* static */ Status XlaBuilder::PopulateInputOutputAlias(
444     HloModuleProto* module, const ProgramShape& program_shape,
445     const std::vector<InputOutputAlias>& input_output_aliases) {
446   HloInputOutputAliasConfig config(program_shape.result());
447   for (auto& alias : input_output_aliases) {
448     // The HloInputOutputAliasConfig does not do parameter validation as it only
449     // carries the result shape. Maybe it should be constructed with a
450     // ProgramShape to allow full validation. We will still get an error when
451     // trying to compile the HLO module, but would be better to have validation
452     // at this stage.
453     if (alias.param_number >= program_shape.parameters_size()) {
454       return InvalidArgument("Invalid parameter number %ld (total %ld)",
455                              alias.param_number,
456                              program_shape.parameters_size());
457     }
458     const Shape& parameter_shape = program_shape.parameters(alias.param_number);
459     if (!ShapeUtil::IndexIsValid(parameter_shape, alias.param_index)) {
460       return InvalidArgument("Invalid parameter %ld index: %s",
461                              alias.param_number,
462                              alias.param_index.ToString().c_str());
463     }
464     TF_RETURN_IF_ERROR(config.SetUpAlias(
465         alias.output_index, alias.param_number, alias.param_index,
466         HloInputOutputAliasConfig::AliasKind::kUserAlias));
467   }
468   *module->mutable_input_output_alias() = config.ToProto();
469   return Status::OK();
470 }
471 
InDimBroadcast(const Shape & shape,XlaOp operand,absl::Span<const int64> broadcast_dimensions)472 StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
473     const Shape& shape, XlaOp operand,
474     absl::Span<const int64> broadcast_dimensions) {
475   TF_RETURN_IF_ERROR(first_error_);
476 
477   HloInstructionProto instr;
478   *instr.mutable_shape() = shape.ToProto();
479   for (int64 dim : broadcast_dimensions) {
480     instr.add_dimensions(dim);
481   }
482 
483   return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand});
484 }
485 
AddBroadcastSequence(const Shape & output_shape,XlaOp operand)486 StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
487                                                  XlaOp operand) {
488   TF_RETURN_IF_ERROR(first_error_);
489 
490   TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
491 
492   CHECK(ShapeUtil::IsScalar(*operand_shape) ||
493         operand_shape->rank() == output_shape.rank());
494   Shape broadcast_shape =
495       ShapeUtil::ChangeElementType(output_shape, operand_shape->element_type());
496 
497   // Do explicit broadcast for scalar.
498   if (ShapeUtil::IsScalar(*operand_shape)) {
499     return InDimBroadcast(broadcast_shape, operand, {});
500   }
501 
502   // Do explicit broadcast for degenerate broadcast.
503   std::vector<int64> broadcast_dimensions;
504   std::vector<int64> reshaped_dimensions;
505   for (int i = 0; i < operand_shape->rank(); i++) {
506     if (operand_shape->dimensions(i) == output_shape.dimensions(i)) {
507       broadcast_dimensions.push_back(i);
508       reshaped_dimensions.push_back(operand_shape->dimensions(i));
509     } else {
510       TF_RET_CHECK(operand_shape->dimensions(i) == 1)
511           << "An explicit broadcast sequence requires the broadcasted "
512              "dimensions to be trivial; operand shape: "
513           << *operand_shape << "; output_shape: " << output_shape;
514     }
515   }
516 
517   Shape reshaped_shape =
518       ShapeUtil::MakeShape(operand_shape->element_type(), reshaped_dimensions);
519 
520   std::vector<std::pair<int64, int64>> unmodified_dims =
521       ShapeUtil::DimensionsUnmodifiedByReshape(*operand_shape, reshaped_shape);
522 
523   for (auto& unmodified : unmodified_dims) {
524     if (operand_shape->is_dynamic_dimension(unmodified.first)) {
525       reshaped_shape.set_dynamic_dimension(unmodified.second, true);
526     }
527   }
528 
529   // Eliminate the size one dimensions.
530   TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, Reshape(reshaped_shape, operand));
531   // Broadcast 'reshape' up to the larger size.
532   return InDimBroadcast(broadcast_shape, reshaped_operand,
533                         broadcast_dimensions);
534 }
535 
UnaryOp(HloOpcode unop,XlaOp operand)536 XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
537   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
538     HloInstructionProto instr;
539     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
540     TF_ASSIGN_OR_RETURN(
541         Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape));
542     *instr.mutable_shape() = shape.ToProto();
543     return AddInstruction(std::move(instr), unop, {operand});
544   });
545 }
546 
BinaryOp(HloOpcode binop,XlaOp lhs,XlaOp rhs,absl::Span<const int64> broadcast_dimensions,absl::optional<ComparisonDirection> direction)547 XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
548                            absl::Span<const int64> broadcast_dimensions,
549                            absl::optional<ComparisonDirection> direction) {
550   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
551     HloInstructionProto instr;
552     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
553     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
554     TF_ASSIGN_OR_RETURN(
555         Shape shape, ShapeInference::InferBinaryOpShape(
556                          binop, *lhs_shape, *rhs_shape, broadcast_dimensions));
557     *instr.mutable_shape() = shape.ToProto();
558     if (binop == HloOpcode::kCompare) {
559       if (!direction.has_value()) {
560         return InvalidArgument(
561             "kCompare expects a ComparisonDirection, but none provided.");
562       }
563       instr.set_comparison_direction(ComparisonDirectionToString(*direction));
564     } else if (direction.has_value()) {
565       return InvalidArgument(
566           "A comparison direction is provided for a non-compare opcode: %s.",
567           HloOpcodeString(binop));
568     }
569 
570     const int64 lhs_rank = lhs_shape->rank();
571     const int64 rhs_rank = rhs_shape->rank();
572 
573     XlaOp updated_lhs = lhs;
574     XlaOp updated_rhs = rhs;
575 
576     if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) {
577       const bool should_broadcast_lhs = lhs_rank < rhs_rank;
578       XlaOp from = should_broadcast_lhs ? lhs : rhs;
579       const Shape& from_shape = should_broadcast_lhs ? *lhs_shape : *rhs_shape;
580 
581       std::vector<int64> to_size;
582       std::vector<bool> to_size_is_dynamic;
583       for (int i = 0; i < shape.rank(); i++) {
584         to_size.push_back(shape.dimensions(i));
585         to_size_is_dynamic.push_back(shape.is_dynamic_dimension(i));
586       }
587       for (int64 from_dim = 0; from_dim < from_shape.rank(); from_dim++) {
588         int64 to_dim = broadcast_dimensions[from_dim];
589         to_size[to_dim] = from_shape.dimensions(from_dim);
590         to_size_is_dynamic[to_dim] = from_shape.is_dynamic_dimension(from_dim);
591       }
592 
593       const Shape& broadcasted_shape = ShapeUtil::MakeShape(
594           from_shape.element_type(), to_size, to_size_is_dynamic);
595       TF_ASSIGN_OR_RETURN(
596           XlaOp broadcasted_operand,
597           InDimBroadcast(broadcasted_shape, from, broadcast_dimensions));
598 
599       updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs;
600       updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs;
601     }
602 
603     TF_ASSIGN_OR_RETURN(const Shape* updated_lhs_shape,
604                         GetShapePtr(updated_lhs));
605     if (!ShapeUtil::SameDimensions(shape, *updated_lhs_shape)) {
606       TF_ASSIGN_OR_RETURN(updated_lhs,
607                           AddBroadcastSequence(shape, updated_lhs));
608     }
609     TF_ASSIGN_OR_RETURN(const Shape* updated_rhs_shape,
610                         GetShapePtr(updated_rhs));
611     if (!ShapeUtil::SameDimensions(shape, *updated_rhs_shape)) {
612       TF_ASSIGN_OR_RETURN(updated_rhs,
613                           AddBroadcastSequence(shape, updated_rhs));
614     }
615 
616     return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs});
617   });
618 }
619 
TernaryOp(HloOpcode triop,XlaOp lhs,XlaOp rhs,XlaOp ehs)620 XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) {
621   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
622     HloInstructionProto instr;
623     XlaOp updated_lhs = lhs;
624     XlaOp updated_rhs = rhs;
625     XlaOp updated_ehs = ehs;
626     // The client API supports implicit broadcast for kSelect and kClamp, but
627     // XLA does not support implicit broadcast. Make implicit broadcast explicit
628     // and update the operands.
629     if (triop == HloOpcode::kSelect || triop == HloOpcode::kClamp) {
630       TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
631       TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
632       TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(ehs));
633 
634       absl::optional<Shape> non_scalar_shape;
635       for (const Shape* shape : {lhs_shape, rhs_shape, ehs_shape}) {
636         if (shape->IsArray() && shape->rank() != 0) {
637           if (non_scalar_shape.has_value()) {
638             // TODO(jpienaar): The case where we need to compute the broadcasted
639             // shape by considering multiple of the shapes is not implemented.
640             // Consider reusing getBroadcastedType from mlir/Dialect/Traits.h.
641             TF_RET_CHECK(non_scalar_shape.value().dimensions() ==
642                          shape->dimensions())
643                 << "Unimplemented implicit broadcast.";
644           } else {
645             non_scalar_shape = *shape;
646           }
647         }
648       }
649       if (non_scalar_shape.has_value()) {
650         if (ShapeUtil::IsScalar(*lhs_shape)) {
651           TF_ASSIGN_OR_RETURN(updated_lhs,
652                               AddBroadcastSequence(*non_scalar_shape, lhs));
653         }
654         if (ShapeUtil::IsScalar(*rhs_shape)) {
655           TF_ASSIGN_OR_RETURN(updated_rhs,
656                               AddBroadcastSequence(*non_scalar_shape, rhs));
657         }
658         if (ShapeUtil::IsScalar(*ehs_shape)) {
659           TF_ASSIGN_OR_RETURN(updated_ehs,
660                               AddBroadcastSequence(*non_scalar_shape, ehs));
661         }
662       }
663     }
664 
665     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(updated_lhs));
666     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(updated_rhs));
667     TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(updated_ehs));
668     StatusOr<const Shape> status_or_shape = ShapeInference::InferTernaryOpShape(
669         triop, *lhs_shape, *rhs_shape, *ehs_shape);
670     if (!status_or_shape.status().ok()) {
671       return InvalidArgument(
672           "%s Input scalar shapes may have been changed to non-scalar shapes.",
673           status_or_shape.status().error_message());
674     }
675     *instr.mutable_shape() = status_or_shape.ConsumeValueOrDie().ToProto();
676     return AddInstruction(std::move(instr), triop,
677                           {updated_lhs, updated_rhs, updated_ehs});
678   });
679 }
680 
ConstantLiteral(const LiteralSlice & literal)681 XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
682   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
683     if (literal.shape().IsArray() && literal.element_count() > 1 &&
684         literal.IsAllFirst()) {
685       Literal scalar = LiteralUtil::GetFirstScalarLiteral(literal);
686       HloInstructionProto instr;
687       *instr.mutable_shape() = scalar.shape().ToProto();
688       *instr.mutable_literal() = scalar.ToProto();
689       TF_ASSIGN_OR_RETURN(
690           XlaOp scalar_op,
691           AddInstruction(std::move(instr), HloOpcode::kConstant));
692       return Broadcast(scalar_op, literal.shape().dimensions());
693     } else {
694       HloInstructionProto instr;
695       *instr.mutable_shape() = literal.shape().ToProto();
696       *instr.mutable_literal() = literal.ToProto();
697       return AddInstruction(std::move(instr), HloOpcode::kConstant);
698     }
699   });
700 }
701 
Iota(const Shape & shape,int64 iota_dimension)702 XlaOp XlaBuilder::Iota(const Shape& shape, int64 iota_dimension) {
703   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
704     HloInstructionProto instr;
705     *instr.mutable_shape() = shape.ToProto();
706     instr.add_dimensions(iota_dimension);
707     return AddInstruction(std::move(instr), HloOpcode::kIota);
708   });
709 }
710 
Iota(PrimitiveType type,int64 size)711 XlaOp XlaBuilder::Iota(PrimitiveType type, int64 size) {
712   return Iota(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0);
713 }
714 
Call(const XlaComputation & computation,absl::Span<const XlaOp> operands)715 XlaOp XlaBuilder::Call(const XlaComputation& computation,
716                        absl::Span<const XlaOp> operands) {
717   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
718     HloInstructionProto instr;
719     std::vector<const Shape*> operand_shape_ptrs;
720     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
721     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
722                       [](const Shape& shape) { return &shape; });
723     TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
724                         computation.GetProgramShape());
725     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape(
726                                          operand_shape_ptrs,
727                                          /*to_apply=*/called_program_shape));
728     *instr.mutable_shape() = shape.ToProto();
729 
730     AddCalledComputation(computation, &instr);
731 
732     return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
733   });
734 }
735 
Parameter(int64 parameter_number,const Shape & shape,const string & name,const std::vector<bool> & replicated_at_leaf_buffers)736 XlaOp XlaBuilder::Parameter(
737     int64 parameter_number, const Shape& shape, const string& name,
738     const std::vector<bool>& replicated_at_leaf_buffers) {
739   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
740     HloInstructionProto instr;
741     if (!parameter_numbers_.insert(parameter_number).second) {
742       return InvalidArgument("parameter %d already registered",
743                              parameter_number);
744     }
745     instr.set_parameter_number(parameter_number);
746     instr.set_name(name);
747     *instr.mutable_shape() = shape.ToProto();
748     if (!replicated_at_leaf_buffers.empty()) {
749       auto replication = instr.mutable_parameter_replication();
750       for (bool replicated : replicated_at_leaf_buffers) {
751         replication->add_replicated_at_leaf_buffers(replicated);
752       }
753     }
754     return AddInstruction(std::move(instr), HloOpcode::kParameter);
755   });
756 }
757 
Broadcast(XlaOp operand,absl::Span<const int64> broadcast_sizes)758 XlaOp XlaBuilder::Broadcast(XlaOp operand,
759                             absl::Span<const int64> broadcast_sizes) {
760   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
761     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
762     TF_ASSIGN_OR_RETURN(
763         const Shape& shape,
764         ShapeInference::InferBroadcastShape(*operand_shape, broadcast_sizes));
765 
766     // The client-level broadcast op just appends dimensions on the left (adds
767     // lowest numbered dimensions). The HLO broadcast instruction is more
768     // flexible and can add new dimensions anywhere. The instruction's
769     // dimensions field maps operand dimensions to dimensions in the broadcast
770     // output, so to append dimensions on the left the instruction's dimensions
771     // should just be the n highest dimension numbers of the output shape where
772     // n is the number of input dimensions.
773     const int64 operand_rank = operand_shape->rank();
774     std::vector<int64> dimensions(operand_rank);
775     for (int i = 0; i < operand_rank; ++i) {
776       dimensions[i] = i + shape.rank() - operand_rank;
777     }
778     return InDimBroadcast(shape, operand, dimensions);
779   });
780 }
781 
BroadcastInDim(XlaOp operand,const absl::Span<const int64> out_dim_size,const absl::Span<const int64> broadcast_dimensions)782 XlaOp XlaBuilder::BroadcastInDim(
783     XlaOp operand, const absl::Span<const int64> out_dim_size,
784     const absl::Span<const int64> broadcast_dimensions) {
785   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
786     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
787     // Output shape, in the case of degenerate broadcast, the out_dim_size is
788     // not necessarily the same as the dimension sizes of the output shape.
789     auto output_shape =
790         ShapeUtil::MakeShape(operand_shape->element_type(), out_dim_size);
791     if (operand_shape->rank() != broadcast_dimensions.size()) {
792       return InvalidArgument(
793           "Size of broadcast_dimensions has to match operand's rank; operand "
794           "rank: %lld, size of broadcast_dimensions %u.",
795           operand_shape->rank(), broadcast_dimensions.size());
796     }
797     for (int i = 0; i < broadcast_dimensions.size(); i++) {
798       if (broadcast_dimensions[i] < 0 ||
799           broadcast_dimensions[i] > out_dim_size.size()) {
800         return InvalidArgument("Broadcast dimension %lld is out of bound",
801                                broadcast_dimensions[i]);
802       }
803       output_shape.set_dynamic_dimension(
804           broadcast_dimensions[i], operand_shape->is_dynamic_dimension(i));
805     }
806 
807     TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape(
808                            *operand_shape, output_shape, broadcast_dimensions)
809                            .status());
810     std::vector<int64> in_dim_size(out_dim_size.begin(), out_dim_size.end());
811     for (int i = 0; i < broadcast_dimensions.size(); i++) {
812       in_dim_size[broadcast_dimensions[i]] = operand_shape->dimensions(i);
813     }
814     const auto& in_dim_shape =
815         ShapeUtil::MakeShape(operand_shape->element_type(), in_dim_size);
816     TF_ASSIGN_OR_RETURN(
817         XlaOp in_dim_broadcast,
818         InDimBroadcast(in_dim_shape, operand, broadcast_dimensions));
819 
820     // If broadcast is not degenerate, return broadcasted result.
821     if (ShapeUtil::Equal(in_dim_shape, output_shape)) {
822       return in_dim_broadcast;
823     }
824 
825     // Otherwise handle degenerate broadcast case.
826     return AddBroadcastSequence(output_shape, in_dim_broadcast);
827   });
828 }
829 
Reshape(const Shape & shape,XlaOp operand,int64 inferred_dimension)830 StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, XlaOp operand,
831                                     int64 inferred_dimension) {
832   TF_RETURN_IF_ERROR(first_error_);
833 
834   HloInstructionProto instr;
835   *instr.mutable_shape() = shape.ToProto();
836   if (inferred_dimension != -1) {
837     instr.add_dimensions(inferred_dimension);
838   }
839   return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand});
840 }
841 
Slice(XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)842 XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span<const int64> start_indices,
843                         absl::Span<const int64> limit_indices,
844                         absl::Span<const int64> strides) {
845   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
846     HloInstructionProto instr;
847     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
848     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSliceShape(
849                                          *operand_shape, start_indices,
850                                          limit_indices, strides));
851     *instr.mutable_shape() = shape.ToProto();
852     for (int i = 0; i < start_indices.size(); i++) {
853       auto* slice_config = instr.add_slice_dimensions();
854       slice_config->set_start(start_indices[i]);
855       slice_config->set_limit(limit_indices[i]);
856       slice_config->set_stride(strides[i]);
857     }
858 
859     return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
860   });
861 }
862 
SliceInDim(XlaOp operand,int64 start_index,int64 limit_index,int64 stride,int64 dimno)863 XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64 start_index,
864                              int64 limit_index, int64 stride, int64 dimno) {
865   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
866     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
867     std::vector<int64> starts(shape->rank(), 0);
868     std::vector<int64> limits(shape->dimensions().begin(),
869                               shape->dimensions().end());
870     std::vector<int64> strides(shape->rank(), 1);
871     starts[dimno] = start_index;
872     limits[dimno] = limit_index;
873     strides[dimno] = stride;
874     return Slice(operand, starts, limits, strides);
875   });
876 }
877 
DynamicSlice(XlaOp operand,XlaOp start_indices,absl::Span<const int64> slice_sizes)878 XlaOp XlaBuilder::DynamicSlice(XlaOp operand, XlaOp start_indices,
879                                absl::Span<const int64> slice_sizes) {
880   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
881     HloInstructionProto instr;
882 
883     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
884     TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape,
885                         GetShapePtr(start_indices));
886     TF_ASSIGN_OR_RETURN(
887         Shape shape, ShapeInference::InferDynamicSliceShape(
888                          *operand_shape, {*start_indices_shape}, slice_sizes));
889     *instr.mutable_shape() = shape.ToProto();
890 
891     for (int64 size : slice_sizes) {
892       instr.add_dynamic_slice_sizes(size);
893     }
894 
895     return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice,
896                           {operand, start_indices});
897   });
898 }
899 
DynamicSlice(XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)900 XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
901                                absl::Span<const XlaOp> start_indices,
902                                absl::Span<const int64> slice_sizes) {
903   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
904     HloInstructionProto instr;
905 
906     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
907     std::vector<const Shape*> start_indices_shape_ptrs;
908     TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
909                         GetOperandShapes(start_indices));
910     absl::c_transform(start_indices_shapes,
911                       std::back_inserter(start_indices_shape_ptrs),
912                       [](const Shape& shape) { return &shape; });
913     TF_ASSIGN_OR_RETURN(Shape shape,
914                         ShapeInference::InferDynamicSliceShape(
915                             *operand_shape, start_indices_shapes, slice_sizes));
916     *instr.mutable_shape() = shape.ToProto();
917 
918     for (int64 size : slice_sizes) {
919       instr.add_dynamic_slice_sizes(size);
920     }
921 
922     std::vector<XlaOp> operands = {operand};
923     operands.insert(operands.end(), start_indices.begin(), start_indices.end());
924     return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands);
925   });
926 }
927 
DynamicUpdateSlice(XlaOp operand,XlaOp update,XlaOp start_indices)928 XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
929                                      XlaOp start_indices) {
930   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
931     HloInstructionProto instr;
932 
933     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
934     TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
935     TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape,
936                         GetShapePtr(start_indices));
937     TF_ASSIGN_OR_RETURN(
938         Shape shape,
939         ShapeInference::InferDynamicUpdateSliceShape(
940             *operand_shape, *update_shape, {*start_indices_shape}));
941     *instr.mutable_shape() = shape.ToProto();
942 
943     return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
944                           {operand, update, start_indices});
945   });
946 }
947 
DynamicUpdateSlice(XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)948 XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
949                                      absl::Span<const XlaOp> start_indices) {
950   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
951     HloInstructionProto instr;
952 
953     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
954     TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
955     std::vector<const Shape*> start_indices_shape_ptrs;
956     TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
957                         GetOperandShapes(start_indices));
958     absl::c_transform(start_indices_shapes,
959                       std::back_inserter(start_indices_shape_ptrs),
960                       [](const Shape& shape) { return &shape; });
961     TF_ASSIGN_OR_RETURN(
962         Shape shape, ShapeInference::InferDynamicUpdateSliceShape(
963                          *operand_shape, *update_shape, start_indices_shapes));
964     *instr.mutable_shape() = shape.ToProto();
965 
966     std::vector<XlaOp> operands = {operand, update};
967     operands.insert(operands.end(), start_indices.begin(), start_indices.end());
968     return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
969                           operands);
970   });
971 }
972 
ConcatInDim(absl::Span<const XlaOp> operands,int64 dimension)973 XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
974                               int64 dimension) {
975   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
976     HloInstructionProto instr;
977 
978     std::vector<const Shape*> operand_shape_ptrs;
979     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
980     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
981                       [](const Shape& shape) { return &shape; });
982     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape(
983                                          operand_shape_ptrs, dimension));
984     *instr.mutable_shape() = shape.ToProto();
985 
986     instr.add_dimensions(dimension);
987 
988     return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
989   });
990 }
991 
Pad(XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)992 XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value,
993                       const PaddingConfig& padding_config) {
994   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
995     HloInstructionProto instr;
996 
997     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
998     TF_ASSIGN_OR_RETURN(const Shape* padding_value_shape,
999                         GetShapePtr(padding_value));
1000     TF_ASSIGN_OR_RETURN(
1001         Shape shape, ShapeInference::InferPadShape(
1002                          *operand_shape, *padding_value_shape, padding_config));
1003     *instr.mutable_shape() = shape.ToProto();
1004     *instr.mutable_padding_config() = padding_config;
1005 
1006     return AddInstruction(std::move(instr), HloOpcode::kPad,
1007                           {operand, padding_value});
1008   });
1009 }
1010 
Reshape(XlaOp operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes,int64 inferred_dimension)1011 XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> dimensions,
1012                           absl::Span<const int64> new_sizes,
1013                           int64 inferred_dimension) {
1014   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1015     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1016     TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferReshapeShape(
1017                                                *operand_shape, dimensions,
1018                                                new_sizes, inferred_dimension));
1019     XlaOp transposed = IsIdentityPermutation(dimensions)
1020                            ? operand
1021                            : Transpose(operand, dimensions);
1022     return Reshape(shape, transposed, inferred_dimension);
1023   });
1024 }
1025 
Reshape(XlaOp operand,absl::Span<const int64> new_sizes,int64 inferred_dimension)1026 XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
1027                           int64 inferred_dimension) {
1028   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1029     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
1030     std::vector<int64> dimensions(shape->dimensions_size());
1031     std::iota(dimensions.begin(), dimensions.end(), 0);
1032     return Reshape(operand, dimensions, new_sizes, inferred_dimension);
1033   });
1034 }
1035 
Collapse(XlaOp operand,absl::Span<const int64> dimensions)1036 XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span<const int64> dimensions) {
1037   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1038     if (dimensions.size() <= 1) {
1039       // Not collapsing anything, trivially we can return the operand versus
1040       // enqueueing a trivial reshape.
1041       return operand;
1042     }
1043 
1044     // Out-of-order collapse is not supported.
1045     // Checks that the collapsed dimensions are in order and consecutive.
1046     for (absl::Span<const int64>::size_type i = 1; i < dimensions.size(); ++i) {
1047       if (dimensions[i] - 1 != dimensions[i - 1]) {
1048         return InvalidArgument(
1049             "Collapsed dimensions are not in consecutive order.");
1050       }
1051     }
1052 
1053     // Create a new sizes vector from the old shape, replacing the collapsed
1054     // dimensions by the product of their sizes.
1055     TF_ASSIGN_OR_RETURN(const Shape* original_shape, GetShapePtr(operand));
1056 
1057     VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape);
1058     VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ",");
1059 
1060     std::vector<int64> new_sizes;
1061     for (int i = 0; i < original_shape->rank(); ++i) {
1062       if (i <= dimensions.front() || i > dimensions.back()) {
1063         new_sizes.push_back(original_shape->dimensions(i));
1064       } else {
1065         new_sizes.back() *= original_shape->dimensions(i);
1066       }
1067     }
1068 
1069     VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]";
1070 
1071     return Reshape(operand, new_sizes);
1072   });
1073 }
1074 
Trace(const string & tag,XlaOp operand)1075 void XlaBuilder::Trace(const string& tag, XlaOp operand) {
1076   ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1077     HloInstructionProto instr;
1078     *instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
1079     *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto();
1080     return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
1081   });
1082 }
1083 
Select(XlaOp pred,XlaOp on_true,XlaOp on_false)1084 XlaOp XlaBuilder::Select(XlaOp pred, XlaOp on_true, XlaOp on_false) {
1085   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1086     TF_ASSIGN_OR_RETURN(const Shape* true_shape, GetShapePtr(on_true));
1087     TF_ASSIGN_OR_RETURN(const Shape* false_shape, GetShapePtr(on_false));
1088     TF_RET_CHECK(true_shape->IsTuple() == false_shape->IsTuple());
1089     HloOpcode opcode =
1090         true_shape->IsTuple() ? HloOpcode::kTupleSelect : HloOpcode::kSelect;
1091     return TernaryOp(opcode, pred, on_true, on_false);
1092   });
1093 }
1094 
Tuple(absl::Span<const XlaOp> elements)1095 XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
1096   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1097     HloInstructionProto instr;
1098     std::vector<const Shape*> operand_shape_ptrs;
1099     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
1100     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1101                       [](const Shape& shape) { return &shape; });
1102     TF_ASSIGN_OR_RETURN(const Shape shape,
1103                         ShapeInference::InferVariadicOpShape(
1104                             HloOpcode::kTuple, operand_shape_ptrs));
1105     *instr.mutable_shape() = shape.ToProto();
1106     return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
1107   });
1108 }
1109 
GetTupleElement(XlaOp tuple_data,int64 index)1110 XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64 index) {
1111   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1112     HloInstructionProto instr;
1113     TF_ASSIGN_OR_RETURN(const Shape* tuple_shape, GetShapePtr(tuple_data));
1114     if (!tuple_shape->IsTuple()) {
1115       return InvalidArgument(
1116           "Operand to GetTupleElement() is not a tuple; got %s",
1117           ShapeUtil::HumanString(*tuple_shape));
1118     }
1119     if (index < 0 || index >= ShapeUtil::TupleElementCount(*tuple_shape)) {
1120       return InvalidArgument(
1121           "GetTupleElement() index (%d) out of range for tuple shape %s", index,
1122           ShapeUtil::HumanString(*tuple_shape));
1123     }
1124     *instr.mutable_shape() =
1125         ShapeUtil::GetTupleElementShape(*tuple_shape, index).ToProto();
1126 
1127     instr.set_tuple_index(index);
1128 
1129     return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
1130                           {tuple_data});
1131   });
1132 }
1133 
Dot(XlaOp lhs,XlaOp rhs,const PrecisionConfig * precision_config)1134 XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
1135                       const PrecisionConfig* precision_config) {
1136   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1137     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1138 
1139     DotDimensionNumbers dimension_numbers;
1140     dimension_numbers.add_lhs_contracting_dimensions(
1141         lhs_shape->dimensions_size() == 1 ? 0 : 1);
1142     dimension_numbers.add_rhs_contracting_dimensions(0);
1143     return DotGeneral(lhs, rhs, dimension_numbers, precision_config);
1144   });
1145 }
1146 
DotGeneral(XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config)1147 XlaOp XlaBuilder::DotGeneral(XlaOp lhs, XlaOp rhs,
1148                              const DotDimensionNumbers& dimension_numbers,
1149                              const PrecisionConfig* precision_config) {
1150   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1151     HloInstructionProto instr;
1152     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1153     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1154     TF_ASSIGN_OR_RETURN(Shape shape,
1155                         ShapeInference::InferDotOpShape(*lhs_shape, *rhs_shape,
1156                                                         dimension_numbers));
1157     *instr.mutable_shape() = shape.ToProto();
1158     *instr.mutable_dot_dimension_numbers() = dimension_numbers;
1159     if (precision_config != nullptr) {
1160       *instr.mutable_precision_config() = *precision_config;
1161     }
1162     return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
1163   });
1164 }
1165 
VerifyConvolution(const Shape & lhs_shape,const Shape & rhs_shape,const ConvolutionDimensionNumbers & dimension_numbers) const1166 Status XlaBuilder::VerifyConvolution(
1167     const Shape& lhs_shape, const Shape& rhs_shape,
1168     const ConvolutionDimensionNumbers& dimension_numbers) const {
1169   if (lhs_shape.rank() != rhs_shape.rank()) {
1170     return InvalidArgument(
1171         "Convolution arguments must have same number of "
1172         "dimensions. Got: %s and %s",
1173         ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1174   }
1175   int num_dims = lhs_shape.rank();
1176   if (num_dims < 2) {
1177     return InvalidArgument(
1178         "Convolution expects argument arrays with >= 3 dimensions. "
1179         "Got: %s and %s",
1180         ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1181   }
1182   int num_spatial_dims = num_dims - 2;
1183 
1184   const auto check_spatial_dimensions =
1185       [&](const char* const field_name,
1186           const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
1187               numbers) {
1188         if (numbers.size() != num_spatial_dims) {
1189           return InvalidArgument("Expected %d elements for %s, but got %d.",
1190                                  num_spatial_dims, field_name, numbers.size());
1191         }
1192         for (int i = 0; i < numbers.size(); ++i) {
1193           if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
1194             return InvalidArgument("Convolution %s[%d] is out of bounds: %d",
1195                                    field_name, i, numbers.Get(i));
1196           }
1197         }
1198         return Status::OK();
1199       };
1200   TF_RETURN_IF_ERROR(
1201       check_spatial_dimensions("input_spatial_dimensions",
1202                                dimension_numbers.input_spatial_dimensions()));
1203   TF_RETURN_IF_ERROR(
1204       check_spatial_dimensions("kernel_spatial_dimensions",
1205                                dimension_numbers.kernel_spatial_dimensions()));
1206   return check_spatial_dimensions(
1207       "output_spatial_dimensions",
1208       dimension_numbers.output_spatial_dimensions());
1209 }
1210 
Conv(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1211 XlaOp XlaBuilder::Conv(XlaOp lhs, XlaOp rhs,
1212                        absl::Span<const int64> window_strides, Padding padding,
1213                        int64 feature_group_count, int64 batch_group_count,
1214                        const PrecisionConfig* precision_config) {
1215   return ConvWithGeneralDimensions(
1216       lhs, rhs, window_strides, padding,
1217       CreateDefaultConvDimensionNumbers(window_strides.size()),
1218       feature_group_count, batch_group_count, precision_config);
1219 }
1220 
ConvWithGeneralPadding(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1221 XlaOp XlaBuilder::ConvWithGeneralPadding(
1222     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1223     absl::Span<const std::pair<int64, int64>> padding,
1224     int64 feature_group_count, int64 batch_group_count,
1225     const PrecisionConfig* precision_config) {
1226   return ConvGeneral(lhs, rhs, window_strides, padding,
1227                      CreateDefaultConvDimensionNumbers(window_strides.size()),
1228                      feature_group_count, batch_group_count, precision_config);
1229 }
1230 
ConvWithGeneralDimensions(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1231 XlaOp XlaBuilder::ConvWithGeneralDimensions(
1232     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1233     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1234     int64 feature_group_count, int64 batch_group_count,
1235     const PrecisionConfig* precision_config) {
1236   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1237     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1238     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1239 
1240     TF_RETURN_IF_ERROR(
1241         VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers));
1242 
1243     std::vector<int64> base_area_dimensions(
1244         dimension_numbers.input_spatial_dimensions_size());
1245     for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size();
1246          ++i) {
1247       base_area_dimensions[i] =
1248           lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i));
1249     }
1250 
1251     std::vector<int64> window_dimensions(
1252         dimension_numbers.kernel_spatial_dimensions_size());
1253     for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
1254          ++i) {
1255       window_dimensions[i] =
1256           rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1257     }
1258 
1259     return ConvGeneral(lhs, rhs, window_strides,
1260                        MakePadding(base_area_dimensions, window_dimensions,
1261                                    window_strides, padding),
1262                        dimension_numbers, feature_group_count,
1263                        batch_group_count, precision_config);
1264   });
1265 }
1266 
ConvGeneral(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1267 XlaOp XlaBuilder::ConvGeneral(
1268     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1269     absl::Span<const std::pair<int64, int64>> padding,
1270     const ConvolutionDimensionNumbers& dimension_numbers,
1271     int64 feature_group_count, int64 batch_group_count,
1272     const PrecisionConfig* precision_config) {
1273   return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
1274                             dimension_numbers, feature_group_count,
1275                             batch_group_count, precision_config);
1276 }
1277 
ConvGeneralDilated(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1278 XlaOp XlaBuilder::ConvGeneralDilated(
1279     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1280     absl::Span<const std::pair<int64, int64>> padding,
1281     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1282     const ConvolutionDimensionNumbers& dimension_numbers,
1283     int64 feature_group_count, int64 batch_group_count,
1284     const PrecisionConfig* precision_config) {
1285   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1286     HloInstructionProto instr;
1287     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1288     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1289     TF_RETURN_IF_ERROR(
1290         VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers));
1291 
1292     std::vector<int64> window_dimensions(
1293         dimension_numbers.kernel_spatial_dimensions_size());
1294     for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
1295          ++i) {
1296       window_dimensions[i] =
1297           rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1298     }
1299     TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
1300                         ShapeInference::InferWindowFromDimensions(
1301                             window_dimensions, window_strides, padding,
1302                             lhs_dilation, rhs_dilation));
1303 
1304     TF_ASSIGN_OR_RETURN(
1305         Shape shape, ShapeInference::InferConvolveShape(
1306                          *lhs_shape, *rhs_shape, feature_group_count,
1307                          batch_group_count, instr.window(), dimension_numbers));
1308     *instr.mutable_shape() = shape.ToProto();
1309 
1310     *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
1311     instr.set_feature_group_count(feature_group_count);
1312     instr.set_batch_group_count(batch_group_count);
1313 
1314     if (precision_config != nullptr) {
1315       *instr.mutable_precision_config() = *precision_config;
1316     }
1317 
1318     return AddInstruction(std::move(instr), HloOpcode::kConvolution,
1319                           {lhs, rhs});
1320   });
1321 }
1322 
Fft(XlaOp operand,const FftType fft_type,const absl::Span<const int64> fft_length)1323 XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type,
1324                       const absl::Span<const int64> fft_length) {
1325   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1326     HloInstructionProto instr;
1327     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1328     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape(
1329                                          *operand_shape, fft_type, fft_length));
1330     *instr.mutable_shape() = shape.ToProto();
1331     instr.set_fft_type(fft_type);
1332     for (int64 i : fft_length) {
1333       instr.add_fft_length(i);
1334     }
1335 
1336     return AddInstruction(std::move(instr), HloOpcode::kFft, {operand});
1337   });
1338 }
1339 
Infeed(const Shape & shape,const string & config)1340 XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
1341   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1342     HloInstructionProto instr;
1343     if (!LayoutUtil::HasLayout(shape)) {
1344       return InvalidArgument("Given shape to Infeed must have a layout");
1345     }
1346     const Shape infeed_instruction_shape =
1347         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1348     *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1349     instr.set_infeed_config(config);
1350 
1351     if (shape.IsArray() && sharding() &&
1352         sharding()->type() == OpSharding::OTHER) {
1353       // TODO(b/110793772): Support tiled array-shaped infeeds.
1354       return InvalidArgument(
1355           "Tiled sharding is not yet supported for array-shaped infeeds");
1356     }
1357 
1358     if (sharding() && sharding()->type() == OpSharding::REPLICATED) {
1359       return InvalidArgument(
1360           "Replicated sharding is not yet supported for infeeds");
1361     }
1362 
1363     // Infeed takes a single token operand. Generate the token to pass to the
1364     // infeed.
1365     XlaOp token;
1366     auto make_token = [&]() {
1367       HloInstructionProto token_instr;
1368       *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1369       return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {});
1370     };
1371     if (sharding()) {
1372       // Arbitrarily assign token to device 0.
1373       OpSharding sharding = sharding_builder::AssignDevice(0);
1374       XlaScopedShardingAssignment scoped_sharding(this, sharding);
1375       TF_ASSIGN_OR_RETURN(token, make_token());
1376     } else {
1377       TF_ASSIGN_OR_RETURN(token, make_token());
1378     }
1379 
1380     // The sharding is set by the client according to the data tuple shape.
1381     // However, the shape of the infeed instruction is a tuple containing the
1382     // data and a token. For tuple sharding type, the sharding must be changed
1383     // to accommodate the token.
1384     XlaOp infeed;
1385     if (sharding() && sharding()->type() == OpSharding::TUPLE) {
1386       // TODO(b/80000000): Remove this when clients have been updated to handle
1387       // tokens.
1388       OpSharding infeed_instruction_sharding = *sharding();
1389       // Arbitrarily assign the token to device 0.
1390       *infeed_instruction_sharding.add_tuple_shardings() =
1391           sharding_builder::AssignDevice(0);
1392       XlaScopedShardingAssignment scoped_sharding(this,
1393                                                   infeed_instruction_sharding);
1394       TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1395                                                  HloOpcode::kInfeed, {token}));
1396     } else {
1397       TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1398                                                  HloOpcode::kInfeed, {token}));
1399     }
1400 
1401     // The infeed instruction produces a tuple of the infed data and a token
1402     // type. Return XLA op containing the data.
1403     // TODO(b/80000000): Remove this when clients have been updated to handle
1404     // tokens.
1405     HloInstructionProto infeed_data;
1406     *infeed_data.mutable_shape() = shape.ToProto();
1407     infeed_data.set_tuple_index(0);
1408     return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement,
1409                           {infeed});
1410   });
1411 }
1412 
InfeedWithToken(XlaOp token,const Shape & shape,const string & config)1413 XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape,
1414                                   const string& config) {
1415   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1416     HloInstructionProto instr;
1417     if (!LayoutUtil::HasLayout(shape)) {
1418       return InvalidArgument("Given shape to Infeed must have a layout");
1419     }
1420     const Shape infeed_instruction_shape =
1421         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1422     *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1423     instr.set_infeed_config(config);
1424 
1425     if (shape.IsArray() && sharding() &&
1426         sharding()->type() == OpSharding::OTHER) {
1427       // TODO(b/110793772): Support tiled array-shaped infeeds.
1428       return InvalidArgument(
1429           "Tiled sharding is not yet supported for array-shaped infeeds");
1430     }
1431 
1432     if (sharding() && sharding()->type() == OpSharding::REPLICATED) {
1433       return InvalidArgument(
1434           "Replicated sharding is not yet supported for infeeds");
1435     }
1436 
1437     return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
1438   });
1439 }
1440 
Outfeed(XlaOp operand,const Shape & shape_with_layout,const string & outfeed_config)1441 void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout,
1442                          const string& outfeed_config) {
1443   ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1444     HloInstructionProto instr;
1445 
1446     *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1447 
1448     // Check and set outfeed shape.
1449     if (!LayoutUtil::HasLayout(shape_with_layout)) {
1450       return InvalidArgument("Given shape to Outfeed must have a layout");
1451     }
1452     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1453     if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
1454       return InvalidArgument(
1455           "Outfeed shape %s must be compatible with operand shape %s",
1456           ShapeUtil::HumanStringWithLayout(shape_with_layout),
1457           ShapeUtil::HumanStringWithLayout(*operand_shape));
1458     }
1459     *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1460 
1461     instr.set_outfeed_config(outfeed_config);
1462 
1463     // Outfeed takes a token as its second operand. Generate the token to pass
1464     // to the outfeed.
1465     HloInstructionProto token_instr;
1466     *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1467     TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
1468                                                     HloOpcode::kAfterAll, {}));
1469 
1470     TF_RETURN_IF_ERROR(
1471         AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand, token})
1472             .status());
1473 
1474     // The outfeed instruction produces a token. However, existing users expect
1475     // a nil shape (empty tuple). This should only be relevant if the outfeed is
1476     // the root of a computation.
1477     // TODO(b/80000000): Remove this when clients have been updated to handle
1478     // tokens.
1479     HloInstructionProto tuple_instr;
1480     *tuple_instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
1481 
1482     // The dummy tuple should have no sharding.
1483     {
1484       XlaScopedShardingAssignment scoped_sharding(this, OpSharding());
1485       TF_ASSIGN_OR_RETURN(
1486           XlaOp empty_tuple,
1487           AddInstruction(std::move(tuple_instr), HloOpcode::kTuple, {}));
1488       return empty_tuple;
1489     }
1490   });
1491 }
1492 
OutfeedWithToken(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)1493 XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token,
1494                                    const Shape& shape_with_layout,
1495                                    const string& outfeed_config) {
1496   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1497     HloInstructionProto instr;
1498 
1499     *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1500 
1501     // Check and set outfeed shape.
1502     if (!LayoutUtil::HasLayout(shape_with_layout)) {
1503       return InvalidArgument("Given shape to Outfeed must have a layout");
1504     }
1505     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1506     if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
1507       return InvalidArgument(
1508           "Outfeed shape %s must be compatible with operand shape %s",
1509           ShapeUtil::HumanStringWithLayout(shape_with_layout),
1510           ShapeUtil::HumanStringWithLayout(*operand_shape));
1511     }
1512     *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1513 
1514     instr.set_outfeed_config(outfeed_config);
1515 
1516     return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
1517                           {operand, token});
1518   });
1519 }
1520 
CreateToken()1521 XlaOp XlaBuilder::CreateToken() {
1522   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1523     HloInstructionProto instr;
1524     *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1525     return AddInstruction(std::move(instr), HloOpcode::kAfterAll);
1526   });
1527 }
1528 
AfterAll(absl::Span<const XlaOp> tokens)1529 XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
1530   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1531     if (tokens.empty()) {
1532       return InvalidArgument("AfterAll requires at least one operand");
1533     }
1534     for (int i = 0; i < tokens.size(); ++i) {
1535       XlaOp operand = tokens[i];
1536       TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1537       if (!operand_shape->IsToken()) {
1538         return InvalidArgument(
1539             "All operands to AfterAll must be tokens; operand %d has shape %s",
1540             i, ShapeUtil::HumanString(*operand_shape));
1541       }
1542     }
1543     HloInstructionProto instr;
1544     *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1545     return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens);
1546   });
1547 }
1548 
CustomCall(const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout)1549 XlaOp XlaBuilder::CustomCall(
1550     const string& call_target_name, absl::Span<const XlaOp> operands,
1551     const Shape& shape, const string& opaque,
1552     absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
1553   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1554     HloInstructionProto instr;
1555     if (absl::StartsWith(call_target_name, "$")) {
1556       return InvalidArgument(
1557           "Invalid custom_call_target \"%s\": Call targets that start with '$' "
1558           "are reserved for internal use.",
1559           call_target_name);
1560     }
1561     *instr.mutable_shape() = shape.ToProto();
1562     instr.set_custom_call_target(call_target_name);
1563     instr.set_backend_config(opaque);
1564     if (operand_shapes_with_layout.has_value()) {
1565       if (!LayoutUtil::HasLayout(shape)) {
1566         return InvalidArgument(
1567             "Result shape must have layout for custom call with constrained "
1568             "layout.");
1569       }
1570       if (operands.size() != operand_shapes_with_layout->size()) {
1571         return InvalidArgument(
1572             "Must specify a shape with layout for each operand for custom call "
1573             "with constrained layout; given %d shapes, expected %d",
1574             operand_shapes_with_layout->size(), operands.size());
1575       }
1576       instr.set_constrain_layout(true);
1577       int64 operand_num = 0;
1578       for (const Shape& operand_shape : *operand_shapes_with_layout) {
1579         if (!LayoutUtil::HasLayout(operand_shape)) {
1580           return InvalidArgument(
1581               "No layout specified for operand %d for custom call with "
1582               "constrained layout.",
1583               operand_num);
1584         }
1585         *instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
1586         ++operand_num;
1587       }
1588     }
1589     return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
1590   });
1591 }
1592 
Transpose(XlaOp operand,absl::Span<const int64> permutation)1593 XlaOp XlaBuilder::Transpose(XlaOp operand,
1594                             absl::Span<const int64> permutation) {
1595   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1596     HloInstructionProto instr;
1597     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1598     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape(
1599                                          *operand_shape, permutation));
1600     *instr.mutable_shape() = shape.ToProto();
1601     for (int64 dim : permutation) {
1602       instr.add_dimensions(dim);
1603     }
1604     return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
1605   });
1606 }
1607 
Rev(XlaOp operand,absl::Span<const int64> dimensions)1608 XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span<const int64> dimensions) {
1609   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1610     HloInstructionProto instr;
1611     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1612     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape(
1613                                          *operand_shape, dimensions));
1614     *instr.mutable_shape() = shape.ToProto();
1615     for (int64 dim : dimensions) {
1616       instr.add_dimensions(dim);
1617     }
1618     return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand});
1619   });
1620 }
1621 
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64 dimension,bool is_stable)1622 XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
1623                        const XlaComputation& comparator, int64 dimension,
1624                        bool is_stable) {
1625   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1626     HloInstructionProto instr;
1627     instr.set_is_stable(is_stable);
1628     std::vector<const Shape*> operand_shape_ptrs;
1629     TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes,
1630                         GetOperandShapes(operands));
1631     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1632                       [](const Shape& shape) { return &shape; });
1633     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape(
1634                                          HloOpcode::kSort, operand_shape_ptrs));
1635     *instr.mutable_shape() = shape.ToProto();
1636     if (dimension == -1) {
1637       TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0]));
1638       dimension = keys_shape->rank() - 1;
1639     }
1640     instr.add_dimensions(dimension);
1641     AddCalledComputation(comparator, &instr);
1642     return AddInstruction(std::move(instr), HloOpcode::kSort, operands);
1643   });
1644 }
1645 
ConvertElementType(XlaOp operand,PrimitiveType new_element_type)1646 XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
1647                                      PrimitiveType new_element_type) {
1648   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1649     HloInstructionProto instr;
1650     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1651     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
1652                                          *operand_shape, new_element_type));
1653     *instr.mutable_shape() = shape.ToProto();
1654     return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand});
1655   });
1656 }
1657 
BitcastConvertType(XlaOp operand,PrimitiveType new_element_type)1658 XlaOp XlaBuilder::BitcastConvertType(XlaOp operand,
1659                                      PrimitiveType new_element_type) {
1660   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1661     HloInstructionProto instr;
1662     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1663     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
1664                                          *operand_shape, new_element_type));
1665     *instr.mutable_shape() = shape.ToProto();
1666     return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
1667                           {operand});
1668   });
1669 }
1670 
Clamp(XlaOp min,XlaOp operand,XlaOp max)1671 XlaOp XlaBuilder::Clamp(XlaOp min, XlaOp operand, XlaOp max) {
1672   return TernaryOp(HloOpcode::kClamp, min, operand, max);
1673 }
1674 
Map(absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64> dimensions,absl::Span<const XlaOp> static_operands)1675 XlaOp XlaBuilder::Map(absl::Span<const XlaOp> operands,
1676                       const XlaComputation& computation,
1677                       absl::Span<const int64> dimensions,
1678                       absl::Span<const XlaOp> static_operands) {
1679   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1680     if (!static_operands.empty()) {
1681       return Unimplemented("static_operands is not supported in Map");
1682     }
1683 
1684     HloInstructionProto instr;
1685     std::vector<const Shape*> operand_shape_ptrs;
1686     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
1687     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1688                       [](const Shape& shape) { return &shape; });
1689     TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
1690                         computation.GetProgramShape());
1691     TF_ASSIGN_OR_RETURN(
1692         Shape shape, ShapeInference::InferMapShape(
1693                          operand_shape_ptrs, called_program_shape, dimensions));
1694     *instr.mutable_shape() = shape.ToProto();
1695 
1696     Shape output_shape(instr.shape());
1697     const int64 output_rank = output_shape.rank();
1698     AddCalledComputation(computation, &instr);
1699     std::vector<XlaOp> new_operands(operands.begin(), operands.end());
1700     for (XlaOp& new_operand : new_operands) {
1701       TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(new_operand));
1702       const int64 rank = shape->rank();
1703       if (rank != output_rank) {
1704         TF_ASSIGN_OR_RETURN(new_operand,
1705                             InDimBroadcast(output_shape, new_operand, {}));
1706         TF_ASSIGN_OR_RETURN(shape, GetShapePtr(new_operand));
1707       }
1708       if (!ShapeUtil::SameDimensions(output_shape, *shape)) {
1709         TF_ASSIGN_OR_RETURN(new_operand,
1710                             AddBroadcastSequence(output_shape, new_operand));
1711       }
1712     }
1713 
1714     return AddInstruction(std::move(instr), HloOpcode::kMap, new_operands);
1715   });
1716 }
1717 
RngOp(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)1718 XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
1719                         absl::Span<const XlaOp> parameters,
1720                         const Shape& shape) {
1721   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1722     HloInstructionProto instr;
1723 
1724     // Check the number of parameters per RNG distribution.
1725     switch (distribution) {
1726       case RandomDistribution::RNG_NORMAL:
1727       case RandomDistribution::RNG_UNIFORM:
1728         if (parameters.size() != 2) {
1729           return InvalidArgument(
1730               "RNG distribution (%s) expects 2 parameters, but got %ld",
1731               RandomDistribution_Name(distribution), parameters.size());
1732         }
1733         break;
1734       default:
1735         LOG(FATAL) << "unhandled distribution " << distribution;
1736     }
1737 
1738     TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
1739     *instr.mutable_shape() = shape.ToProto();
1740 
1741     instr.set_distribution(distribution);
1742 
1743     return AddInstruction(std::move(instr), HloOpcode::kRng, parameters);
1744   });
1745 }
1746 
RngNormal(XlaOp mu,XlaOp sigma,const Shape & shape)1747 XlaOp XlaBuilder::RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape) {
1748   return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
1749 }
1750 
RngUniform(XlaOp a,XlaOp b,const Shape & shape)1751 XlaOp XlaBuilder::RngUniform(XlaOp a, XlaOp b, const Shape& shape) {
1752   return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
1753 }
1754 
While(const XlaComputation & condition,const XlaComputation & body,XlaOp init)1755 XlaOp XlaBuilder::While(const XlaComputation& condition,
1756                         const XlaComputation& body, XlaOp init) {
1757   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1758     HloInstructionProto instr;
1759 
1760     // Infer shape.
1761     TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape());
1762     TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
1763                         condition.GetProgramShape());
1764     TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init));
1765     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape(
1766                                          condition_program_shape,
1767                                          body_program_shape, *init_shape));
1768     *instr.mutable_shape() = shape.ToProto();
1769     // Body comes before condition computation in the vector.
1770     AddCalledComputation(body, &instr);
1771     AddCalledComputation(condition, &instr);
1772     return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
1773   });
1774 }
1775 
Gather(XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)1776 XlaOp XlaBuilder::Gather(XlaOp input, XlaOp start_indices,
1777                          const GatherDimensionNumbers& dimension_numbers,
1778                          absl::Span<const int64> slice_sizes,
1779                          bool indices_are_sorted) {
1780   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1781     HloInstructionProto instr;
1782     instr.set_indices_are_sorted(indices_are_sorted);
1783 
1784     TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
1785     TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape,
1786                         GetShapePtr(start_indices));
1787     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGatherShape(
1788                                          *input_shape, *start_indices_shape,
1789                                          dimension_numbers, slice_sizes));
1790     *instr.mutable_shape() = shape.ToProto();
1791 
1792     *instr.mutable_gather_dimension_numbers() = dimension_numbers;
1793     for (int64 bound : slice_sizes) {
1794       instr.add_gather_slice_sizes(bound);
1795     }
1796 
1797     return AddInstruction(std::move(instr), HloOpcode::kGather,
1798                           {input, start_indices});
1799   });
1800 }
1801 
Scatter(XlaOp input,XlaOp scatter_indices,XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)1802 XlaOp XlaBuilder::Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
1803                           const XlaComputation& update_computation,
1804                           const ScatterDimensionNumbers& dimension_numbers,
1805                           bool indices_are_sorted, bool unique_indices) {
1806   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1807     HloInstructionProto instr;
1808     instr.set_indices_are_sorted(indices_are_sorted);
1809 
1810     instr.set_unique_indices(unique_indices);
1811 
1812     TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
1813     TF_ASSIGN_OR_RETURN(const Shape* scatter_indices_shape,
1814                         GetShapePtr(scatter_indices));
1815     TF_ASSIGN_OR_RETURN(const Shape* updates_shape, GetShapePtr(updates));
1816     TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
1817                         update_computation.GetProgramShape());
1818     TF_ASSIGN_OR_RETURN(
1819         Shape shape, ShapeInference::InferScatterShape(
1820                          *input_shape, *scatter_indices_shape, *updates_shape,
1821                          to_apply_shape, dimension_numbers));
1822     *instr.mutable_shape() = shape.ToProto();
1823 
1824     *instr.mutable_scatter_dimension_numbers() = dimension_numbers;
1825 
1826     AddCalledComputation(update_computation, &instr);
1827     return AddInstruction(std::move(instr), HloOpcode::kScatter,
1828                           {input, scatter_indices, updates});
1829   });
1830 }
1831 
Conditional(XlaOp predicate,XlaOp true_operand,const XlaComputation & true_computation,XlaOp false_operand,const XlaComputation & false_computation)1832 XlaOp XlaBuilder::Conditional(XlaOp predicate, XlaOp true_operand,
1833                               const XlaComputation& true_computation,
1834                               XlaOp false_operand,
1835                               const XlaComputation& false_computation) {
1836   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1837     TF_ASSIGN_OR_RETURN(const xla::Shape* shape, GetShapePtr(predicate));
1838 
1839     if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != PRED) {
1840       return InvalidArgument(
1841           "Argument to predicated-Conditional is not a scalar of PRED type "
1842           "(%s).",
1843           ShapeUtil::HumanString(*shape));
1844     }
1845     // The index of true_computation must be 0 and that of false computation
1846     // must be 1.
1847     return ConditionalImpl(predicate, {&true_computation, &false_computation},
1848                            {true_operand, false_operand});
1849   });
1850 }
1851 
Conditional(XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)1852 XlaOp XlaBuilder::Conditional(
1853     XlaOp branch_index,
1854     absl::Span<const XlaComputation* const> branch_computations,
1855     absl::Span<const XlaOp> branch_operands) {
1856   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1857     TF_ASSIGN_OR_RETURN(const xla::Shape* shape, GetShapePtr(branch_index));
1858 
1859     if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != S32) {
1860       return InvalidArgument(
1861           "Argument to indexed-Conditional is not a scalar of S32 type (%s).",
1862           ShapeUtil::HumanString(*shape));
1863     }
1864     return ConditionalImpl(branch_index, branch_computations, branch_operands);
1865   });
1866 }
1867 
ConditionalImpl(XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)1868 XlaOp XlaBuilder::ConditionalImpl(
1869     XlaOp branch_index,
1870     absl::Span<const XlaComputation* const> branch_computations,
1871     absl::Span<const XlaOp> branch_operands) {
1872   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1873     HloInstructionProto instr;
1874 
1875     TF_ASSIGN_OR_RETURN(const Shape* branch_index_shape,
1876                         GetShapePtr(branch_index));
1877     std::vector<Shape> branch_operand_shapes(branch_operands.size());
1878     std::vector<ProgramShape> branch_computation_shapes(
1879         branch_computations.size());
1880     for (int j = 0; j < branch_operands.size(); ++j) {
1881       TF_ASSIGN_OR_RETURN(branch_operand_shapes[j],
1882                           GetShape(branch_operands[j]));
1883       TF_ASSIGN_OR_RETURN(branch_computation_shapes[j],
1884                           branch_computations[j]->GetProgramShape());
1885     }
1886     TF_ASSIGN_OR_RETURN(const Shape shape,
1887                         ShapeInference::InferConditionalShape(
1888                             *branch_index_shape, branch_computation_shapes,
1889                             branch_operand_shapes));
1890     *instr.mutable_shape() = shape.ToProto();
1891 
1892     for (const XlaComputation* branch_computation : branch_computations) {
1893       AddCalledComputation(*branch_computation, &instr);
1894     }
1895 
1896     std::vector<XlaOp> operands(1, branch_index);
1897     for (const XlaOp branch_operand : branch_operands) {
1898       operands.emplace_back(branch_operand);
1899     }
1900     return AddInstruction(std::move(instr), HloOpcode::kConditional,
1901                           absl::MakeSpan(operands));
1902   });
1903 }
1904 
Reduce(XlaOp operand,XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)1905 XlaOp XlaBuilder::Reduce(XlaOp operand, XlaOp init_value,
1906                          const XlaComputation& computation,
1907                          absl::Span<const int64> dimensions_to_reduce) {
1908   return Reduce(absl::Span<const XlaOp>({operand}),
1909                 absl::Span<const XlaOp>({init_value}), computation,
1910                 dimensions_to_reduce);
1911 }
1912 
Reduce(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)1913 XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
1914                          absl::Span<const XlaOp> init_values,
1915                          const XlaComputation& computation,
1916                          absl::Span<const int64> dimensions_to_reduce) {
1917   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1918     HloInstructionProto instr;
1919 
1920     TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
1921                         computation.GetProgramShape());
1922 
1923     std::vector<XlaOp> all_operands;
1924     all_operands.insert(all_operands.end(), operands.begin(), operands.end());
1925     all_operands.insert(all_operands.end(), init_values.begin(),
1926                         init_values.end());
1927 
1928     std::vector<const Shape*> operand_shape_ptrs;
1929     TF_ASSIGN_OR_RETURN(const auto& operand_shapes,
1930                         GetOperandShapes(all_operands));
1931     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1932                       [](const Shape& shape) { return &shape; });
1933 
1934     TF_ASSIGN_OR_RETURN(
1935         Shape shape,
1936         ShapeInference::InferReduceShape(
1937             operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
1938     *instr.mutable_shape() = shape.ToProto();
1939 
1940     for (int64 dim : dimensions_to_reduce) {
1941       instr.add_dimensions(dim);
1942     }
1943 
1944     AddCalledComputation(computation, &instr);
1945 
1946     return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands);
1947   });
1948 }
1949 
ReduceAll(XlaOp operand,XlaOp init_value,const XlaComputation & computation)1950 XlaOp XlaBuilder::ReduceAll(XlaOp operand, XlaOp init_value,
1951                             const XlaComputation& computation) {
1952   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1953     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1954     std::vector<int64> all_dimnos(operand_shape->rank());
1955     std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
1956     return Reduce(operand, init_value, computation, all_dimnos);
1957   });
1958 }
1959 
ReduceWindow(XlaOp operand,XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)1960 XlaOp XlaBuilder::ReduceWindow(XlaOp operand, XlaOp init_value,
1961                                const XlaComputation& computation,
1962                                absl::Span<const int64> window_dimensions,
1963                                absl::Span<const int64> window_strides,
1964                                Padding padding) {
1965   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1966     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1967     TF_RETURN_IF_ERROR(
1968         ValidatePaddingValues(AsInt64Slice(operand_shape->dimensions()),
1969                               window_dimensions, window_strides));
1970 
1971     std::vector<std::pair<int64, int64>> padding_values =
1972         MakePadding(AsInt64Slice(operand_shape->dimensions()),
1973                     window_dimensions, window_strides, padding);
1974     return ReduceWindowWithGeneralPadding(
1975         operand, init_value, computation, window_dimensions, window_strides,
1976         /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values);
1977   });
1978 }
1979 
ReduceWindowWithGeneralPadding(XlaOp operand,XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)1980 XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
1981     XlaOp operand, XlaOp init_value, const XlaComputation& computation,
1982     absl::Span<const int64> window_dimensions,
1983     absl::Span<const int64> window_strides,
1984     absl::Span<const int64> base_dilations,
1985     absl::Span<const int64> window_dilations,
1986     absl::Span<const std::pair<int64, int64>> padding) {
1987   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1988     HloInstructionProto instr;
1989 
1990     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1991     TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
1992     TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
1993                         computation.GetProgramShape());
1994     TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
1995                         ShapeInference::InferWindowFromDimensions(
1996                             window_dimensions, window_strides, padding,
1997                             /*lhs_dilation=*/base_dilations,
1998                             /*rhs_dilation=*/window_dilations));
1999     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape(
2000                                          *operand_shape, *init_shape,
2001                                          instr.window(), to_apply_shape));
2002     *instr.mutable_shape() = shape.ToProto();
2003 
2004     AddCalledComputation(computation, &instr);
2005     return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
2006                           {operand, init_value});
2007   });
2008 }
2009 
BatchNormTraining(XlaOp operand,XlaOp scale,XlaOp offset,float epsilon,int64 feature_index)2010 XlaOp XlaBuilder::BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
2011                                     float epsilon, int64 feature_index) {
2012   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2013     HloInstructionProto instr;
2014 
2015     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2016     TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2017     TF_ASSIGN_OR_RETURN(const Shape* offset_shape, GetShapePtr(offset));
2018     TF_ASSIGN_OR_RETURN(
2019         Shape shape,
2020         ShapeInference::InferBatchNormTrainingShape(
2021             *operand_shape, *scale_shape, *offset_shape, feature_index));
2022     *instr.mutable_shape() = shape.ToProto();
2023 
2024     instr.set_epsilon(epsilon);
2025     instr.set_feature_index(feature_index);
2026 
2027     return AddInstruction(std::move(instr), HloOpcode::kBatchNormTraining,
2028                           {operand, scale, offset});
2029   });
2030 }
2031 
BatchNormInference(XlaOp operand,XlaOp scale,XlaOp offset,XlaOp mean,XlaOp variance,float epsilon,int64 feature_index)2032 XlaOp XlaBuilder::BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset,
2033                                      XlaOp mean, XlaOp variance, float epsilon,
2034                                      int64 feature_index) {
2035   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2036     HloInstructionProto instr;
2037 
2038     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2039     TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2040     TF_ASSIGN_OR_RETURN(const Shape* offset_shape, GetShapePtr(offset));
2041     TF_ASSIGN_OR_RETURN(const Shape* mean_shape, GetShapePtr(mean));
2042     TF_ASSIGN_OR_RETURN(const Shape* variance_shape, GetShapePtr(variance));
2043     TF_ASSIGN_OR_RETURN(Shape shape,
2044                         ShapeInference::InferBatchNormInferenceShape(
2045                             *operand_shape, *scale_shape, *offset_shape,
2046                             *mean_shape, *variance_shape, feature_index));
2047     *instr.mutable_shape() = shape.ToProto();
2048 
2049     instr.set_epsilon(epsilon);
2050     instr.set_feature_index(feature_index);
2051 
2052     return AddInstruction(std::move(instr), HloOpcode::kBatchNormInference,
2053                           {operand, scale, offset, mean, variance});
2054   });
2055 }
2056 
BatchNormGrad(XlaOp operand,XlaOp scale,XlaOp batch_mean,XlaOp batch_var,XlaOp grad_output,float epsilon,int64 feature_index)2057 XlaOp XlaBuilder::BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
2058                                 XlaOp batch_var, XlaOp grad_output,
2059                                 float epsilon, int64 feature_index) {
2060   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2061     HloInstructionProto instr;
2062 
2063     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2064     TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2065     TF_ASSIGN_OR_RETURN(const Shape* batch_mean_shape, GetShapePtr(batch_mean));
2066     TF_ASSIGN_OR_RETURN(const Shape* batch_var_shape, GetShapePtr(batch_var));
2067     TF_ASSIGN_OR_RETURN(const Shape* grad_output_shape,
2068                         GetShapePtr(grad_output));
2069     TF_ASSIGN_OR_RETURN(
2070         Shape shape, ShapeInference::InferBatchNormGradShape(
2071                          *operand_shape, *scale_shape, *batch_mean_shape,
2072                          *batch_var_shape, *grad_output_shape, feature_index));
2073     *instr.mutable_shape() = shape.ToProto();
2074 
2075     instr.set_epsilon(epsilon);
2076     instr.set_feature_index(feature_index);
2077 
2078     return AddInstruction(std::move(instr), HloOpcode::kBatchNormGrad,
2079                           {operand, scale, batch_mean, batch_var, grad_output});
2080   });
2081 }
2082 
CrossReplicaSum(XlaOp operand,absl::Span<const ReplicaGroup> replica_groups)2083 XlaOp XlaBuilder::CrossReplicaSum(
2084     XlaOp operand, absl::Span<const ReplicaGroup> replica_groups) {
2085   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2086     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
2087     const Shape* element_shape;
2088     if (shape->IsTuple()) {
2089       if (shape->tuple_shapes_size() == 0) {
2090         return Unimplemented(
2091             "0 element tuple CrossReplicaSum is not supported");
2092       }
2093       element_shape = &shape->tuple_shapes(0);
2094     } else {
2095       element_shape = shape;
2096     }
2097     const Shape scalar_shape =
2098         ShapeUtil::MakeShape(element_shape->element_type(), {});
2099     auto b = CreateSubBuilder("sum");
2100     auto x = b->Parameter(/*parameter_number=*/0, scalar_shape, "x");
2101     auto y = b->Parameter(/*parameter_number=*/1, scalar_shape, "y");
2102     if (scalar_shape.element_type() == PRED) {
2103       Or(x, y);
2104     } else {
2105       Add(x, y);
2106     }
2107     TF_ASSIGN_OR_RETURN(auto computation, b->Build());
2108     return AllReduce(operand, computation, replica_groups,
2109                      /*channel_id=*/absl::nullopt);
2110   });
2111 }
2112 
AllReduce(XlaOp operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Shape> & shape_with_layout)2113 XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
2114                             absl::Span<const ReplicaGroup> replica_groups,
2115                             const absl::optional<ChannelHandle>& channel_id,
2116                             const absl::optional<Shape>& shape_with_layout) {
2117   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2118     HloInstructionProto instr;
2119     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2120     std::vector<const Shape*> operand_shapes;
2121     std::vector<XlaOp> operands;
2122     if (operand_shape->IsTuple()) {
2123       if (operand_shape->tuple_shapes_size() == 0) {
2124         return Unimplemented("0 element tuple AllReduce is not supported");
2125       }
2126       for (int64 i = 0; i < operand_shape->tuple_shapes_size(); ++i) {
2127         if (operand_shape->tuple_shapes(i).element_type() !=
2128             operand_shape->tuple_shapes(0).element_type()) {
2129           return Unimplemented(
2130               "All the shapes of a tuple input of AllReduce must have the same "
2131               "element type");
2132         }
2133         operand_shapes.push_back(&operand_shape->tuple_shapes(i));
2134         operands.push_back(GetTupleElement(operand, i));
2135       }
2136     } else {
2137       operand_shapes.push_back(operand_shape);
2138       operands.push_back(operand);
2139     }
2140 
2141     TF_ASSIGN_OR_RETURN(Shape inferred_shape,
2142                         ShapeInference::InferAllReduceShape(operand_shapes));
2143     if (shape_with_layout) {
2144       if (!LayoutUtil::HasLayout(*shape_with_layout)) {
2145         return InvalidArgument("shape_with_layout must have the layout set: %s",
2146                                shape_with_layout->ToString());
2147       }
2148       if (!ShapeUtil::Compatible(*shape_with_layout, *operand_shape)) {
2149         return InvalidArgument(
2150             "Provided shape_with_layout must be compatible with the "
2151             "operand shape: %s vs %s",
2152             shape_with_layout->ToString(), operand_shape->ToString());
2153       }
2154       instr.set_constrain_layout(true);
2155       if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
2156         // For a single-element tuple, take the tuple element shape.
2157         TF_RET_CHECK(shape_with_layout->tuple_shapes_size() == 1);
2158         *instr.mutable_shape() = shape_with_layout->tuple_shapes(0).ToProto();
2159       } else {
2160         *instr.mutable_shape() = shape_with_layout->ToProto();
2161       }
2162     } else {
2163       *instr.mutable_shape() = inferred_shape.ToProto();
2164     }
2165 
2166     for (const ReplicaGroup& group : replica_groups) {
2167       *instr.add_replica_groups() = group;
2168     }
2169 
2170     if (channel_id.has_value()) {
2171       instr.set_channel_id(channel_id->handle());
2172     }
2173 
2174     AddCalledComputation(computation, &instr);
2175 
2176     TF_ASSIGN_OR_RETURN(
2177         auto all_reduce,
2178         AddInstruction(std::move(instr), HloOpcode::kAllReduce, operands));
2179     if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
2180       // For a single-element tuple, wrap the result into a tuple.
2181       TF_RET_CHECK(operand_shapes.size() == 1);
2182       TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], inferred_shape));
2183       return Tuple({all_reduce});
2184     }
2185     return all_reduce;
2186   });
2187 }
2188 
AllToAll(XlaOp operand,int64 split_dimension,int64 concat_dimension,int64 split_count,const std::vector<ReplicaGroup> & replica_groups)2189 XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension,
2190                            int64 concat_dimension, int64 split_count,
2191                            const std::vector<ReplicaGroup>& replica_groups) {
2192   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2193     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2194 
2195     // The HloInstruction for Alltoall currently only handles the data
2196     // communication: it accepts N already split parts and scatters them to N
2197     // cores, and each core gathers the N received parts into a tuple as the
2198     // output. So here we explicitly split the operand before the hlo alltoall,
2199     // and concat the tuple elements.
2200     //
2201     // First, run shape inference to make sure the shapes are valid.
2202     TF_RETURN_IF_ERROR(
2203         ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
2204                                            concat_dimension, split_count)
2205             .status());
2206 
2207     // Split into N parts.
2208     std::vector<XlaOp> slices;
2209     slices.reserve(split_count);
2210     const int64 block_size =
2211         operand_shape->dimensions(split_dimension) / split_count;
2212     for (int i = 0; i < split_count; i++) {
2213       slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size,
2214                                   /*limit_index=*/(i + 1) * block_size,
2215                                   /*stride=*/1, /*dimno=*/split_dimension));
2216     }
2217 
2218     // Handle data communication.
2219     HloInstructionProto instr;
2220     TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices));
2221     std::vector<const Shape*> slice_shape_ptrs;
2222     absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs),
2223                       [](const Shape& shape) { return &shape; });
2224     TF_ASSIGN_OR_RETURN(
2225         Shape shape, ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
2226     *instr.mutable_shape() = shape.ToProto();
2227     for (const ReplicaGroup& group : replica_groups) {
2228       *instr.add_replica_groups() = group;
2229     }
2230     TF_ASSIGN_OR_RETURN(
2231         XlaOp alltoall,
2232         AddInstruction(std::move(instr), HloOpcode::kAllToAll, slices));
2233 
2234     // Concat the N received parts.
2235     std::vector<XlaOp> received;
2236     received.reserve(split_count);
2237     for (int i = 0; i < split_count; i++) {
2238       received.push_back(this->GetTupleElement(alltoall, i));
2239     }
2240     return this->ConcatInDim(received, concat_dimension);
2241   });
2242 }
2243 
CollectivePermute(XlaOp operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)2244 XlaOp XlaBuilder::CollectivePermute(
2245     XlaOp operand,
2246     const std::vector<std::pair<int64, int64>>& source_target_pairs) {
2247   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2248     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2249     HloInstructionProto instr;
2250     TF_ASSIGN_OR_RETURN(
2251         Shape shape,
2252         ShapeInference::InferCollectivePermuteShape(*operand_shape));
2253     *instr.mutable_shape() = shape.ToProto();
2254 
2255     for (const auto& pair : source_target_pairs) {
2256       auto* proto_pair = instr.add_source_target_pairs();
2257       proto_pair->set_source(pair.first);
2258       proto_pair->set_target(pair.second);
2259     }
2260 
2261     return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute,
2262                           {operand});
2263   });
2264 }
2265 
ReplicaId()2266 XlaOp XlaBuilder::ReplicaId() {
2267   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2268     HloInstructionProto instr;
2269     *instr.mutable_shape() = ShapeUtil::MakeShape(U32, {}).ToProto();
2270     return AddInstruction(std::move(instr), HloOpcode::kReplicaId, {});
2271   });
2272 }
2273 
SelectAndScatter(XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)2274 XlaOp XlaBuilder::SelectAndScatter(XlaOp operand, const XlaComputation& select,
2275                                    absl::Span<const int64> window_dimensions,
2276                                    absl::Span<const int64> window_strides,
2277                                    Padding padding, XlaOp source,
2278                                    XlaOp init_value,
2279                                    const XlaComputation& scatter) {
2280   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2281     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2282     return SelectAndScatterWithGeneralPadding(
2283         operand, select, window_dimensions, window_strides,
2284         MakePadding(AsInt64Slice(operand_shape->dimensions()),
2285                     window_dimensions, window_strides, padding),
2286         source, init_value, scatter);
2287   });
2288 }
2289 
SelectAndScatterWithGeneralPadding(XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)2290 XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
2291     XlaOp operand, const XlaComputation& select,
2292     absl::Span<const int64> window_dimensions,
2293     absl::Span<const int64> window_strides,
2294     absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
2295     XlaOp init_value, const XlaComputation& scatter) {
2296   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2297     HloInstructionProto instr;
2298 
2299     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2300     TF_ASSIGN_OR_RETURN(const Shape* source_shape, GetShapePtr(source));
2301     TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
2302     TF_ASSIGN_OR_RETURN(const ProgramShape& select_shape,
2303                         select.GetProgramShape());
2304     TF_ASSIGN_OR_RETURN(const ProgramShape& scatter_shape,
2305                         scatter.GetProgramShape());
2306     TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
2307                         ShapeInference::InferWindowFromDimensions(
2308                             window_dimensions, window_strides, padding,
2309                             /*lhs_dilation=*/{}, /*rhs_dilation=*/{}));
2310     TF_ASSIGN_OR_RETURN(Shape shape,
2311                         ShapeInference::InferSelectAndScatterShape(
2312                             *operand_shape, select_shape, instr.window(),
2313                             *source_shape, *init_shape, scatter_shape));
2314     *instr.mutable_shape() = shape.ToProto();
2315 
2316     AddCalledComputation(select, &instr);
2317     AddCalledComputation(scatter, &instr);
2318 
2319     return AddInstruction(std::move(instr), HloOpcode::kSelectAndScatter,
2320                           {operand, source, init_value});
2321   });
2322 }
2323 
ReducePrecision(XlaOp operand,const int exponent_bits,const int mantissa_bits)2324 XlaOp XlaBuilder::ReducePrecision(XlaOp operand, const int exponent_bits,
2325                                   const int mantissa_bits) {
2326   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2327     HloInstructionProto instr;
2328     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2329     TF_ASSIGN_OR_RETURN(Shape shape,
2330                         ShapeInference::InferReducePrecisionShape(
2331                             *operand_shape, exponent_bits, mantissa_bits));
2332     *instr.mutable_shape() = shape.ToProto();
2333     instr.set_exponent_bits(exponent_bits);
2334     instr.set_mantissa_bits(mantissa_bits);
2335     return AddInstruction(std::move(instr), HloOpcode::kReducePrecision,
2336                           {operand});
2337   });
2338 }
2339 
Send(XlaOp operand,const ChannelHandle & handle)2340 void XlaBuilder::Send(XlaOp operand, const ChannelHandle& handle) {
2341   ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2342     // Send HLO takes two operands: a data operand and a token. Generate the
2343     // token to pass into the send.
2344     // TODO(b/80000000): Remove this when clients have been updated to handle
2345     // tokens.
2346     HloInstructionProto token_instr;
2347     *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
2348     TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
2349                                                     HloOpcode::kAfterAll, {}));
2350 
2351     return SendWithToken(operand, token, handle);
2352   });
2353 }
2354 
SendWithToken(XlaOp operand,XlaOp token,const ChannelHandle & handle)2355 XlaOp XlaBuilder::SendWithToken(XlaOp operand, XlaOp token,
2356                                 const ChannelHandle& handle) {
2357   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2358     if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
2359       return InvalidArgument("Send must use a device-to-device channel");
2360     }
2361 
2362     // Send instruction produces a tuple of {aliased operand, U32 context,
2363     // token}.
2364     HloInstructionProto send_instr;
2365     TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
2366     *send_instr.mutable_shape() =
2367         ShapeUtil::MakeTupleShape({*shape, ShapeUtil::MakeShape(U32, {}),
2368                                    ShapeUtil::MakeTokenShape()})
2369             .ToProto();
2370     send_instr.set_channel_id(handle.handle());
2371     TF_ASSIGN_OR_RETURN(XlaOp send,
2372                         AddInstruction(std::move(send_instr), HloOpcode::kSend,
2373                                        {operand, token}));
2374 
2375     HloInstructionProto send_done_instr;
2376     *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
2377     send_done_instr.set_channel_id(handle.handle());
2378     return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
2379                           {send});
2380   });
2381 }
2382 
Recv(const Shape & shape,const ChannelHandle & handle)2383 XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
2384   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2385     // Recv HLO takes a single token operand. Generate the token to pass into
2386     // the Recv and RecvDone instructions.
2387     // TODO(b/80000000): Remove this when clients have been updated to handle
2388     // tokens.
2389     HloInstructionProto token_instr;
2390     *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
2391     TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
2392                                                     HloOpcode::kAfterAll, {}));
2393 
2394     XlaOp recv = RecvWithToken(token, shape, handle);
2395 
2396     // The RecvDone instruction produces a tuple of the data and a token
2397     // type. Return XLA op containing the data.
2398     // TODO(b/80000000): Remove this when clients have been updated to handle
2399     // tokens.
2400     HloInstructionProto recv_data;
2401     *recv_data.mutable_shape() = shape.ToProto();
2402     recv_data.set_tuple_index(0);
2403     return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement,
2404                           {recv});
2405   });
2406 }
2407 
RecvWithToken(XlaOp token,const Shape & shape,const ChannelHandle & handle)2408 XlaOp XlaBuilder::RecvWithToken(XlaOp token, const Shape& shape,
2409                                 const ChannelHandle& handle) {
2410   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2411     if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
2412       return InvalidArgument("Recv must use a device-to-device channel");
2413     }
2414 
2415     // Recv instruction produces a tuple of {receive buffer, U32 context,
2416     // token}.
2417     HloInstructionProto recv_instr;
2418     *recv_instr.mutable_shape() =
2419         ShapeUtil::MakeTupleShape(
2420             {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
2421             .ToProto();
2422     recv_instr.set_channel_id(handle.handle());
2423     TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
2424                                                    HloOpcode::kRecv, {token}));
2425 
2426     HloInstructionProto recv_done_instr;
2427     *recv_done_instr.mutable_shape() =
2428         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
2429             .ToProto();
2430     recv_done_instr.set_channel_id(handle.handle());
2431     return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
2432                           {recv});
2433   });
2434 }
2435 
SendToHost(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const ChannelHandle & handle)2436 XlaOp XlaBuilder::SendToHost(XlaOp operand, XlaOp token,
2437                              const Shape& shape_with_layout,
2438                              const ChannelHandle& handle) {
2439   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2440     if (!LayoutUtil::HasLayout(shape_with_layout)) {
2441       return InvalidArgument("Shape passed to SendToHost must have a layout");
2442     }
2443     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2444     if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
2445       return InvalidArgument(
2446           "SendToHost shape %s must be compatible with operand shape %s",
2447           ShapeUtil::HumanStringWithLayout(shape_with_layout),
2448           ShapeUtil::HumanStringWithLayout(*operand_shape));
2449     }
2450     // TODO(b/111544877): Support tuple shapes.
2451     if (!operand_shape->IsArray()) {
2452       return InvalidArgument("SendToHost only supports array shapes, shape: %s",
2453                              ShapeUtil::HumanString(*operand_shape));
2454     }
2455 
2456     if (handle.type() != ChannelHandle::DEVICE_TO_HOST) {
2457       return InvalidArgument("SendToHost must use a device-to-host channel");
2458     }
2459 
2460     // Send instruction produces a tuple of {aliased operand, U32 context,
2461     // token}.
2462     HloInstructionProto send_instr;
2463     *send_instr.mutable_shape() =
2464         ShapeUtil::MakeTupleShape({shape_with_layout,
2465                                    ShapeUtil::MakeShape(U32, {}),
2466                                    ShapeUtil::MakeTokenShape()})
2467             .ToProto();
2468     send_instr.set_channel_id(handle.handle());
2469     send_instr.set_is_host_transfer(true);
2470     TF_ASSIGN_OR_RETURN(XlaOp send,
2471                         AddInstruction(std::move(send_instr), HloOpcode::kSend,
2472                                        {operand, token}));
2473 
2474     HloInstructionProto send_done_instr;
2475     *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
2476     send_done_instr.set_channel_id(handle.handle());
2477     send_done_instr.set_is_host_transfer(true);
2478     return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
2479                           {send});
2480   });
2481 }
2482 
RecvFromHost(XlaOp token,const Shape & shape,const ChannelHandle & handle)2483 XlaOp XlaBuilder::RecvFromHost(XlaOp token, const Shape& shape,
2484                                const ChannelHandle& handle) {
2485   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2486     if (!LayoutUtil::HasLayout(shape)) {
2487       return InvalidArgument("Shape passed to RecvFromHost must have a layout");
2488     }
2489 
2490     // TODO(b/111544877): Support tuple shapes.
2491     if (!shape.IsArray()) {
2492       return InvalidArgument(
2493           "RecvFromHost only supports array shapes, shape: %s",
2494           ShapeUtil::HumanString(shape));
2495     }
2496 
2497     if (handle.type() != ChannelHandle::HOST_TO_DEVICE) {
2498       return InvalidArgument("RecvFromHost must use a host-to-device channel");
2499     }
2500 
2501     // Recv instruction produces a tuple of {receive buffer, U32 context,
2502     // token}.
2503     HloInstructionProto recv_instr;
2504     *recv_instr.mutable_shape() =
2505         ShapeUtil::MakeTupleShape(
2506             {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
2507             .ToProto();
2508     recv_instr.set_channel_id(handle.handle());
2509     recv_instr.set_is_host_transfer(true);
2510     TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
2511                                                    HloOpcode::kRecv, {token}));
2512 
2513     HloInstructionProto recv_done_instr;
2514     *recv_done_instr.mutable_shape() =
2515         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
2516             .ToProto();
2517     recv_done_instr.set_channel_id(handle.handle());
2518     recv_done_instr.set_is_host_transfer(true);
2519     return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
2520                           {recv});
2521   });
2522 }
2523 
GetDimensionSize(XlaOp operand,int64 dimension)2524 XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64 dimension) {
2525   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2526     HloInstructionProto instr;
2527     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2528     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape(
2529                                          *operand_shape, dimension));
2530     *instr.mutable_shape() = shape.ToProto();
2531     instr.add_dimensions(dimension);
2532     return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize,
2533                           {operand});
2534   });
2535 }
2536 
SetDimensionSize(XlaOp operand,XlaOp val,int64 dimension)2537 XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) {
2538   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2539     HloInstructionProto instr;
2540     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2541     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSetDimensionSizeShape(
2542                                          *operand_shape, dimension));
2543     *instr.mutable_shape() = shape.ToProto();
2544     instr.add_dimensions(dimension);
2545     return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize,
2546                           {operand, val});
2547   });
2548 }
2549 
IsConstant(XlaOp operand) const2550 StatusOr<bool> XlaBuilder::IsConstant(XlaOp operand) const {
2551   TF_RETURN_IF_ERROR(first_error_);
2552 
2553   // Verify that the handle is valid.
2554   TF_RETURN_IF_ERROR(LookUpInstruction(operand).status());
2555 
2556   bool is_constant = true;
2557   absl::flat_hash_set<int64> visited;
2558   IsConstantVisitor(operand.handle(), &visited, &is_constant);
2559   return is_constant;
2560 }
2561 
BuildConstantSubGraph(XlaOp root_op,bool dynamic_dimension_is_minus_one)2562 StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
2563     XlaOp root_op, bool dynamic_dimension_is_minus_one) {
2564   TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
2565   if (!is_constant) {
2566     auto op_status = LookUpInstruction(root_op);
2567     string op_string =
2568         op_status.ok() ? op_status.ValueOrDie()->name() : "<unknown operation>";
2569     return InvalidArgument(
2570         "Operand to BuildConstantSubGraph depends on a parameter.\n\n"
2571         "  op requested for constant subgraph: %s\n\n"
2572         "This is an internal error that typically happens when the XLA user "
2573         "(e.g. TensorFlow) is attempting to determine a value that must be a "
2574         "compile-time constant (e.g. an array dimension) but it is not capable "
2575         "of being evaluated at XLA compile time.\n\n"
2576         "Please file a usability bug with the framework being used (e.g. "
2577         "TensorFlow).",
2578         op_string);
2579   }
2580 
2581   TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
2582                       LookUpInstruction(root_op));
2583 
2584   HloComputationProto entry;
2585   SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator,
2586                     GetNextId());
2587   entry.set_root_id(root->id());
2588   ProgramShapeProto* program_shape = entry.mutable_program_shape();
2589   *program_shape->mutable_result() = root->shape();
2590 
2591   // We use std::set to keep the instruction ids in ascending order (which is
2592   // also a valid dependency order). The related ops will be added to the
2593   // subgraph in the same order.
2594   std::set<int64> related_ops;
2595   absl::flat_hash_set<int64> related_calls;  // Related computations.
2596   std::queue<int64> worklist;
2597   worklist.push(root->id());
2598   related_ops.insert(root->id());
2599   while (!worklist.empty()) {
2600     int64 handle = worklist.front();
2601     worklist.pop();
2602     TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
2603                         LookUpInstructionByHandle(handle));
2604 
2605     if (instr_proto->opcode() ==
2606         HloOpcodeString(HloOpcode::kGetDimensionSize)) {
2607       // At this point, BuildConstantSubGraph should never encounter a
2608       // GetDimensionSize with a dynamic dimension. IsConstant check would have
2609       // failed at the beginning of this function.
2610       //
2611       // Replace GetDimensionSize with a Constant representing the static bound
2612       // of the shape.
2613       int64 dimension = instr_proto->dimensions(0);
2614       int64 operand_handle = instr_proto->operand_ids(0);
2615       TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
2616                           LookUpInstructionByHandle(operand_handle));
2617 
2618       int32 constant_dimension_size = -1;
2619       if (!(operand_proto->shape().is_dynamic_dimension(dimension) &&
2620             dynamic_dimension_is_minus_one)) {
2621         constant_dimension_size =
2622             static_cast<int32>(operand_proto->shape().dimensions(dimension));
2623       }
2624 
2625       Literal literal = LiteralUtil::CreateR0(constant_dimension_size);
2626 
2627       HloInstructionProto const_instr;
2628       *const_instr.mutable_shape() = literal.shape().ToProto();
2629       *const_instr.mutable_literal() = literal.ToProto();
2630       *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
2631 
2632       const_instr.set_id(handle);
2633       *const_instr.mutable_name() =
2634           GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id());
2635       *entry.add_instructions() =
2636           const_instr;  // Add to the result constant graph.
2637     } else {
2638       for (int64 id : instr_proto->operand_ids()) {
2639         if (related_ops.insert(id).second) {
2640           worklist.push(id);
2641         }
2642       }
2643       for (int64 called_id : instr_proto->called_computation_ids()) {
2644         related_calls.insert(called_id);
2645       }
2646     }
2647   }
2648 
2649   // Add related ops to the computation.
2650   for (int64 id : related_ops) {
2651     TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
2652                         LookUpInstructionByHandle(id));
2653 
2654     if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize)) {
2655       continue;
2656     }
2657     auto* instr = entry.add_instructions();
2658 
2659     *instr = *instr_src;
2660     // Ensures that the instruction names are unique among the graph.
2661     const string& new_name =
2662         StrCat(instr->name(), ".", entry.id(), ".", instr->id());
2663     instr->set_name(new_name);
2664   }
2665 
2666   XlaComputation computation(entry.id());
2667   HloModuleProto* module = computation.mutable_proto();
2668   module->set_name(entry.name());
2669   module->set_id(entry.id());
2670   module->set_entry_computation_name(entry.name());
2671   module->set_entry_computation_id(entry.id());
2672   *module->mutable_host_program_shape() = *program_shape;
2673   for (auto& e : embedded_) {
2674     if (related_calls.find(e.second.id()) != related_calls.end()) {
2675       *module->add_computations() = e.second;
2676     }
2677   }
2678   *module->add_computations() = std::move(entry);
2679 
2680   return std::move(computation);
2681 }
2682 
CreateSubBuilder(const string & computation_name)2683 std::unique_ptr<XlaBuilder> XlaBuilder::CreateSubBuilder(
2684     const string& computation_name) {
2685   auto sub_builder = absl::make_unique<XlaBuilder>(computation_name);
2686   sub_builder->parent_builder_ = this;
2687   sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_;
2688   return sub_builder;
2689 }
2690 
2691 /* static */ ConvolutionDimensionNumbers
CreateDefaultConvDimensionNumbers(int num_spatial_dims)2692 XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
2693   ConvolutionDimensionNumbers dimension_numbers;
2694   dimension_numbers.set_input_batch_dimension(kConvBatchDimension);
2695   dimension_numbers.set_input_feature_dimension(kConvFeatureDimension);
2696   dimension_numbers.set_output_batch_dimension(kConvBatchDimension);
2697   dimension_numbers.set_output_feature_dimension(kConvFeatureDimension);
2698   dimension_numbers.set_kernel_output_feature_dimension(
2699       kConvKernelOutputDimension);
2700   dimension_numbers.set_kernel_input_feature_dimension(
2701       kConvKernelInputDimension);
2702   for (int i = 0; i < num_spatial_dims; ++i) {
2703     dimension_numbers.add_input_spatial_dimensions(i + 2);
2704     dimension_numbers.add_kernel_spatial_dimensions(i + 2);
2705     dimension_numbers.add_output_spatial_dimensions(i + 2);
2706   }
2707   return dimension_numbers;
2708 }
2709 
Validate(const ConvolutionDimensionNumbers & dnum)2710 /* static */ Status XlaBuilder::Validate(
2711     const ConvolutionDimensionNumbers& dnum) {
2712   if (dnum.input_spatial_dimensions_size() < 2) {
2713     return FailedPrecondition("input spacial dimension < 2: %d",
2714                               dnum.input_spatial_dimensions_size());
2715   }
2716   if (dnum.kernel_spatial_dimensions_size() < 2) {
2717     return FailedPrecondition("kernel spacial dimension < 2: %d",
2718                               dnum.kernel_spatial_dimensions_size());
2719   }
2720   if (dnum.output_spatial_dimensions_size() < 2) {
2721     return FailedPrecondition("output spacial dimension < 2: %d",
2722                               dnum.output_spatial_dimensions_size());
2723   }
2724 
2725   if (std::set<int64>(
2726           {dnum.input_batch_dimension(), dnum.input_feature_dimension(),
2727            dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)})
2728           .size() != 4) {
2729     return FailedPrecondition(
2730         "dimension numbers for the input are not unique: (%d, %d, %d, "
2731         "%d)",
2732         dnum.input_batch_dimension(), dnum.input_feature_dimension(),
2733         dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1));
2734   }
2735   if (std::set<int64>({dnum.kernel_output_feature_dimension(),
2736                        dnum.kernel_input_feature_dimension(),
2737                        dnum.kernel_spatial_dimensions(0),
2738                        dnum.kernel_spatial_dimensions(1)})
2739           .size() != 4) {
2740     return FailedPrecondition(
2741         "dimension numbers for the weight are not unique: (%d, %d, %d, "
2742         "%d)",
2743         dnum.kernel_output_feature_dimension(),
2744         dnum.kernel_input_feature_dimension(),
2745         dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1));
2746   }
2747   if (std::set<int64>({dnum.output_batch_dimension(),
2748                        dnum.output_feature_dimension(),
2749                        dnum.output_spatial_dimensions(0),
2750                        dnum.output_spatial_dimensions(1)})
2751           .size() != 4) {
2752     return FailedPrecondition(
2753         "dimension numbers for the output are not unique: (%d, %d, %d, "
2754         "%d)",
2755         dnum.output_batch_dimension(), dnum.output_feature_dimension(),
2756         dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1));
2757   }
2758   return Status::OK();
2759 }
2760 
AddInstruction(HloInstructionProto && instr,HloOpcode opcode,absl::Span<const XlaOp> operands)2761 StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
2762                                            HloOpcode opcode,
2763                                            absl::Span<const XlaOp> operands) {
2764   TF_RETURN_IF_ERROR(first_error_);
2765 
2766   const int64 handle = GetNextId();
2767   instr.set_id(handle);
2768   instr.set_opcode(HloOpcodeString(opcode));
2769   if (instr.name().empty()) {
2770     instr.set_name(instr.opcode());
2771   }
2772   for (const auto& operand : operands) {
2773     if (operand.builder_ == nullptr) {
2774       return InvalidArgument("invalid XlaOp with handle %d", operand.handle());
2775     }
2776     if (operand.builder_ != this) {
2777       return InvalidArgument("Do not add XlaOp from builder %s to builder %s",
2778                              operand.builder_->name(), this->name());
2779     }
2780     instr.add_operand_ids(operand.handle());
2781   }
2782 
2783   *instr.mutable_metadata() = metadata_;
2784   if (sharding_) {
2785     *instr.mutable_sharding() = *sharding_;
2786   }
2787   *instr.mutable_frontend_attributes() = frontend_attributes_;
2788 
2789   handle_to_index_[handle] = instructions_.size();
2790   instructions_.push_back(std::move(instr));
2791   instruction_shapes_.push_back(
2792       absl::make_unique<Shape>(instructions_.back().shape()));
2793 
2794   XlaOp op(handle, this);
2795   return op;
2796 }
2797 
AddCalledComputation(const XlaComputation & computation,HloInstructionProto * instr)2798 void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
2799                                       HloInstructionProto* instr) {
2800   absl::flat_hash_map<int64, int64> remapped_ids;
2801   std::vector<HloComputationProto> imported_computations;
2802   imported_computations.reserve(computation.proto().computations_size());
2803   // Before we import the computations by remapping IDs, and capturing the
2804   // old->new mappings in remapped_ids.
2805   for (const HloComputationProto& e : computation.proto().computations()) {
2806     HloComputationProto new_computation(e);
2807     int64 computation_id = GetNextId();
2808     remapped_ids[new_computation.id()] = computation_id;
2809     SetProtoIdAndName(&new_computation,
2810                       GetBaseName(new_computation.name(), kNameSeparator),
2811                       kNameSeparator, computation_id);
2812     for (auto& instruction : *new_computation.mutable_instructions()) {
2813       int64 instruction_id = GetNextId();
2814       remapped_ids[instruction.id()] = instruction_id;
2815       SetProtoIdAndName(&instruction,
2816                         GetBaseName(instruction.name(), kNameSeparator),
2817                         kNameSeparator, instruction_id);
2818     }
2819     new_computation.set_root_id(remapped_ids.at(new_computation.root_id()));
2820 
2821     imported_computations.push_back(std::move(new_computation));
2822   }
2823   // Once we have imported all the computations, and captured all the ID
2824   // mappings, we go back and fixup the IDs in the imported computations.
2825   instr->add_called_computation_ids(
2826       remapped_ids.at(computation.proto().entry_computation_id()));
2827   for (auto& imported_computation : imported_computations) {
2828     for (auto& instruction : *imported_computation.mutable_instructions()) {
2829       for (auto& operand_id : *instruction.mutable_operand_ids()) {
2830         operand_id = remapped_ids.at(operand_id);
2831       }
2832       for (auto& control_predecessor_id :
2833            *instruction.mutable_control_predecessor_ids()) {
2834         control_predecessor_id = remapped_ids.at(control_predecessor_id);
2835       }
2836       for (auto& called_computation_id :
2837            *instruction.mutable_called_computation_ids()) {
2838         called_computation_id = remapped_ids.at(called_computation_id);
2839       }
2840     }
2841 
2842     int64 computation_id = imported_computation.id();
2843     embedded_.insert({computation_id, std::move(imported_computation)});
2844   }
2845 }
2846 
LookUpInstruction(const XlaOp op) const2847 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
2848     const XlaOp op) const {
2849   TF_RETURN_IF_ERROR(first_error_);
2850   return LookUpInstructionInternal<const HloInstructionProto*>(
2851       handle_to_index_, instructions_, op.builder_, this, op.handle());
2852 }
2853 
LookUpInstructionByHandle(int64 handle) const2854 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
2855     int64 handle) const {
2856   return LookUpInstructionByHandleInternal<const HloInstructionProto*>(
2857       handle_to_index_, instructions_, handle);
2858 }
2859 
LookUpMutableInstruction(const XlaOp op)2860 StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstruction(
2861     const XlaOp op) {
2862   TF_RETURN_IF_ERROR(first_error_);
2863   return LookUpInstructionInternal<HloInstructionProto*>(
2864       handle_to_index_, instructions_, op.builder_, this, op.handle());
2865 }
2866 
LookUpMutableInstructionByHandle(int64 handle)2867 StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstructionByHandle(
2868     int64 handle) {
2869   return LookUpInstructionByHandleInternal<HloInstructionProto*>(
2870       handle_to_index_, instructions_, handle);
2871 }
2872 
2873 // Enqueues a "retrieve parameter value" instruction for a parameter that was
2874 // passed to the computation.
Parameter(XlaBuilder * builder,int64 parameter_number,const Shape & shape,const string & name)2875 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
2876                 const string& name) {
2877   std::vector<bool> empty_bools;
2878   return Parameter(builder, parameter_number, shape, name, empty_bools);
2879 }
2880 
Parameter(XlaBuilder * builder,int64 parameter_number,const Shape & shape,const string & name,const std::vector<bool> & replicated_at_leaf_buffers)2881 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
2882                 const string& name,
2883                 const std::vector<bool>& replicated_at_leaf_buffers) {
2884   return builder->Parameter(parameter_number, shape, name,
2885                             replicated_at_leaf_buffers);
2886 }
2887 
2888 // Enqueues a constant with the value of the given literal onto the
2889 // computation.
ConstantLiteral(XlaBuilder * builder,const LiteralSlice & literal)2890 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) {
2891   return builder->ConstantLiteral(literal);
2892 }
2893 
Broadcast(const XlaOp operand,absl::Span<const int64> broadcast_sizes)2894 XlaOp Broadcast(const XlaOp operand, absl::Span<const int64> broadcast_sizes) {
2895   return operand.builder()->Broadcast(operand, broadcast_sizes);
2896 }
2897 
BroadcastInDim(const XlaOp operand,const absl::Span<const int64> out_dim_size,const absl::Span<const int64> broadcast_dimensions)2898 XlaOp BroadcastInDim(const XlaOp operand,
2899                      const absl::Span<const int64> out_dim_size,
2900                      const absl::Span<const int64> broadcast_dimensions) {
2901   return operand.builder()->BroadcastInDim(operand, out_dim_size,
2902                                            broadcast_dimensions);
2903 }
2904 
Copy(const XlaOp operand)2905 XlaOp Copy(const XlaOp operand) {
2906   return operand.builder()->UnaryOp(HloOpcode::kCopy, operand);
2907 }
2908 
Pad(const XlaOp operand,const XlaOp padding_value,const PaddingConfig & padding_config)2909 XlaOp Pad(const XlaOp operand, const XlaOp padding_value,
2910           const PaddingConfig& padding_config) {
2911   return operand.builder()->Pad(operand, padding_value, padding_config);
2912 }
2913 
Reshape(const XlaOp operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes)2914 XlaOp Reshape(const XlaOp operand, absl::Span<const int64> dimensions,
2915               absl::Span<const int64> new_sizes) {
2916   return operand.builder()->Reshape(operand, dimensions, new_sizes);
2917 }
2918 
Reshape(const XlaOp operand,absl::Span<const int64> new_sizes)2919 XlaOp Reshape(const XlaOp operand, absl::Span<const int64> new_sizes) {
2920   return operand.builder()->Reshape(operand, new_sizes);
2921 }
2922 
ReshapeWithInferredDimension(XlaOp operand,absl::Span<const int64> new_sizes,int64 inferred_dimension)2923 XlaOp ReshapeWithInferredDimension(XlaOp operand,
2924                                    absl::Span<const int64> new_sizes,
2925                                    int64 inferred_dimension) {
2926   return operand.builder()->Reshape(operand, new_sizes, inferred_dimension);
2927 }
2928 
Collapse(const XlaOp operand,absl::Span<const int64> dimensions)2929 XlaOp Collapse(const XlaOp operand, absl::Span<const int64> dimensions) {
2930   return operand.builder()->Collapse(operand, dimensions);
2931 }
2932 
Slice(const XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)2933 XlaOp Slice(const XlaOp operand, absl::Span<const int64> start_indices,
2934             absl::Span<const int64> limit_indices,
2935             absl::Span<const int64> strides) {
2936   return operand.builder()->Slice(operand, start_indices, limit_indices,
2937                                   strides);
2938 }
2939 
SliceInDim(const XlaOp operand,int64 start_index,int64 limit_index,int64 stride,int64 dimno)2940 XlaOp SliceInDim(const XlaOp operand, int64 start_index, int64 limit_index,
2941                  int64 stride, int64 dimno) {
2942   return operand.builder()->SliceInDim(operand, start_index, limit_index,
2943                                        stride, dimno);
2944 }
2945 
DynamicSlice(const XlaOp operand,const XlaOp start_indices,absl::Span<const int64> slice_sizes)2946 XlaOp DynamicSlice(const XlaOp operand, const XlaOp start_indices,
2947                    absl::Span<const int64> slice_sizes) {
2948   return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
2949 }
DynamicSlice(const XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)2950 XlaOp DynamicSlice(const XlaOp operand, absl::Span<const XlaOp> start_indices,
2951                    absl::Span<const int64> slice_sizes) {
2952   return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
2953 }
2954 
DynamicUpdateSlice(const XlaOp operand,const XlaOp update,const XlaOp start_indices)2955 XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update,
2956                          const XlaOp start_indices) {
2957   return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
2958 }
2959 
DynamicUpdateSlice(const XlaOp operand,const XlaOp update,absl::Span<const XlaOp> start_indices)2960 XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update,
2961                          absl::Span<const XlaOp> start_indices) {
2962   return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
2963 }
2964 
ConcatInDim(XlaBuilder * builder,absl::Span<const XlaOp> operands,int64 dimension)2965 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
2966                   int64 dimension) {
2967   return builder->ConcatInDim(operands, dimension);
2968 }
2969 
Trace(const string & tag,const XlaOp operand)2970 void Trace(const string& tag, const XlaOp operand) {
2971   return operand.builder()->Trace(tag, operand);
2972 }
2973 
Select(const XlaOp pred,const XlaOp on_true,const XlaOp on_false)2974 XlaOp Select(const XlaOp pred, const XlaOp on_true, const XlaOp on_false) {
2975   return pred.builder()->Select(pred, on_true, on_false);
2976 }
2977 
Tuple(XlaBuilder * builder,absl::Span<const XlaOp> elements)2978 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements) {
2979   return builder->Tuple(elements);
2980 }
2981 
GetTupleElement(const XlaOp tuple_data,int64 index)2982 XlaOp GetTupleElement(const XlaOp tuple_data, int64 index) {
2983   return tuple_data.builder()->GetTupleElement(tuple_data, index);
2984 }
2985 
Eq(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)2986 XlaOp Eq(const XlaOp lhs, const XlaOp rhs,
2987          absl::Span<const int64> broadcast_dimensions) {
2988   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
2989 }
2990 
Ne(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)2991 XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
2992          absl::Span<const int64> broadcast_dimensions) {
2993   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe);
2994 }
2995 
Ge(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)2996 XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
2997          absl::Span<const int64> broadcast_dimensions) {
2998   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe);
2999 }
3000 
Gt(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3001 XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
3002          absl::Span<const int64> broadcast_dimensions) {
3003   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt);
3004 }
3005 
Le(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3006 XlaOp Le(const XlaOp lhs, const XlaOp rhs,
3007          absl::Span<const int64> broadcast_dimensions) {
3008   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe);
3009 }
3010 
Lt(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3011 XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
3012          absl::Span<const int64> broadcast_dimensions) {
3013   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
3014 }
3015 
Compare(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions,ComparisonDirection direction)3016 XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
3017               absl::Span<const int64> broadcast_dimensions,
3018               ComparisonDirection direction) {
3019   return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
3020                                  broadcast_dimensions, direction);
3021 }
3022 
Dot(const XlaOp lhs,const XlaOp rhs,const PrecisionConfig * precision_config)3023 XlaOp Dot(const XlaOp lhs, const XlaOp rhs,
3024           const PrecisionConfig* precision_config) {
3025   return lhs.builder()->Dot(lhs, rhs, precision_config);
3026 }
3027 
DotGeneral(const XlaOp lhs,const XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config)3028 XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs,
3029                  const DotDimensionNumbers& dimension_numbers,
3030                  const PrecisionConfig* precision_config) {
3031   return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
3032                                    precision_config);
3033 }
3034 
Conv(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)3035 XlaOp Conv(const XlaOp lhs, const XlaOp rhs,
3036            absl::Span<const int64> window_strides, Padding padding,
3037            int64 feature_group_count, int64 batch_group_count,
3038            const PrecisionConfig* precision_config) {
3039   return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
3040                              feature_group_count, batch_group_count,
3041                              precision_config);
3042 }
3043 
ConvWithGeneralPadding(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)3044 XlaOp ConvWithGeneralPadding(const XlaOp lhs, const XlaOp rhs,
3045                              absl::Span<const int64> window_strides,
3046                              absl::Span<const std::pair<int64, int64>> padding,
3047                              int64 feature_group_count, int64 batch_group_count,
3048                              const PrecisionConfig* precision_config) {
3049   return lhs.builder()->ConvWithGeneralPadding(
3050       lhs, rhs, window_strides, padding, feature_group_count, batch_group_count,
3051       precision_config);
3052 }
3053 
ConvWithGeneralDimensions(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)3054 XlaOp ConvWithGeneralDimensions(
3055     const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
3056     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
3057     int64 feature_group_count, int64 batch_group_count,
3058     const PrecisionConfig* precision_config) {
3059   return lhs.builder()->ConvWithGeneralDimensions(
3060       lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
3061       batch_group_count, precision_config);
3062 }
3063 
ConvGeneral(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)3064 XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
3065                   absl::Span<const int64> window_strides,
3066                   absl::Span<const std::pair<int64, int64>> padding,
3067                   const ConvolutionDimensionNumbers& dimension_numbers,
3068                   int64 feature_group_count, int64 batch_group_count,
3069                   const PrecisionConfig* precision_config) {
3070   return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding,
3071                                     dimension_numbers, feature_group_count,
3072                                     batch_group_count, precision_config);
3073 }
3074 
ConvGeneralDilated(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)3075 XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
3076                          absl::Span<const int64> window_strides,
3077                          absl::Span<const std::pair<int64, int64>> padding,
3078                          absl::Span<const int64> lhs_dilation,
3079                          absl::Span<const int64> rhs_dilation,
3080                          const ConvolutionDimensionNumbers& dimension_numbers,
3081                          int64 feature_group_count, int64 batch_group_count,
3082                          const PrecisionConfig* precision_config) {
3083   return lhs.builder()->ConvGeneralDilated(
3084       lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
3085       dimension_numbers, feature_group_count, batch_group_count,
3086       precision_config);
3087 }
3088 
Fft(const XlaOp operand,FftType fft_type,absl::Span<const int64> fft_length)3089 XlaOp Fft(const XlaOp operand, FftType fft_type,
3090           absl::Span<const int64> fft_length) {
3091   return operand.builder()->Fft(operand, fft_type, fft_length);
3092 }
3093 
TriangularSolve(XlaOp a,XlaOp b,bool left_side,bool lower,bool unit_diagonal,TriangularSolveOptions::Transpose transpose_a)3094 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
3095                       bool unit_diagonal,
3096                       TriangularSolveOptions::Transpose transpose_a) {
3097   XlaBuilder* builder = a.builder();
3098   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3099     HloInstructionProto instr;
3100     TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a));
3101     TF_ASSIGN_OR_RETURN(const Shape* b_shape, builder->GetShapePtr(b));
3102     xla::TriangularSolveOptions& options =
3103         *instr.mutable_triangular_solve_options();
3104     options.set_left_side(left_side);
3105     options.set_lower(lower);
3106     options.set_unit_diagonal(unit_diagonal);
3107     options.set_transpose_a(transpose_a);
3108     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTriangularSolveShape(
3109                                          *a_shape, *b_shape, options));
3110     *instr.mutable_shape() = shape.ToProto();
3111 
3112     return builder->AddInstruction(std::move(instr),
3113                                    HloOpcode::kTriangularSolve, {a, b});
3114   });
3115 }
3116 
Cholesky(XlaOp a,bool lower)3117 XlaOp Cholesky(XlaOp a, bool lower) {
3118   XlaBuilder* builder = a.builder();
3119   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3120     HloInstructionProto instr;
3121     TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a));
3122     xla::CholeskyOptions& options = *instr.mutable_cholesky_options();
3123     options.set_lower(lower);
3124     TF_ASSIGN_OR_RETURN(Shape shape,
3125                         ShapeInference::InferCholeskyShape(*a_shape));
3126     *instr.mutable_shape() = shape.ToProto();
3127 
3128     return builder->AddInstruction(std::move(instr), HloOpcode::kCholesky, {a});
3129   });
3130 }
3131 
Infeed(XlaBuilder * builder,const Shape & shape,const string & config)3132 XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) {
3133   return builder->Infeed(shape, config);
3134 }
3135 
Outfeed(const XlaOp operand,const Shape & shape_with_layout,const string & outfeed_config)3136 void Outfeed(const XlaOp operand, const Shape& shape_with_layout,
3137              const string& outfeed_config) {
3138   return operand.builder()->Outfeed(operand, shape_with_layout, outfeed_config);
3139 }
3140 
Call(XlaBuilder * builder,const XlaComputation & computation,absl::Span<const XlaOp> operands)3141 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
3142            absl::Span<const XlaOp> operands) {
3143   return builder->Call(computation, operands);
3144 }
3145 
CustomCall(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque)3146 XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
3147                  absl::Span<const XlaOp> operands, const Shape& shape,
3148                  const string& opaque) {
3149   return builder->CustomCall(call_target_name, operands, shape, opaque,
3150                              /*operand_shapes_with_layout=*/absl::nullopt);
3151 }
3152 
CustomCallWithLayout(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,absl::Span<const Shape> operand_shapes_with_layout,const string & opaque)3153 XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
3154                            absl::Span<const XlaOp> operands, const Shape& shape,
3155                            absl::Span<const Shape> operand_shapes_with_layout,
3156                            const string& opaque) {
3157   return builder->CustomCall(call_target_name, operands, shape, opaque,
3158                              operand_shapes_with_layout);
3159 }
3160 
Complex(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3161 XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
3162               absl::Span<const int64> broadcast_dimensions) {
3163   return lhs.builder()->BinaryOp(HloOpcode::kComplex, lhs, rhs,
3164                                  broadcast_dimensions);
3165 }
3166 
Conj(const XlaOp operand)3167 XlaOp Conj(const XlaOp operand) {
3168   return Complex(Real(operand), Neg(Imag(operand)));
3169 }
3170 
Add(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3171 XlaOp Add(const XlaOp lhs, const XlaOp rhs,
3172           absl::Span<const int64> broadcast_dimensions) {
3173   return lhs.builder()->BinaryOp(HloOpcode::kAdd, lhs, rhs,
3174                                  broadcast_dimensions);
3175 }
3176 
Sub(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3177 XlaOp Sub(const XlaOp lhs, const XlaOp rhs,
3178           absl::Span<const int64> broadcast_dimensions) {
3179   return lhs.builder()->BinaryOp(HloOpcode::kSubtract, lhs, rhs,
3180                                  broadcast_dimensions);
3181 }
3182 
Mul(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3183 XlaOp Mul(const XlaOp lhs, const XlaOp rhs,
3184           absl::Span<const int64> broadcast_dimensions) {
3185   return lhs.builder()->BinaryOp(HloOpcode::kMultiply, lhs, rhs,
3186                                  broadcast_dimensions);
3187 }
3188 
Div(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3189 XlaOp Div(const XlaOp lhs, const XlaOp rhs,
3190           absl::Span<const int64> broadcast_dimensions) {
3191   return lhs.builder()->BinaryOp(HloOpcode::kDivide, lhs, rhs,
3192                                  broadcast_dimensions);
3193 }
3194 
Rem(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3195 XlaOp Rem(const XlaOp lhs, const XlaOp rhs,
3196           absl::Span<const int64> broadcast_dimensions) {
3197   return lhs.builder()->BinaryOp(HloOpcode::kRemainder, lhs, rhs,
3198                                  broadcast_dimensions);
3199 }
3200 
Max(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3201 XlaOp Max(const XlaOp lhs, const XlaOp rhs,
3202           absl::Span<const int64> broadcast_dimensions) {
3203   return lhs.builder()->BinaryOp(HloOpcode::kMaximum, lhs, rhs,
3204                                  broadcast_dimensions);
3205 }
3206 
Min(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3207 XlaOp Min(const XlaOp lhs, const XlaOp rhs,
3208           absl::Span<const int64> broadcast_dimensions) {
3209   return lhs.builder()->BinaryOp(HloOpcode::kMinimum, lhs, rhs,
3210                                  broadcast_dimensions);
3211 }
3212 
And(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3213 XlaOp And(const XlaOp lhs, const XlaOp rhs,
3214           absl::Span<const int64> broadcast_dimensions) {
3215   return lhs.builder()->BinaryOp(HloOpcode::kAnd, lhs, rhs,
3216                                  broadcast_dimensions);
3217 }
3218 
Or(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3219 XlaOp Or(const XlaOp lhs, const XlaOp rhs,
3220          absl::Span<const int64> broadcast_dimensions) {
3221   return lhs.builder()->BinaryOp(HloOpcode::kOr, lhs, rhs,
3222                                  broadcast_dimensions);
3223 }
3224 
Xor(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3225 XlaOp Xor(const XlaOp lhs, const XlaOp rhs,
3226           absl::Span<const int64> broadcast_dimensions) {
3227   return lhs.builder()->BinaryOp(HloOpcode::kXor, lhs, rhs,
3228                                  broadcast_dimensions);
3229 }
3230 
Not(const XlaOp operand)3231 XlaOp Not(const XlaOp operand) {
3232   return operand.builder()->UnaryOp(HloOpcode::kNot, operand);
3233 }
3234 
PopulationCount(const XlaOp operand)3235 XlaOp PopulationCount(const XlaOp operand) {
3236   return operand.builder()->UnaryOp(HloOpcode::kPopulationCount, operand);
3237 }
3238 
ShiftLeft(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3239 XlaOp ShiftLeft(const XlaOp lhs, const XlaOp rhs,
3240                 absl::Span<const int64> broadcast_dimensions) {
3241   return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs,
3242                                  broadcast_dimensions);
3243 }
3244 
ShiftRightArithmetic(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3245 XlaOp ShiftRightArithmetic(const XlaOp lhs, const XlaOp rhs,
3246                            absl::Span<const int64> broadcast_dimensions) {
3247   return lhs.builder()->BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs,
3248                                  broadcast_dimensions);
3249 }
3250 
ShiftRightLogical(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3251 XlaOp ShiftRightLogical(const XlaOp lhs, const XlaOp rhs,
3252                         absl::Span<const int64> broadcast_dimensions) {
3253   return lhs.builder()->BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs,
3254                                  broadcast_dimensions);
3255 }
3256 
Reduce(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)3257 XlaOp Reduce(const XlaOp operand, const XlaOp init_value,
3258              const XlaComputation& computation,
3259              absl::Span<const int64> dimensions_to_reduce) {
3260   return operand.builder()->Reduce(operand, init_value, computation,
3261                                    dimensions_to_reduce);
3262 }
3263 
3264 // Reduces several arrays simultaneously among the provided dimensions, given
3265 // "computation" as a reduction operator.
Reduce(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)3266 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
3267              absl::Span<const XlaOp> init_values,
3268              const XlaComputation& computation,
3269              absl::Span<const int64> dimensions_to_reduce) {
3270   return builder->Reduce(operands, init_values, computation,
3271                          dimensions_to_reduce);
3272 }
3273 
ReduceAll(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation)3274 XlaOp ReduceAll(const XlaOp operand, const XlaOp init_value,
3275                 const XlaComputation& computation) {
3276   return operand.builder()->ReduceAll(operand, init_value, computation);
3277 }
3278 
ReduceWindow(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)3279 XlaOp ReduceWindow(const XlaOp operand, const XlaOp init_value,
3280                    const XlaComputation& computation,
3281                    absl::Span<const int64> window_dimensions,
3282                    absl::Span<const int64> window_strides, Padding padding) {
3283   return operand.builder()->ReduceWindow(operand, init_value, computation,
3284                                          window_dimensions, window_strides,
3285                                          padding);
3286 }
3287 
ReduceWindowWithGeneralPadding(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)3288 XlaOp ReduceWindowWithGeneralPadding(
3289     const XlaOp operand, const XlaOp init_value,
3290     const XlaComputation& computation,
3291     absl::Span<const int64> window_dimensions,
3292     absl::Span<const int64> window_strides,
3293     absl::Span<const int64> base_dilations,
3294     absl::Span<const int64> window_dilations,
3295     absl::Span<const std::pair<int64, int64>> padding) {
3296   return operand.builder()->ReduceWindowWithGeneralPadding(
3297       operand, init_value, computation, window_dimensions, window_strides,
3298       base_dilations, window_dilations, padding);
3299 }
3300 
CrossReplicaSum(const XlaOp operand,absl::Span<const ReplicaGroup> replica_groups)3301 XlaOp CrossReplicaSum(const XlaOp operand,
3302                       absl::Span<const ReplicaGroup> replica_groups) {
3303   return operand.builder()->CrossReplicaSum(operand, replica_groups);
3304 }
3305 
AllReduce(const XlaOp operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Shape> & shape_with_layout)3306 XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
3307                 absl::Span<const ReplicaGroup> replica_groups,
3308                 const absl::optional<ChannelHandle>& channel_id,
3309                 const absl::optional<Shape>& shape_with_layout) {
3310   return operand.builder()->AllReduce(operand, computation, replica_groups,
3311                                       channel_id, shape_with_layout);
3312 }
3313 
AllToAll(const XlaOp operand,int64 split_dimension,int64 concat_dimension,int64 split_count,const std::vector<ReplicaGroup> & replica_groups)3314 XlaOp AllToAll(const XlaOp operand, int64 split_dimension,
3315                int64 concat_dimension, int64 split_count,
3316                const std::vector<ReplicaGroup>& replica_groups) {
3317   return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
3318                                      split_count, replica_groups);
3319 }
3320 
CollectivePermute(const XlaOp operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)3321 XlaOp CollectivePermute(
3322     const XlaOp operand,
3323     const std::vector<std::pair<int64, int64>>& source_target_pairs) {
3324   return operand.builder()->CollectivePermute(operand, source_target_pairs);
3325 }
3326 
ReplicaId(XlaBuilder * builder)3327 XlaOp ReplicaId(XlaBuilder* builder) { return builder->ReplicaId(); }
3328 
SelectAndScatter(const XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding,const XlaOp source,const XlaOp init_value,const XlaComputation & scatter)3329 XlaOp SelectAndScatter(const XlaOp operand, const XlaComputation& select,
3330                        absl::Span<const int64> window_dimensions,
3331                        absl::Span<const int64> window_strides, Padding padding,
3332                        const XlaOp source, const XlaOp init_value,
3333                        const XlaComputation& scatter) {
3334   return operand.builder()->SelectAndScatter(operand, select, window_dimensions,
3335                                              window_strides, padding, source,
3336                                              init_value, scatter);
3337 }
3338 
SelectAndScatterWithGeneralPadding(const XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const XlaOp source,const XlaOp init_value,const XlaComputation & scatter)3339 XlaOp SelectAndScatterWithGeneralPadding(
3340     const XlaOp operand, const XlaComputation& select,
3341     absl::Span<const int64> window_dimensions,
3342     absl::Span<const int64> window_strides,
3343     absl::Span<const std::pair<int64, int64>> padding, const XlaOp source,
3344     const XlaOp init_value, const XlaComputation& scatter) {
3345   return operand.builder()->SelectAndScatterWithGeneralPadding(
3346       operand, select, window_dimensions, window_strides, padding, source,
3347       init_value, scatter);
3348 }
3349 
Abs(const XlaOp operand)3350 XlaOp Abs(const XlaOp operand) {
3351   return operand.builder()->UnaryOp(HloOpcode::kAbs, operand);
3352 }
3353 
Atan2(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3354 XlaOp Atan2(const XlaOp lhs, const XlaOp rhs,
3355             absl::Span<const int64> broadcast_dimensions) {
3356   return lhs.builder()->BinaryOp(HloOpcode::kAtan2, lhs, rhs,
3357                                  broadcast_dimensions);
3358 }
3359 
Exp(const XlaOp operand)3360 XlaOp Exp(const XlaOp operand) {
3361   return operand.builder()->UnaryOp(HloOpcode::kExp, operand);
3362 }
Expm1(const XlaOp operand)3363 XlaOp Expm1(const XlaOp operand) {
3364   return operand.builder()->UnaryOp(HloOpcode::kExpm1, operand);
3365 }
Floor(const XlaOp operand)3366 XlaOp Floor(const XlaOp operand) {
3367   return operand.builder()->UnaryOp(HloOpcode::kFloor, operand);
3368 }
Ceil(const XlaOp operand)3369 XlaOp Ceil(const XlaOp operand) {
3370   return operand.builder()->UnaryOp(HloOpcode::kCeil, operand);
3371 }
Round(const XlaOp operand)3372 XlaOp Round(const XlaOp operand) {
3373   return operand.builder()->UnaryOp(HloOpcode::kRoundNearestAfz, operand);
3374 }
Log(const XlaOp operand)3375 XlaOp Log(const XlaOp operand) {
3376   return operand.builder()->UnaryOp(HloOpcode::kLog, operand);
3377 }
Log1p(const XlaOp operand)3378 XlaOp Log1p(const XlaOp operand) {
3379   return operand.builder()->UnaryOp(HloOpcode::kLog1p, operand);
3380 }
Sign(const XlaOp operand)3381 XlaOp Sign(const XlaOp operand) {
3382   return operand.builder()->UnaryOp(HloOpcode::kSign, operand);
3383 }
Clz(const XlaOp operand)3384 XlaOp Clz(const XlaOp operand) {
3385   return operand.builder()->UnaryOp(HloOpcode::kClz, operand);
3386 }
Cos(const XlaOp operand)3387 XlaOp Cos(const XlaOp operand) {
3388   return operand.builder()->UnaryOp(HloOpcode::kCos, operand);
3389 }
Sin(const XlaOp operand)3390 XlaOp Sin(const XlaOp operand) {
3391   return operand.builder()->UnaryOp(HloOpcode::kSin, operand);
3392 }
Tanh(const XlaOp operand)3393 XlaOp Tanh(const XlaOp operand) {
3394   return operand.builder()->UnaryOp(HloOpcode::kTanh, operand);
3395 }
Real(const XlaOp operand)3396 XlaOp Real(const XlaOp operand) {
3397   return operand.builder()->UnaryOp(HloOpcode::kReal, operand);
3398 }
Imag(const XlaOp operand)3399 XlaOp Imag(const XlaOp operand) {
3400   return operand.builder()->UnaryOp(HloOpcode::kImag, operand);
3401 }
Sqrt(const XlaOp operand)3402 XlaOp Sqrt(const XlaOp operand) {
3403   return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand);
3404 }
Rsqrt(const XlaOp operand)3405 XlaOp Rsqrt(const XlaOp operand) {
3406   return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand);
3407 }
3408 
Pow(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3409 XlaOp Pow(const XlaOp lhs, const XlaOp rhs,
3410           absl::Span<const int64> broadcast_dimensions) {
3411   return lhs.builder()->BinaryOp(HloOpcode::kPower, lhs, rhs,
3412                                  broadcast_dimensions);
3413 }
3414 
IsFinite(const XlaOp operand)3415 XlaOp IsFinite(const XlaOp operand) {
3416   return operand.builder()->UnaryOp(HloOpcode::kIsFinite, operand);
3417 }
3418 
ConvertElementType(const XlaOp operand,PrimitiveType new_element_type)3419 XlaOp ConvertElementType(const XlaOp operand, PrimitiveType new_element_type) {
3420   return operand.builder()->ConvertElementType(operand, new_element_type);
3421 }
3422 
BitcastConvertType(const XlaOp operand,PrimitiveType new_element_type)3423 XlaOp BitcastConvertType(const XlaOp operand, PrimitiveType new_element_type) {
3424   return operand.builder()->BitcastConvertType(operand, new_element_type);
3425 }
3426 
Neg(const XlaOp operand)3427 XlaOp Neg(const XlaOp operand) {
3428   return operand.builder()->UnaryOp(HloOpcode::kNegate, operand);
3429 }
3430 
Transpose(const XlaOp operand,absl::Span<const int64> permutation)3431 XlaOp Transpose(const XlaOp operand, absl::Span<const int64> permutation) {
3432   return operand.builder()->Transpose(operand, permutation);
3433 }
3434 
Rev(const XlaOp operand,absl::Span<const int64> dimensions)3435 XlaOp Rev(const XlaOp operand, absl::Span<const int64> dimensions) {
3436   return operand.builder()->Rev(operand, dimensions);
3437 }
3438 
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64 dimension,bool is_stable)3439 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
3440            int64 dimension, bool is_stable) {
3441   return operands[0].builder()->Sort(operands, comparator, dimension,
3442                                      is_stable);
3443 }
3444 
Clamp(const XlaOp min,const XlaOp operand,const XlaOp max)3445 XlaOp Clamp(const XlaOp min, const XlaOp operand, const XlaOp max) {
3446   return min.builder()->Clamp(min, operand, max);
3447 }
3448 
Map(XlaBuilder * builder,absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64> dimensions,absl::Span<const XlaOp> static_operands)3449 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
3450           const XlaComputation& computation, absl::Span<const int64> dimensions,
3451           absl::Span<const XlaOp> static_operands) {
3452   return builder->Map(operands, computation, dimensions, static_operands);
3453 }
3454 
RngNormal(const XlaOp mu,const XlaOp sigma,const Shape & shape)3455 XlaOp RngNormal(const XlaOp mu, const XlaOp sigma, const Shape& shape) {
3456   return mu.builder()->RngNormal(mu, sigma, shape);
3457 }
3458 
RngUniform(const XlaOp a,const XlaOp b,const Shape & shape)3459 XlaOp RngUniform(const XlaOp a, const XlaOp b, const Shape& shape) {
3460   return a.builder()->RngUniform(a, b, shape);
3461 }
3462 
While(const XlaComputation & condition,const XlaComputation & body,const XlaOp init)3463 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
3464             const XlaOp init) {
3465   return init.builder()->While(condition, body, init);
3466 }
3467 
Conditional(const XlaOp predicate,const XlaOp true_operand,const XlaComputation & true_computation,const XlaOp false_operand,const XlaComputation & false_computation)3468 XlaOp Conditional(const XlaOp predicate, const XlaOp true_operand,
3469                   const XlaComputation& true_computation,
3470                   const XlaOp false_operand,
3471                   const XlaComputation& false_computation) {
3472   return predicate.builder()->Conditional(predicate, true_operand,
3473                                           true_computation, false_operand,
3474                                           false_computation);
3475 }
3476 
Conditional(const XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)3477 XlaOp Conditional(const XlaOp branch_index,
3478                   absl::Span<const XlaComputation* const> branch_computations,
3479                   absl::Span<const XlaOp> branch_operands) {
3480   return branch_index.builder()->Conditional(branch_index, branch_computations,
3481                                              branch_operands);
3482 }
3483 
ReducePrecision(const XlaOp operand,const int exponent_bits,const int mantissa_bits)3484 XlaOp ReducePrecision(const XlaOp operand, const int exponent_bits,
3485                       const int mantissa_bits) {
3486   return operand.builder()->ReducePrecision(operand, exponent_bits,
3487                                             mantissa_bits);
3488 }
3489 
Gather(const XlaOp input,const XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)3490 XlaOp Gather(const XlaOp input, const XlaOp start_indices,
3491              const GatherDimensionNumbers& dimension_numbers,
3492              absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
3493   return input.builder()->Gather(input, start_indices, dimension_numbers,
3494                                  slice_sizes, indices_are_sorted);
3495 }
3496 
Scatter(const XlaOp input,const XlaOp scatter_indices,const XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)3497 XlaOp Scatter(const XlaOp input, const XlaOp scatter_indices,
3498               const XlaOp updates, const XlaComputation& update_computation,
3499               const ScatterDimensionNumbers& dimension_numbers,
3500               bool indices_are_sorted, bool unique_indices) {
3501   return input.builder()->Scatter(input, scatter_indices, updates,
3502                                   update_computation, dimension_numbers,
3503                                   indices_are_sorted, unique_indices);
3504 }
3505 
Send(const XlaOp operand,const ChannelHandle & handle)3506 void Send(const XlaOp operand, const ChannelHandle& handle) {
3507   return operand.builder()->Send(operand, handle);
3508 }
3509 
Recv(XlaBuilder * builder,const Shape & shape,const ChannelHandle & handle)3510 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
3511            const ChannelHandle& handle) {
3512   return builder->Recv(shape, handle);
3513 }
3514 
SendWithToken(const XlaOp operand,const XlaOp token,const ChannelHandle & handle)3515 XlaOp SendWithToken(const XlaOp operand, const XlaOp token,
3516                     const ChannelHandle& handle) {
3517   return operand.builder()->SendWithToken(operand, token, handle);
3518 }
3519 
RecvWithToken(const XlaOp token,const Shape & shape,const ChannelHandle & handle)3520 XlaOp RecvWithToken(const XlaOp token, const Shape& shape,
3521                     const ChannelHandle& handle) {
3522   return token.builder()->RecvWithToken(token, shape, handle);
3523 }
3524 
SendToHost(const XlaOp operand,const XlaOp token,const Shape & shape_with_layout,const ChannelHandle & handle)3525 XlaOp SendToHost(const XlaOp operand, const XlaOp token,
3526                  const Shape& shape_with_layout, const ChannelHandle& handle) {
3527   return operand.builder()->SendToHost(operand, token, shape_with_layout,
3528                                        handle);
3529 }
3530 
RecvFromHost(const XlaOp token,const Shape & shape,const ChannelHandle & handle)3531 XlaOp RecvFromHost(const XlaOp token, const Shape& shape,
3532                    const ChannelHandle& handle) {
3533   return token.builder()->RecvFromHost(token, shape, handle);
3534 }
3535 
InfeedWithToken(const XlaOp token,const Shape & shape,const string & config)3536 XlaOp InfeedWithToken(const XlaOp token, const Shape& shape,
3537                       const string& config) {
3538   return token.builder()->InfeedWithToken(token, shape, config);
3539 }
3540 
OutfeedWithToken(const XlaOp operand,const XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)3541 XlaOp OutfeedWithToken(const XlaOp operand, const XlaOp token,
3542                        const Shape& shape_with_layout,
3543                        const string& outfeed_config) {
3544   return operand.builder()->OutfeedWithToken(operand, token, shape_with_layout,
3545                                              outfeed_config);
3546 }
3547 
CreateToken(XlaBuilder * builder)3548 XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); }
3549 
AfterAll(XlaBuilder * builder,absl::Span<const XlaOp> tokens)3550 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens) {
3551   return builder->AfterAll(tokens);
3552 }
3553 
BatchNormTraining(const XlaOp operand,const XlaOp scale,const XlaOp offset,float epsilon,int64 feature_index)3554 XlaOp BatchNormTraining(const XlaOp operand, const XlaOp scale,
3555                         const XlaOp offset, float epsilon,
3556                         int64 feature_index) {
3557   return operand.builder()->BatchNormTraining(operand, scale, offset, epsilon,
3558                                               feature_index);
3559 }
3560 
BatchNormInference(const XlaOp operand,const XlaOp scale,const XlaOp offset,const XlaOp mean,const XlaOp variance,float epsilon,int64 feature_index)3561 XlaOp BatchNormInference(const XlaOp operand, const XlaOp scale,
3562                          const XlaOp offset, const XlaOp mean,
3563                          const XlaOp variance, float epsilon,
3564                          int64 feature_index) {
3565   return operand.builder()->BatchNormInference(
3566       operand, scale, offset, mean, variance, epsilon, feature_index);
3567 }
3568 
BatchNormGrad(const XlaOp operand,const XlaOp scale,const XlaOp batch_mean,const XlaOp batch_var,const XlaOp grad_output,float epsilon,int64 feature_index)3569 XlaOp BatchNormGrad(const XlaOp operand, const XlaOp scale,
3570                     const XlaOp batch_mean, const XlaOp batch_var,
3571                     const XlaOp grad_output, float epsilon,
3572                     int64 feature_index) {
3573   return operand.builder()->BatchNormGrad(operand, scale, batch_mean, batch_var,
3574                                           grad_output, epsilon, feature_index);
3575 }
3576 
Iota(XlaBuilder * builder,PrimitiveType type,int64 size)3577 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) {
3578   return builder->Iota(type, size);
3579 }
3580 
Iota(XlaBuilder * builder,const Shape & shape,int64 iota_dimension)3581 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) {
3582   return builder->Iota(shape, iota_dimension);
3583 }
3584 
GetDimensionSize(const XlaOp operand,int64 dimension)3585 XlaOp GetDimensionSize(const XlaOp operand, int64 dimension) {
3586   return operand.builder()->GetDimensionSize(operand, dimension);
3587 }
3588 
SetDimensionSize(const XlaOp operand,const XlaOp val,int64 dimension)3589 XlaOp SetDimensionSize(const XlaOp operand, const XlaOp val, int64 dimension) {
3590   return operand.builder()->SetDimensionSize(operand, val, dimension);
3591 }
3592 
3593 }  // namespace xla
3594