1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/xla_builder.h"
17
18 #include <functional>
19 #include <numeric>
20 #include <queue>
21 #include <string>
22 #include <utility>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/memory/memory.h"
28 #include "absl/strings/match.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_join.h"
31 #include "tensorflow/compiler/xla/client/sharding_builder.h"
32 #include "tensorflow/compiler/xla/client/xla_computation.h"
33 #include "tensorflow/compiler/xla/execution_options_util.h"
34 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
37 #include "tensorflow/compiler/xla/service/shape_inference.h"
38 #include "tensorflow/compiler/xla/util.h"
39
40 namespace xla {
41
42 using absl::StrCat;
43
44 namespace {
45
46 static const char kNameSeparator = '.';
47
48 // Retrieves the base name of an instruction or computation fully qualified
49 // name, using separator as boundary between the initial base name part, and
50 // the numeric identification.
GetBaseName(const string & name,char separator)51 string GetBaseName(const string& name, char separator) {
52 auto pos = name.rfind(separator);
53 CHECK_NE(pos, string::npos) << name;
54 return name.substr(0, pos);
55 }
56
57 // Generates a fully qualified computation/instruction name.
GetFullName(const string & base_name,char separator,int64 id)58 string GetFullName(const string& base_name, char separator, int64 id) {
59 const char separator_str[] = {separator, '\0'};
60 return StrCat(base_name, separator_str, id);
61 }
62
63 // Common function to standardize setting name and IDs on computation and
64 // instruction proto entities.
65 template <typename T>
SetProtoIdAndName(T * entry,const string & base_name,char separator,int64 id)66 void SetProtoIdAndName(T* entry, const string& base_name, char separator,
67 int64 id) {
68 entry->set_id(id);
69 entry->set_name(GetFullName(base_name, separator, id));
70 }
71
72 } // namespace
73
operator -(const XlaOp & x)74 XlaOp operator-(const XlaOp& x) { return Neg(x); }
operator +(const XlaOp & x,const XlaOp & y)75 XlaOp operator+(const XlaOp& x, const XlaOp& y) { return Add(x, y); }
operator -(const XlaOp & x,const XlaOp & y)76 XlaOp operator-(const XlaOp& x, const XlaOp& y) { return Sub(x, y); }
operator *(const XlaOp & x,const XlaOp & y)77 XlaOp operator*(const XlaOp& x, const XlaOp& y) { return Mul(x, y); }
operator /(const XlaOp & x,const XlaOp & y)78 XlaOp operator/(const XlaOp& x, const XlaOp& y) { return Div(x, y); }
operator %(const XlaOp & x,const XlaOp & y)79 XlaOp operator%(const XlaOp& x, const XlaOp& y) { return Rem(x, y); }
80
operator ~(const XlaOp & x)81 XlaOp operator~(const XlaOp& x) { return Not(x); }
operator &(const XlaOp & x,const XlaOp & y)82 XlaOp operator&(const XlaOp& x, const XlaOp& y) { return And(x, y); }
operator |(const XlaOp & x,const XlaOp & y)83 XlaOp operator|(const XlaOp& x, const XlaOp& y) { return Or(x, y); }
operator ^(const XlaOp & x,const XlaOp & y)84 XlaOp operator^(const XlaOp& x, const XlaOp& y) { return Xor(x, y); }
operator <<(const XlaOp & x,const XlaOp & y)85 XlaOp operator<<(const XlaOp& x, const XlaOp& y) { return ShiftLeft(x, y); }
86
operator >>(const XlaOp & x,const XlaOp & y)87 XlaOp operator>>(const XlaOp& x, const XlaOp& y) {
88 XlaBuilder* builder = x.builder();
89 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
90 TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
91 if (!ShapeUtil::ElementIsIntegral(shape)) {
92 return InvalidArgument(
93 "Argument to >> operator does not have an integral type (%s).",
94 ShapeUtil::HumanString(shape));
95 }
96 if (ShapeUtil::ElementIsSigned(shape)) {
97 return ShiftRightArithmetic(x, y);
98 } else {
99 return ShiftRightLogical(x, y);
100 }
101 });
102 }
103
GetShape(const XlaOp & op) const104 StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const {
105 TF_RETURN_IF_ERROR(first_error_);
106
107 TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op));
108 return Shape(instr->shape());
109 }
110
GetOperandShapes(absl::Span<const XlaOp> operands) const111 StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
112 absl::Span<const XlaOp> operands) const {
113 std::vector<Shape> operand_shapes;
114 for (const XlaOp& operand : operands) {
115 TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
116 operand_shapes.push_back(shape);
117 }
118 return operand_shapes;
119 }
120
XlaBuilder(const string & computation_name)121 XlaBuilder::XlaBuilder(const string& computation_name)
122 : name_(computation_name) {}
123
~XlaBuilder()124 XlaBuilder::~XlaBuilder() {}
125
ReportError(const Status & error)126 XlaOp XlaBuilder::ReportError(const Status& error) {
127 CHECK(!error.ok());
128 if (die_immediately_on_error_) {
129 LOG(FATAL) << "error building computation: " << error;
130 }
131
132 if (first_error_.ok()) {
133 first_error_ = error;
134 first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
135 }
136 return XlaOp(this);
137 }
138
ReportErrorOrReturn(const StatusOr<XlaOp> & op)139 XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr<XlaOp>& op) {
140 if (!first_error_.ok()) {
141 return XlaOp(this);
142 }
143 if (!op.ok()) {
144 return ReportError(op.status());
145 }
146 return op.ValueOrDie();
147 }
148
ReportErrorOrReturn(const std::function<StatusOr<XlaOp> ()> & op_creator)149 XlaOp XlaBuilder::ReportErrorOrReturn(
150 const std::function<StatusOr<XlaOp>()>& op_creator) {
151 return ReportErrorOrReturn(op_creator());
152 }
153
GetProgramShape(int64 root_id) const154 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
155 TF_RETURN_IF_ERROR(first_error_);
156 TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
157 LookUpInstructionByHandle(root_id));
158
159 ProgramShape program_shape;
160
161 *program_shape.mutable_result() = Shape(root_proto->shape());
162
163 // Check that the parameter numbers are continuous from 0, and add parameter
164 // shapes and names to the program shape.
165 const int64 param_count = parameter_numbers_.size();
166 for (int64 i = 0; i < param_count; i++) {
167 program_shape.add_parameters();
168 program_shape.add_parameter_names();
169 }
170 for (const HloInstructionProto& instr : instructions_) {
171 // Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So
172 // to verify continuity, we just need to verify that every parameter is in
173 // the right range.
174 if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
175 const int64 index = instr.parameter_number();
176 TF_RET_CHECK(index >= 0 && index < param_count)
177 << "invalid parameter number: " << index;
178 *program_shape.mutable_parameters(index) = Shape(instr.shape());
179 *program_shape.mutable_parameter_names(index) = instr.name();
180 }
181 }
182 return program_shape;
183 }
184
GetProgramShape() const185 StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const {
186 TF_RET_CHECK(!instructions_.empty());
187 return GetProgramShape(instructions_.back().id());
188 }
189
GetProgramShape(XlaOp root) const190 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(XlaOp root) const {
191 if (root.builder_ != this) {
192 return InvalidArgument("Given root operation is not in this computation.");
193 }
194 return GetProgramShape(root.handle());
195 }
196
IsConstantVisitor(const int64 op_handle,absl::flat_hash_set<int64> * visited,bool * is_constant) const197 void XlaBuilder::IsConstantVisitor(const int64 op_handle,
198 absl::flat_hash_set<int64>* visited,
199 bool* is_constant) const {
200 if (visited->contains(op_handle) || !*is_constant) {
201 return;
202 }
203
204 const HloInstructionProto& instr =
205 *(LookUpInstructionByHandle(op_handle).ValueOrDie());
206 const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
207 switch (opcode) {
208 default:
209 for (const int64 operand_id : instr.operand_ids()) {
210 IsConstantVisitor(operand_id, visited, is_constant);
211 }
212 // TODO(b/32495713): We aren't checking the called computations.
213 break;
214 case HloOpcode::kGetDimensionSize: {
215 int64 dimension_number = instr.dimensions(0);
216 const HloInstructionProto& operand =
217 *(LookUpInstructionByHandle(instr.operand_ids(0)).ValueOrDie());
218 Shape operand_shape(operand.shape());
219 if (operand_shape.is_dynamic_dimension(dimension_number)) {
220 *is_constant = false;
221 }
222 break;
223 }
224
225 // Non functional ops.
226 case HloOpcode::kRng:
227 case HloOpcode::kAllReduce:
228 // TODO(b/33009255): Implement constant folding for cross replica sum.
229 case HloOpcode::kInfeed:
230 case HloOpcode::kOutfeed:
231 case HloOpcode::kCall:
232 // TODO(b/32495713): We aren't checking the to_apply computation itself,
233 // so we conservatively say that computations containing the Call op
234 // cannot be constant. We cannot set is_functional=false in other similar
235 // cases since we're already relying on IsConstant to return true.
236 case HloOpcode::kCustomCall:
237 case HloOpcode::kWhile:
238 // TODO(b/32495713): We aren't checking the condition and body
239 // computations themselves.
240 case HloOpcode::kScatter:
241 // TODO(b/32495713): We aren't checking the embedded computation in
242 // Scatter.
243 case HloOpcode::kSend:
244 case HloOpcode::kRecv:
245 case HloOpcode::kParameter:
246 *is_constant = false;
247 break;
248 }
249 if (!*is_constant) {
250 VLOG(1) << "Non-constant: " << instr.name();
251 }
252 visited->insert(op_handle);
253 }
254
SetDynamicBinding(int64 dynamic_size_param_num,ShapeIndex dynamic_size_param_index,int64 target_param_num,ShapeIndex target_param_index,int64 target_dim_num)255 Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num,
256 ShapeIndex dynamic_size_param_index,
257 int64 target_param_num,
258 ShapeIndex target_param_index,
259 int64 target_dim_num) {
260 bool param_exists = false;
261 for (HloInstructionProto& instr : instructions_) {
262 if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
263 instr.parameter_number() == target_param_num) {
264 param_exists = true;
265 Shape param_shape(instr.shape());
266 Shape* param_shape_ptr = ¶m_shape;
267 for (int64 index : target_param_index) {
268 param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index);
269 }
270 // TODO(b/121223198): Set `is_dynamic` to the parameter shape when XLA
271 // backend can handle dynamic dimensions.
272 *instr.mutable_shape() = param_shape.ToProto();
273 }
274 }
275
276 if (!param_exists) {
277 return InvalidArgument(
278 "Asked to mark parameter %lld as dynamic sized parameter, but the "
279 "doesn't exists",
280 target_param_num);
281 }
282
283 TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind(
284 DynamicParameterBinding::DynamicParameter{dynamic_size_param_num,
285 dynamic_size_param_index},
286 DynamicParameterBinding::DynamicDimension{
287 target_param_num, target_param_index, target_dim_num}));
288 return Status::OK();
289 }
290
BuildAndNoteError()291 XlaComputation XlaBuilder::BuildAndNoteError() {
292 DCHECK(parent_builder_ != nullptr);
293 auto build_status = Build();
294 if (!build_status.ok()) {
295 parent_builder_->ReportError(
296 AddStatus(build_status.status(), absl::StrCat("error from: ", name_)));
297 return {};
298 }
299 return build_status.ConsumeValueOrDie();
300 }
301
GetCurrentStatus() const302 Status XlaBuilder::GetCurrentStatus() const {
303 if (!first_error_.ok()) {
304 string backtrace;
305 first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
306 return AppendStatus(first_error_, backtrace);
307 }
308 return Status::OK();
309 }
310
Build(bool remove_dynamic_dimensions)311 StatusOr<XlaComputation> XlaBuilder::Build(bool remove_dynamic_dimensions) {
312 TF_RETURN_IF_ERROR(GetCurrentStatus());
313 return Build(instructions_.back().id(), remove_dynamic_dimensions);
314 }
315
Build(XlaOp root,bool remove_dynamic_dimensions)316 StatusOr<XlaComputation> XlaBuilder::Build(XlaOp root,
317 bool remove_dynamic_dimensions) {
318 if (root.builder_ != this) {
319 return InvalidArgument("Given root operation is not in this computation.");
320 }
321 return Build(root.handle(), remove_dynamic_dimensions);
322 }
323
Build(int64 root_id,bool remove_dynamic_dimensions)324 StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id,
325 bool remove_dynamic_dimensions) {
326 TF_RETURN_IF_ERROR(GetCurrentStatus());
327
328 // TODO(b/121223198): XLA backend cannot handle dynamic dimensions yet, remove
329 // all dynamic dimensions before building xla program until we have support in
330 // the backend.
331 if (remove_dynamic_dimensions) {
332 std::function<void(ShapeProto*)> remove_dynamic_dimension =
333 [&](ShapeProto* shape) {
334 if (shape->tuple_shapes_size() != 0) {
335 for (int64 i = 0; i < shape->tuple_shapes_size(); ++i) {
336 remove_dynamic_dimension(shape->mutable_tuple_shapes(i));
337 }
338 }
339 for (int64 i = 0; i < shape->dimensions_size(); ++i) {
340 shape->set_is_dynamic_dimension(i, false);
341 }
342 };
343
344 for (auto& instruction : instructions_) {
345 remove_dynamic_dimension(instruction.mutable_shape());
346 }
347 }
348
349 HloComputationProto entry;
350 SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId());
351 TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id));
352 *entry.mutable_program_shape() = program_shape.ToProto();
353 entry.set_root_id(root_id);
354
355 for (auto& instruction : instructions_) {
356 // Ensures that the instruction names are unique among the whole graph.
357 instruction.set_name(
358 GetFullName(instruction.name(), kNameSeparator, instruction.id()));
359 entry.add_instructions()->Swap(&instruction);
360 }
361
362 XlaComputation computation(entry.id());
363 HloModuleProto* module = computation.mutable_proto();
364 module->set_name(entry.name());
365 module->set_id(entry.id());
366 module->set_entry_computation_name(entry.name());
367 module->set_entry_computation_id(entry.id());
368 *module->mutable_host_program_shape() = entry.program_shape();
369 for (auto& e : embedded_) {
370 module->add_computations()->Swap(&e.second);
371 }
372 module->add_computations()->Swap(&entry);
373 if (!input_output_aliases_.empty()) {
374 TF_RETURN_IF_ERROR(
375 PopulateInputOutputAlias(module, program_shape, input_output_aliases_));
376 }
377 *(module->mutable_dynamic_parameter_binding()) =
378 dynamic_parameter_binding_.ToProto();
379
380 // Clear data held by this builder.
381 this->instructions_.clear();
382 this->handle_to_index_.clear();
383 this->embedded_.clear();
384 this->parameter_numbers_.clear();
385
386 return std::move(computation);
387 }
388
PopulateInputOutputAlias(HloModuleProto * module,const ProgramShape & program_shape,const std::vector<InputOutputAlias> & input_output_aliases)389 /* static */ Status XlaBuilder::PopulateInputOutputAlias(
390 HloModuleProto* module, const ProgramShape& program_shape,
391 const std::vector<InputOutputAlias>& input_output_aliases) {
392 HloInputOutputAliasConfig config(program_shape.result());
393 for (auto& alias : input_output_aliases) {
394 // The HloInputOutputAliasConfig does not do parameter validation as it only
395 // carries the result shape. Maybe it should be constructed with a
396 // ProgramShape to allow full validation. We will still get an error when
397 // trying to compile the HLO module, but would be better to have validation
398 // at this stage.
399 if (alias.param_number >= program_shape.parameters_size()) {
400 return InvalidArgument("Invalid parameter number %ld (total %ld)",
401 alias.param_number,
402 program_shape.parameters_size());
403 }
404 const Shape& parameter_shape = program_shape.parameters(alias.param_number);
405 if (!ShapeUtil::IndexIsValid(parameter_shape, alias.param_index)) {
406 return InvalidArgument("Invalid parameter %ld index: %s",
407 alias.param_number,
408 alias.param_index.ToString().c_str());
409 }
410 TF_RETURN_IF_ERROR(config.SetUpAlias(
411 alias.output_index, alias.param_number, alias.param_index,
412 HloInputOutputAliasConfig::AliasKind::kUserAlias));
413 }
414 *module->mutable_input_output_alias() = config.ToProto();
415 return Status::OK();
416 }
417
InDimBroadcast(const Shape & shape,const XlaOp & operand,absl::Span<const int64> broadcast_dimensions)418 StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
419 const Shape& shape, const XlaOp& operand,
420 absl::Span<const int64> broadcast_dimensions) {
421 TF_RETURN_IF_ERROR(first_error_);
422
423 HloInstructionProto instr;
424 *instr.mutable_shape() = shape.ToProto();
425 for (int64 dim : broadcast_dimensions) {
426 instr.add_dimensions(dim);
427 }
428 return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand});
429 }
430
AddBroadcastSequence(const Shape & output_shape,const XlaOp & operand)431 StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
432 const XlaOp& operand) {
433 TF_RETURN_IF_ERROR(first_error_);
434
435 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
436
437 CHECK(ShapeUtil::IsScalar(operand_shape) ||
438 operand_shape.rank() == output_shape.rank());
439 Shape broadcast_shape =
440 ShapeUtil::ChangeElementType(output_shape, operand_shape.element_type());
441
442 // Do explicit broadcast for scalar.
443 if (ShapeUtil::IsScalar(operand_shape)) {
444 return InDimBroadcast(broadcast_shape, operand, {});
445 }
446
447 // Do explicit broadcast for degenerate broadcast.
448 std::vector<int64> broadcast_dimensions;
449 std::vector<int64> reshaped_dimensions;
450 for (int i = 0; i < operand_shape.rank(); i++) {
451 if (operand_shape.dimensions(i) == output_shape.dimensions(i)) {
452 broadcast_dimensions.push_back(i);
453 reshaped_dimensions.push_back(operand_shape.dimensions(i));
454 } else {
455 TF_RET_CHECK(operand_shape.dimensions(i) == 1)
456 << "An explicit broadcast sequence requires the broadcasted "
457 "dimensions to be trivial; operand shape: "
458 << operand_shape << "; output_shape: " << output_shape;
459 }
460 }
461 // Eliminate the size one dimensions.
462 TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand,
463 Reshape(ShapeUtil::MakeShape(operand_shape.element_type(),
464 reshaped_dimensions),
465 operand));
466 // Broadcast 'reshape' up to the larger size.
467 return InDimBroadcast(broadcast_shape, reshaped_operand,
468 broadcast_dimensions);
469 }
470
UnaryOp(HloOpcode unop,const XlaOp & operand)471 XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) {
472 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
473 HloInstructionProto instr;
474 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
475 TF_ASSIGN_OR_RETURN(Shape shape,
476 ShapeInference::InferUnaryOpShape(unop, operand_shape));
477 *instr.mutable_shape() = shape.ToProto();
478 return AddInstruction(std::move(instr), unop, {operand});
479 });
480 }
481
BinaryOp(HloOpcode binop,const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions,absl::optional<ComparisonDirection> direction)482 XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
483 absl::Span<const int64> broadcast_dimensions,
484 absl::optional<ComparisonDirection> direction) {
485 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
486 HloInstructionProto instr;
487 TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
488 TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
489 TF_ASSIGN_OR_RETURN(Shape shape,
490 ShapeInference::InferBinaryOpShape(
491 binop, lhs_shape, rhs_shape, broadcast_dimensions));
492 *instr.mutable_shape() = shape.ToProto();
493 if (binop == HloOpcode::kCompare) {
494 if (!direction.has_value()) {
495 return InvalidArgument(
496 "kCompare expects a ComparisonDirection, but none provided.");
497 }
498 instr.set_comparison_direction(ComparisonDirectionToString(*direction));
499 } else if (direction.has_value()) {
500 return InvalidArgument(
501 "A comparison direction is provided for a non-compare opcode: %s.",
502 HloOpcodeString(binop));
503 }
504
505 const int64 lhs_rank = lhs_shape.rank();
506 const int64 rhs_rank = rhs_shape.rank();
507
508 XlaOp updated_lhs = lhs;
509 XlaOp updated_rhs = rhs;
510
511 if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) {
512 const bool should_broadcast_lhs = lhs_rank < rhs_rank;
513 XlaOp from = should_broadcast_lhs ? lhs : rhs;
514 const Shape& from_shape = should_broadcast_lhs ? lhs_shape : rhs_shape;
515
516 std::vector<int64> to_size;
517 std::vector<bool> to_size_is_dynamic;
518 for (int i = 0; i < shape.rank(); i++) {
519 to_size.push_back(shape.dimensions(i));
520 to_size_is_dynamic.push_back(shape.is_dynamic_dimension(i));
521 }
522 for (int64 from_dim = 0; from_dim < from_shape.rank(); from_dim++) {
523 int64 to_dim = broadcast_dimensions[from_dim];
524 to_size[to_dim] = from_shape.dimensions(from_dim);
525 to_size_is_dynamic[to_dim] = from_shape.is_dynamic_dimension(from_dim);
526 }
527
528 const Shape& broadcasted_shape = ShapeUtil::MakeShape(
529 from_shape.element_type(), to_size, to_size_is_dynamic);
530 TF_ASSIGN_OR_RETURN(
531 XlaOp broadcasted_operand,
532 InDimBroadcast(broadcasted_shape, from, broadcast_dimensions));
533
534 updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs;
535 updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs;
536 }
537
538 TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, GetShape(updated_lhs));
539 if (!ShapeUtil::SameDimensions(shape, updated_lhs_shape)) {
540 TF_ASSIGN_OR_RETURN(updated_lhs,
541 AddBroadcastSequence(shape, updated_lhs));
542 }
543 TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, GetShape(updated_rhs));
544 if (!ShapeUtil::SameDimensions(shape, updated_rhs_shape)) {
545 TF_ASSIGN_OR_RETURN(updated_rhs,
546 AddBroadcastSequence(shape, updated_rhs));
547 }
548
549 return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs});
550 });
551 }
552
TernaryOp(HloOpcode triop,const XlaOp & lhs,const XlaOp & rhs,const XlaOp & ehs)553 XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
554 const XlaOp& ehs) {
555 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
556 HloInstructionProto instr;
557 TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
558 TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
559 TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, GetShape(ehs));
560 TF_ASSIGN_OR_RETURN(
561 Shape shape, ShapeInference::InferTernaryOpShape(triop, lhs_shape,
562 rhs_shape, ehs_shape));
563 *instr.mutable_shape() = shape.ToProto();
564 XlaOp updated_lhs = lhs;
565 XlaOp updated_rhs = rhs;
566 XlaOp updated_ehs = ehs;
567 if (!shape.IsTuple()) {
568 if (!lhs_shape.IsTuple() &&
569 !ShapeUtil::SameDimensions(shape, lhs_shape)) {
570 // lhs is being implicitly broadcasted. Change to explicit.
571 TF_ASSIGN_OR_RETURN(updated_lhs, AddBroadcastSequence(shape, lhs));
572 }
573 if (!rhs_shape.IsTuple() &&
574 !ShapeUtil::SameDimensions(shape, rhs_shape)) {
575 // rhs is being implicitly broadcasted. Change to explicit.
576 TF_ASSIGN_OR_RETURN(updated_rhs, AddBroadcastSequence(shape, rhs));
577 }
578 if (!ehs_shape.IsTuple() &&
579 !ShapeUtil::SameDimensions(shape, ehs_shape)) {
580 // ehs is being implicitly broadcasted. Change to explicit.
581 TF_ASSIGN_OR_RETURN(updated_ehs, AddBroadcastSequence(shape, ehs));
582 }
583 }
584 return AddInstruction(std::move(instr), triop,
585 {updated_lhs, updated_rhs, updated_ehs});
586 });
587 }
588
ConstantLiteral(const LiteralSlice & literal)589 XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
590 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
591 HloInstructionProto instr;
592 *instr.mutable_shape() = literal.shape().ToProto();
593 *instr.mutable_literal() = literal.ToProto();
594 return AddInstruction(std::move(instr), HloOpcode::kConstant);
595 });
596 }
597
Iota(const Shape & shape,int64 iota_dimension)598 XlaOp XlaBuilder::Iota(const Shape& shape, int64 iota_dimension) {
599 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
600 HloInstructionProto instr;
601 *instr.mutable_shape() = shape.ToProto();
602 instr.add_dimensions(iota_dimension);
603 return AddInstruction(std::move(instr), HloOpcode::kIota);
604 });
605 }
606
Iota(PrimitiveType type,int64 size)607 XlaOp XlaBuilder::Iota(PrimitiveType type, int64 size) {
608 return Iota(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0);
609 }
610
Call(const XlaComputation & computation,absl::Span<const XlaOp> operands)611 XlaOp XlaBuilder::Call(const XlaComputation& computation,
612 absl::Span<const XlaOp> operands) {
613 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
614 HloInstructionProto instr;
615 std::vector<const Shape*> operand_shape_ptrs;
616 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
617 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
618 [](const Shape& shape) { return &shape; });
619 TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
620 computation.GetProgramShape());
621 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape(
622 operand_shape_ptrs,
623 /*to_apply=*/called_program_shape));
624 *instr.mutable_shape() = shape.ToProto();
625
626 AddCalledComputation(computation, &instr);
627
628 return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
629 });
630 }
631
Parameter(int64 parameter_number,const Shape & shape,const string & name)632 XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
633 const string& name) {
634 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
635 HloInstructionProto instr;
636 if (!parameter_numbers_.insert(parameter_number).second) {
637 return InvalidArgument("parameter %d already registered",
638 parameter_number);
639 }
640 instr.set_parameter_number(parameter_number);
641 instr.set_name(name);
642 *instr.mutable_shape() = shape.ToProto();
643 return AddInstruction(std::move(instr), HloOpcode::kParameter);
644 });
645 }
646
Broadcast(const XlaOp & operand,absl::Span<const int64> broadcast_sizes)647 XlaOp XlaBuilder::Broadcast(const XlaOp& operand,
648 absl::Span<const int64> broadcast_sizes) {
649 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
650 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
651 TF_ASSIGN_OR_RETURN(
652 const Shape& shape,
653 ShapeInference::InferBroadcastShape(operand_shape, broadcast_sizes));
654
655 // The client-level broadcast op just appends dimensions on the left (adds
656 // lowest numbered dimensions). The HLO broadcast instruction is more
657 // flexible and can add new dimensions anywhere. The instruction's
658 // dimensions field maps operand dimensions to dimensions in the broadcast
659 // output, so to append dimensions on the left the instruction's dimensions
660 // should just be the n highest dimension numbers of the output shape where
661 // n is the number of input dimensions.
662 const int64 operand_rank = operand_shape.rank();
663 std::vector<int64> dimensions(operand_rank);
664 for (int i = 0; i < operand_rank; ++i) {
665 dimensions[i] = i + shape.rank() - operand_rank;
666 }
667 return InDimBroadcast(shape, operand, dimensions);
668 });
669 }
670
BroadcastInDim(const XlaOp & operand,const absl::Span<const int64> out_dim_size,const absl::Span<const int64> broadcast_dimensions)671 XlaOp XlaBuilder::BroadcastInDim(
672 const XlaOp& operand, const absl::Span<const int64> out_dim_size,
673 const absl::Span<const int64> broadcast_dimensions) {
674 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
675 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
676 // Output shape, in the case of degenerate broadcast, the out_dim_size is
677 // not necessarily the same as the dimension sizes of the output shape.
678 auto output_shape =
679 ShapeUtil::MakeShape(operand_shape.element_type(), out_dim_size);
680 for (int i = 0; i < broadcast_dimensions.size(); i++) {
681 if (broadcast_dimensions[i] < 0 ||
682 broadcast_dimensions[i] > out_dim_size.size()) {
683 return InvalidArgument("Broadcast dimension %lld is out of bound",
684 broadcast_dimensions[i]);
685 }
686 output_shape.set_dynamic_dimension(broadcast_dimensions[i],
687 operand_shape.is_dynamic_dimension(i));
688 }
689
690 TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape(
691 operand_shape, output_shape, broadcast_dimensions)
692 .status());
693 std::vector<int64> in_dim_size(out_dim_size.begin(), out_dim_size.end());
694 for (int i = 0; i < broadcast_dimensions.size(); i++) {
695 in_dim_size[broadcast_dimensions[i]] = operand_shape.dimensions(i);
696 }
697 const auto& in_dim_shape =
698 ShapeUtil::MakeShape(operand_shape.element_type(), in_dim_size);
699 TF_ASSIGN_OR_RETURN(
700 XlaOp in_dim_broadcast,
701 InDimBroadcast(in_dim_shape, operand, broadcast_dimensions));
702
703 // If broadcast is not degenerate, return broadcasted result.
704 if (ShapeUtil::Equal(in_dim_shape, output_shape)) {
705 return in_dim_broadcast;
706 }
707
708 // Otherwise handle degenerate broadcast case.
709 return AddBroadcastSequence(output_shape, in_dim_broadcast);
710 });
711 }
712
Reshape(const Shape & shape,const XlaOp & operand)713 StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) {
714 TF_RETURN_IF_ERROR(first_error_);
715
716 HloInstructionProto instr;
717 *instr.mutable_shape() = shape.ToProto();
718 return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand});
719 }
720
Slice(const XlaOp & operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)721 XlaOp XlaBuilder::Slice(const XlaOp& operand,
722 absl::Span<const int64> start_indices,
723 absl::Span<const int64> limit_indices,
724 absl::Span<const int64> strides) {
725 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
726 HloInstructionProto instr;
727 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
728 TF_ASSIGN_OR_RETURN(
729 Shape shape, ShapeInference::InferSliceShape(
730 operand_shape, start_indices, limit_indices, strides));
731 *instr.mutable_shape() = shape.ToProto();
732 for (int i = 0; i < start_indices.size(); i++) {
733 auto* slice_config = instr.add_slice_dimensions();
734 slice_config->set_start(start_indices[i]);
735 slice_config->set_limit(limit_indices[i]);
736 slice_config->set_stride(strides[i]);
737 }
738
739 return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
740 });
741 }
742
SliceInDim(const XlaOp & operand,int64 start_index,int64 limit_index,int64 stride,int64 dimno)743 XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index,
744 int64 limit_index, int64 stride, int64 dimno) {
745 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
746 TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
747 std::vector<int64> starts(shape.rank(), 0);
748 std::vector<int64> limits(shape.dimensions().begin(),
749 shape.dimensions().end());
750 std::vector<int64> strides(shape.rank(), 1);
751 starts[dimno] = start_index;
752 limits[dimno] = limit_index;
753 strides[dimno] = stride;
754 return Slice(operand, starts, limits, strides);
755 });
756 }
757
DynamicSlice(const XlaOp & operand,const XlaOp & start_indices,absl::Span<const int64> slice_sizes)758 XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
759 absl::Span<const int64> slice_sizes) {
760 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
761 HloInstructionProto instr;
762
763 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
764 TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
765 GetShape(start_indices));
766 TF_ASSIGN_OR_RETURN(Shape shape,
767 ShapeInference::InferDynamicSliceShape(
768 operand_shape, {start_indices_shape}, slice_sizes));
769 *instr.mutable_shape() = shape.ToProto();
770
771 for (int64 size : slice_sizes) {
772 instr.add_dynamic_slice_sizes(size);
773 }
774
775 return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice,
776 {operand, start_indices});
777 });
778 }
779
DynamicSlice(const XlaOp & operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)780 XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand,
781 absl::Span<const XlaOp> start_indices,
782 absl::Span<const int64> slice_sizes) {
783 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
784 HloInstructionProto instr;
785
786 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
787 std::vector<const Shape*> start_indices_shape_ptrs;
788 TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
789 GetOperandShapes(start_indices));
790 absl::c_transform(start_indices_shapes,
791 std::back_inserter(start_indices_shape_ptrs),
792 [](const Shape& shape) { return &shape; });
793 TF_ASSIGN_OR_RETURN(Shape shape,
794 ShapeInference::InferDynamicSliceShape(
795 operand_shape, start_indices_shapes, slice_sizes));
796 *instr.mutable_shape() = shape.ToProto();
797
798 for (int64 size : slice_sizes) {
799 instr.add_dynamic_slice_sizes(size);
800 }
801
802 std::vector<XlaOp> operands = {operand};
803 operands.insert(operands.end(), start_indices.begin(), start_indices.end());
804 return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands);
805 });
806 }
807
DynamicUpdateSlice(const XlaOp & operand,const XlaOp & update,const XlaOp & start_indices)808 XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
809 const XlaOp& start_indices) {
810 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
811 HloInstructionProto instr;
812
813 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
814 TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update));
815 TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
816 GetShape(start_indices));
817 TF_ASSIGN_OR_RETURN(
818 Shape shape, ShapeInference::InferDynamicUpdateSliceShape(
819 operand_shape, update_shape, {start_indices_shape}));
820 *instr.mutable_shape() = shape.ToProto();
821
822 return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
823 {operand, update, start_indices});
824 });
825 }
826
DynamicUpdateSlice(const XlaOp & operand,const XlaOp & update,absl::Span<const XlaOp> start_indices)827 XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
828 absl::Span<const XlaOp> start_indices) {
829 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
830 HloInstructionProto instr;
831
832 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
833 TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update));
834 std::vector<const Shape*> start_indices_shape_ptrs;
835 TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
836 GetOperandShapes(start_indices));
837 absl::c_transform(start_indices_shapes,
838 std::back_inserter(start_indices_shape_ptrs),
839 [](const Shape& shape) { return &shape; });
840 TF_ASSIGN_OR_RETURN(Shape shape,
841 ShapeInference::InferDynamicUpdateSliceShape(
842 operand_shape, update_shape, start_indices_shapes));
843 *instr.mutable_shape() = shape.ToProto();
844
845 std::vector<XlaOp> operands = {operand, update};
846 operands.insert(operands.end(), start_indices.begin(), start_indices.end());
847 return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
848 operands);
849 });
850 }
851
ConcatInDim(absl::Span<const XlaOp> operands,int64 dimension)852 XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
853 int64 dimension) {
854 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
855 HloInstructionProto instr;
856
857 std::vector<const Shape*> operand_shape_ptrs;
858 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
859 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
860 [](const Shape& shape) { return &shape; });
861 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape(
862 operand_shape_ptrs, dimension));
863 *instr.mutable_shape() = shape.ToProto();
864
865 instr.add_dimensions(dimension);
866
867 return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
868 });
869 }
870
Pad(const XlaOp & operand,const XlaOp & padding_value,const PaddingConfig & padding_config)871 XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value,
872 const PaddingConfig& padding_config) {
873 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
874 HloInstructionProto instr;
875
876 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
877 TF_ASSIGN_OR_RETURN(const Shape& padding_value_shape,
878 GetShape(padding_value));
879 TF_ASSIGN_OR_RETURN(
880 Shape shape, ShapeInference::InferPadShape(
881 operand_shape, padding_value_shape, padding_config));
882 *instr.mutable_shape() = shape.ToProto();
883 *instr.mutable_padding_config() = padding_config;
884
885 return AddInstruction(std::move(instr), HloOpcode::kPad,
886 {operand, padding_value});
887 });
888 }
889
Reshape(const XlaOp & operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes)890 XlaOp XlaBuilder::Reshape(const XlaOp& operand,
891 absl::Span<const int64> dimensions,
892 absl::Span<const int64> new_sizes) {
893 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
894 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
895 TF_ASSIGN_OR_RETURN(const Shape shape,
896 ShapeInference::InferReshapeShape(
897 operand_shape, dimensions, new_sizes));
898 XlaOp transposed = IsIdentityPermutation(dimensions)
899 ? operand
900 : Transpose(operand, dimensions);
901 return Reshape(shape, transposed);
902 });
903 }
904
Reshape(const XlaOp & operand,absl::Span<const int64> new_sizes)905 XlaOp XlaBuilder::Reshape(const XlaOp& operand,
906 absl::Span<const int64> new_sizes) {
907 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
908 TF_ASSIGN_OR_RETURN(Shape shape, GetShape(operand));
909 std::vector<int64> dimensions(shape.dimensions_size());
910 std::iota(dimensions.begin(), dimensions.end(), 0);
911 return Reshape(operand, dimensions, new_sizes);
912 });
913 }
914
Collapse(const XlaOp & operand,absl::Span<const int64> dimensions)915 XlaOp XlaBuilder::Collapse(const XlaOp& operand,
916 absl::Span<const int64> dimensions) {
917 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
918 if (dimensions.size() <= 1) {
919 // Not collapsing anything, trivially we can return the operand versus
920 // enqueueing a trivial reshape.
921 return operand;
922 }
923
924 // Out-of-order collapse is not supported.
925 // Checks that the collapsed dimensions are in order and consecutive.
926 for (absl::Span<const int64>::size_type i = 1; i < dimensions.size(); ++i) {
927 if (dimensions[i] - 1 != dimensions[i - 1]) {
928 return InvalidArgument(
929 "Collapsed dimensions are not in consecutive order.");
930 }
931 }
932
933 // Create a new sizes vector from the old shape, replacing the collapsed
934 // dimensions by the product of their sizes.
935 TF_ASSIGN_OR_RETURN(const Shape& original_shape, GetShape(operand));
936
937 VLOG(3) << "original shape: " << ShapeUtil::HumanString(original_shape);
938 VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ",");
939
940 std::vector<int64> new_sizes;
941 for (int i = 0; i < original_shape.rank(); ++i) {
942 if (i <= dimensions.front() || i > dimensions.back()) {
943 new_sizes.push_back(original_shape.dimensions(i));
944 } else {
945 new_sizes.back() *= original_shape.dimensions(i);
946 }
947 }
948
949 VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]";
950
951 return Reshape(operand, new_sizes);
952 });
953 }
954
Trace(const string & tag,const XlaOp & operand)955 void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
956 ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
957 HloInstructionProto instr;
958 *instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
959 *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto();
960 return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
961 });
962 }
963
Select(const XlaOp & pred,const XlaOp & on_true,const XlaOp & on_false)964 XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true,
965 const XlaOp& on_false) {
966 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
967 TF_ASSIGN_OR_RETURN(const Shape& true_shape, GetShape(on_true));
968 TF_ASSIGN_OR_RETURN(const Shape& false_shape, GetShape(on_false));
969 TF_RET_CHECK(true_shape.IsTuple() == false_shape.IsTuple());
970 HloOpcode opcode =
971 true_shape.IsTuple() ? HloOpcode::kTupleSelect : HloOpcode::kSelect;
972 return TernaryOp(opcode, pred, on_true, on_false);
973 });
974 }
975
Tuple(absl::Span<const XlaOp> elements)976 XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
977 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
978 HloInstructionProto instr;
979 std::vector<const Shape*> operand_shape_ptrs;
980 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
981 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
982 [](const Shape& shape) { return &shape; });
983 TF_ASSIGN_OR_RETURN(const Shape shape,
984 ShapeInference::InferVariadicOpShape(
985 HloOpcode::kTuple, operand_shape_ptrs));
986 *instr.mutable_shape() = shape.ToProto();
987 return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
988 });
989 }
990
GetTupleElement(const XlaOp & tuple_data,int64 index)991 XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) {
992 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
993 HloInstructionProto instr;
994 TF_ASSIGN_OR_RETURN(const Shape& tuple_shape, GetShape(tuple_data));
995 if (!tuple_shape.IsTuple()) {
996 return InvalidArgument(
997 "Operand to GetTupleElement() is not a tuple; got %s",
998 ShapeUtil::HumanString(tuple_shape));
999 }
1000 *instr.mutable_shape() =
1001 ShapeUtil::GetTupleElementShape(tuple_shape, index).ToProto();
1002
1003 instr.set_tuple_index(index);
1004
1005 return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
1006 {tuple_data});
1007 });
1008 }
1009
Dot(const XlaOp & lhs,const XlaOp & rhs,const PrecisionConfig * precision_config)1010 XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs,
1011 const PrecisionConfig* precision_config) {
1012 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1013 TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
1014
1015 DotDimensionNumbers dimension_numbers;
1016 dimension_numbers.add_lhs_contracting_dimensions(
1017 lhs_shape.dimensions_size() == 1 ? 0 : 1);
1018 dimension_numbers.add_rhs_contracting_dimensions(0);
1019 return DotGeneral(lhs, rhs, dimension_numbers, precision_config);
1020 });
1021 }
1022
DotGeneral(const XlaOp & lhs,const XlaOp & rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config)1023 XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
1024 const DotDimensionNumbers& dimension_numbers,
1025 const PrecisionConfig* precision_config) {
1026 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1027 HloInstructionProto instr;
1028 TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
1029 TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
1030 // If one operand is a scalar, just multiply the two operands.
1031 if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) {
1032 if (dimension_numbers.rhs_batch_dimensions_size() != 0 ||
1033 dimension_numbers.lhs_batch_dimensions_size() != 0 ||
1034 dimension_numbers.rhs_contracting_dimensions_size() != 0 ||
1035 dimension_numbers.lhs_contracting_dimensions_size() != 0) {
1036 return InvalidArgument(
1037 "Dots with scalar operands must have no contracting or batch "
1038 "dimensions");
1039 }
1040 return xla::Mul(lhs, rhs);
1041 }
1042 TF_ASSIGN_OR_RETURN(Shape shape,
1043 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape,
1044 dimension_numbers));
1045 *instr.mutable_shape() = shape.ToProto();
1046 *instr.mutable_dot_dimension_numbers() = dimension_numbers;
1047 if (precision_config != nullptr) {
1048 *instr.mutable_precision_config() = *precision_config;
1049 }
1050 return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
1051 });
1052 }
1053
VerifyConvolution(const Shape & lhs_shape,const Shape & rhs_shape,const ConvolutionDimensionNumbers & dimension_numbers) const1054 Status XlaBuilder::VerifyConvolution(
1055 const Shape& lhs_shape, const Shape& rhs_shape,
1056 const ConvolutionDimensionNumbers& dimension_numbers) const {
1057 if (lhs_shape.rank() != rhs_shape.rank()) {
1058 return InvalidArgument(
1059 "Convolution arguments must have same number of "
1060 "dimensions. Got: %s and %s",
1061 ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1062 }
1063 int num_dims = lhs_shape.rank();
1064 if (num_dims < 2) {
1065 return InvalidArgument(
1066 "Convolution expects argument arrays with >= 3 dimensions. "
1067 "Got: %s and %s",
1068 ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1069 }
1070 int num_spatial_dims = num_dims - 2;
1071
1072 const auto check_spatial_dimensions =
1073 [&](const char* const field_name,
1074 const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
1075 numbers) {
1076 if (numbers.size() != num_spatial_dims) {
1077 return InvalidArgument("Expected %d elements for %s, but got %d.",
1078 num_spatial_dims, field_name, numbers.size());
1079 }
1080 for (int i = 0; i < numbers.size(); ++i) {
1081 if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
1082 return InvalidArgument("Convolution %s[%d] is out of bounds: %d",
1083 field_name, i, numbers.Get(i));
1084 }
1085 }
1086 return Status::OK();
1087 };
1088 TF_RETURN_IF_ERROR(
1089 check_spatial_dimensions("input_spatial_dimensions",
1090 dimension_numbers.input_spatial_dimensions()));
1091 TF_RETURN_IF_ERROR(
1092 check_spatial_dimensions("kernel_spatial_dimensions",
1093 dimension_numbers.kernel_spatial_dimensions()));
1094 return check_spatial_dimensions(
1095 "output_spatial_dimensions",
1096 dimension_numbers.output_spatial_dimensions());
1097 }
1098
Conv(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,Padding padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1099 XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
1100 absl::Span<const int64> window_strides, Padding padding,
1101 int64 feature_group_count, int64 batch_group_count,
1102 const PrecisionConfig* precision_config) {
1103 return ConvWithGeneralDimensions(
1104 lhs, rhs, window_strides, padding,
1105 CreateDefaultConvDimensionNumbers(window_strides.size()),
1106 feature_group_count, batch_group_count, precision_config);
1107 }
1108
ConvWithGeneralPadding(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1109 XlaOp XlaBuilder::ConvWithGeneralPadding(
1110 const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
1111 absl::Span<const std::pair<int64, int64>> padding,
1112 int64 feature_group_count, int64 batch_group_count,
1113 const PrecisionConfig* precision_config) {
1114 return ConvGeneral(lhs, rhs, window_strides, padding,
1115 CreateDefaultConvDimensionNumbers(window_strides.size()),
1116 feature_group_count, batch_group_count, precision_config);
1117 }
1118
ConvWithGeneralDimensions(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1119 XlaOp XlaBuilder::ConvWithGeneralDimensions(
1120 const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
1121 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1122 int64 feature_group_count, int64 batch_group_count,
1123 const PrecisionConfig* precision_config) {
1124 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1125 TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
1126 TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
1127
1128 TF_RETURN_IF_ERROR(
1129 VerifyConvolution(lhs_shape, rhs_shape, dimension_numbers));
1130
1131 std::vector<int64> base_area_dimensions(
1132 dimension_numbers.input_spatial_dimensions_size());
1133 for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size();
1134 ++i) {
1135 base_area_dimensions[i] =
1136 lhs_shape.dimensions(dimension_numbers.input_spatial_dimensions(i));
1137 }
1138
1139 std::vector<int64> window_dimensions(
1140 dimension_numbers.kernel_spatial_dimensions_size());
1141 for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
1142 ++i) {
1143 window_dimensions[i] =
1144 rhs_shape.dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1145 }
1146
1147 return ConvGeneral(lhs, rhs, window_strides,
1148 MakePadding(base_area_dimensions, window_dimensions,
1149 window_strides, padding),
1150 dimension_numbers, feature_group_count,
1151 batch_group_count, precision_config);
1152 });
1153 }
1154
ConvGeneral(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1155 XlaOp XlaBuilder::ConvGeneral(
1156 const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
1157 absl::Span<const std::pair<int64, int64>> padding,
1158 const ConvolutionDimensionNumbers& dimension_numbers,
1159 int64 feature_group_count, int64 batch_group_count,
1160 const PrecisionConfig* precision_config) {
1161 return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
1162 dimension_numbers, feature_group_count,
1163 batch_group_count, precision_config);
1164 }
1165
ConvGeneralDilated(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1166 XlaOp XlaBuilder::ConvGeneralDilated(
1167 const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
1168 absl::Span<const std::pair<int64, int64>> padding,
1169 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1170 const ConvolutionDimensionNumbers& dimension_numbers,
1171 int64 feature_group_count, int64 batch_group_count,
1172 const PrecisionConfig* precision_config) {
1173 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1174 HloInstructionProto instr;
1175 TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
1176 TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
1177 TF_RETURN_IF_ERROR(
1178 VerifyConvolution(lhs_shape, rhs_shape, dimension_numbers));
1179
1180 std::vector<int64> window_dimensions(
1181 dimension_numbers.kernel_spatial_dimensions_size());
1182 for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
1183 ++i) {
1184 window_dimensions[i] =
1185 rhs_shape.dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1186 }
1187 TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
1188 MakeWindow(window_dimensions, window_strides, padding,
1189 lhs_dilation, rhs_dilation));
1190
1191 TF_ASSIGN_OR_RETURN(
1192 Shape shape, ShapeInference::InferConvolveShape(
1193 lhs_shape, rhs_shape, feature_group_count,
1194 batch_group_count, instr.window(), dimension_numbers));
1195 *instr.mutable_shape() = shape.ToProto();
1196
1197 *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
1198 instr.set_feature_group_count(feature_group_count);
1199 instr.set_batch_group_count(batch_group_count);
1200
1201 if (precision_config != nullptr) {
1202 *instr.mutable_precision_config() = *precision_config;
1203 }
1204
1205 return AddInstruction(std::move(instr), HloOpcode::kConvolution,
1206 {lhs, rhs});
1207 });
1208 }
1209
MakeWindow(absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation) const1210 StatusOr<Window> XlaBuilder::MakeWindow(
1211 absl::Span<const int64> window_dimensions,
1212 absl::Span<const int64> window_strides,
1213 absl::Span<const std::pair<int64, int64>> padding,
1214 absl::Span<const int64> lhs_dilation,
1215 absl::Span<const int64> rhs_dilation) const {
1216 const auto verify_size = [&](const size_t x, const char* x_name) {
1217 if (x == 0 || x == window_dimensions.size()) {
1218 return Status::OK();
1219 } else {
1220 return InvalidArgument(
1221 "%s", absl::StrCat(
1222 "Window has different number of window dimensions than of ",
1223 x_name,
1224 "\nNumber of window dimensions: ", window_dimensions.size(),
1225 "\nNumber of ", x_name, ": ", x, "\n"));
1226 }
1227 };
1228 TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides"));
1229 TF_RETURN_IF_ERROR(verify_size(padding.size(), "padding entries"));
1230 TF_RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors"));
1231 TF_RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors"));
1232
1233 Window window;
1234 for (size_t i = 0; i < window_dimensions.size(); i++) {
1235 auto dim = window.add_dimensions();
1236 dim->set_size(window_dimensions[i]);
1237 if (!window_strides.empty()) {
1238 dim->set_stride(window_strides[i]);
1239 } else {
1240 dim->set_stride(1);
1241 }
1242 if (!padding.empty()) {
1243 dim->set_padding_low(padding[i].first);
1244 dim->set_padding_high(padding[i].second);
1245 } else {
1246 dim->set_padding_low(0);
1247 dim->set_padding_high(0);
1248 }
1249 if (!lhs_dilation.empty()) {
1250 dim->set_base_dilation(lhs_dilation[i]);
1251 } else {
1252 dim->set_base_dilation(1);
1253 }
1254 if (!rhs_dilation.empty()) {
1255 dim->set_window_dilation(rhs_dilation[i]);
1256 } else {
1257 dim->set_window_dilation(1);
1258 }
1259 dim->set_window_reversal(false);
1260 }
1261 return window;
1262 }
1263
Fft(const XlaOp & operand,const FftType fft_type,const absl::Span<const int64> fft_length)1264 XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type,
1265 const absl::Span<const int64> fft_length) {
1266 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1267 HloInstructionProto instr;
1268 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
1269 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape(
1270 operand_shape, fft_type, fft_length));
1271 *instr.mutable_shape() = shape.ToProto();
1272 instr.set_fft_type(fft_type);
1273 for (int64 i : fft_length) {
1274 instr.add_fft_length(i);
1275 }
1276
1277 return AddInstruction(std::move(instr), HloOpcode::kFft, {operand});
1278 });
1279 }
1280
Infeed(const Shape & shape,const string & config)1281 XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
1282 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1283 HloInstructionProto instr;
1284 if (!LayoutUtil::HasLayout(shape)) {
1285 return InvalidArgument("Given shape to Infeed must have a layout");
1286 }
1287 const Shape infeed_instruction_shape =
1288 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1289 *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1290 instr.set_infeed_config(config);
1291
1292 if (shape.IsArray() && sharding() &&
1293 sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) {
1294 // TODO(b/110793772): Support tiled array-shaped infeeds.
1295 return InvalidArgument(
1296 "Tiled sharding is not yet supported for array-shaped infeeds");
1297 }
1298
1299 if (sharding() &&
1300 sharding()->type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
1301 return InvalidArgument(
1302 "Replicated sharding is not yet supported for infeeds");
1303 }
1304
1305 // Infeed takes a single token operand. Generate the token to pass to the
1306 // infeed.
1307 XlaOp token;
1308 auto make_token = [&]() {
1309 HloInstructionProto token_instr;
1310 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1311 return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {});
1312 };
1313 if (sharding()) {
1314 // Arbitrarily assign token to device 0.
1315 OpSharding sharding = sharding_builder::AssignDevice(0);
1316 XlaScopedShardingAssignment scoped_sharding(this, sharding);
1317 TF_ASSIGN_OR_RETURN(token, make_token());
1318 } else {
1319 TF_ASSIGN_OR_RETURN(token, make_token());
1320 }
1321
1322 // The sharding is set by the client according to the data tuple shape.
1323 // However, the shape of the infeed instruction is a tuple containing the
1324 // data and a token. For tuple sharding type, the sharding must be changed
1325 // to accommodate the token.
1326 XlaOp infeed;
1327 if (sharding() &&
1328 sharding()->type() == OpSharding::Type::OpSharding_Type_TUPLE) {
1329 // TODO(b/80000000): Remove this when clients have been updated to handle
1330 // tokens.
1331 OpSharding infeed_instruction_sharding = *sharding();
1332 // Arbitrarily assign the token to device 0.
1333 *infeed_instruction_sharding.add_tuple_shardings() =
1334 sharding_builder::AssignDevice(0);
1335 XlaScopedShardingAssignment scoped_sharding(this,
1336 infeed_instruction_sharding);
1337 TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1338 HloOpcode::kInfeed, {token}));
1339 } else {
1340 TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1341 HloOpcode::kInfeed, {token}));
1342 }
1343
1344 // The infeed instruction produces a tuple of the infed data and a token
1345 // type. Return XLA op containing the data.
1346 // TODO(b/80000000): Remove this when clients have been updated to handle
1347 // tokens.
1348 HloInstructionProto infeed_data;
1349 *infeed_data.mutable_shape() = shape.ToProto();
1350 infeed_data.set_tuple_index(0);
1351 return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement,
1352 {infeed});
1353 });
1354 }
1355
InfeedWithToken(const XlaOp & token,const Shape & shape,const string & config)1356 XlaOp XlaBuilder::InfeedWithToken(const XlaOp& token, const Shape& shape,
1357 const string& config) {
1358 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1359 HloInstructionProto instr;
1360 if (!LayoutUtil::HasLayout(shape)) {
1361 return InvalidArgument("Given shape to Infeed must have a layout");
1362 }
1363 const Shape infeed_instruction_shape =
1364 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1365 *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1366 instr.set_infeed_config(config);
1367
1368 if (shape.IsArray() && sharding() &&
1369 sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) {
1370 // TODO(b/110793772): Support tiled array-shaped infeeds.
1371 return InvalidArgument(
1372 "Tiled sharding is not yet supported for array-shaped infeeds");
1373 }
1374
1375 if (sharding() &&
1376 sharding()->type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
1377 return InvalidArgument(
1378 "Replicated sharding is not yet supported for infeeds");
1379 }
1380
1381 return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
1382 });
1383 }
1384
Outfeed(const XlaOp & operand,const Shape & shape_with_layout,const string & outfeed_config)1385 void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
1386 const string& outfeed_config) {
1387 ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1388 HloInstructionProto instr;
1389
1390 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1391
1392 // Check and set outfeed shape.
1393 if (!LayoutUtil::HasLayout(shape_with_layout)) {
1394 return InvalidArgument("Given shape to Outfeed must have a layout");
1395 }
1396 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
1397 if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) {
1398 return InvalidArgument(
1399 "Outfeed shape %s must be compatible with operand shape %s",
1400 ShapeUtil::HumanStringWithLayout(shape_with_layout),
1401 ShapeUtil::HumanStringWithLayout(operand_shape));
1402 }
1403 *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1404
1405 instr.set_outfeed_config(outfeed_config);
1406
1407 // Outfeed takes a token as its second operand. Generate the token to pass
1408 // to the outfeed.
1409 HloInstructionProto token_instr;
1410 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1411 TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
1412 HloOpcode::kAfterAll, {}));
1413
1414 TF_RETURN_IF_ERROR(
1415 AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand, token})
1416 .status());
1417
1418 // The outfeed instruction produces a token. However, existing users expect
1419 // a nil shape (empty tuple). This should only be relevant if the outfeed is
1420 // the root of a computation.
1421 // TODO(b/80000000): Remove this when clients have been updated to handle
1422 // tokens.
1423 HloInstructionProto tuple_instr;
1424 *tuple_instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
1425
1426 // The dummy tuple should have no sharding.
1427 {
1428 XlaScopedShardingAssignment scoped_sharding(this, OpSharding());
1429 TF_ASSIGN_OR_RETURN(
1430 XlaOp empty_tuple,
1431 AddInstruction(std::move(tuple_instr), HloOpcode::kTuple, {}));
1432 return empty_tuple;
1433 }
1434 });
1435 }
1436
OutfeedWithToken(const XlaOp & operand,const XlaOp & token,const Shape & shape_with_layout,const string & outfeed_config)1437 XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
1438 const Shape& shape_with_layout,
1439 const string& outfeed_config) {
1440 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1441 HloInstructionProto instr;
1442
1443 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1444
1445 // Check and set outfeed shape.
1446 if (!LayoutUtil::HasLayout(shape_with_layout)) {
1447 return InvalidArgument("Given shape to Outfeed must have a layout");
1448 }
1449 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
1450 if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) {
1451 return InvalidArgument(
1452 "Outfeed shape %s must be compatible with operand shape %s",
1453 ShapeUtil::HumanStringWithLayout(shape_with_layout),
1454 ShapeUtil::HumanStringWithLayout(operand_shape));
1455 }
1456 *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1457
1458 instr.set_outfeed_config(outfeed_config);
1459
1460 return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
1461 {operand, token});
1462 });
1463 }
1464
CreateToken()1465 XlaOp XlaBuilder::CreateToken() {
1466 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1467 HloInstructionProto instr;
1468 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1469 return AddInstruction(std::move(instr), HloOpcode::kAfterAll);
1470 });
1471 }
1472
AfterAll(absl::Span<const XlaOp> tokens)1473 XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
1474 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1475 if (tokens.empty()) {
1476 return InvalidArgument("AfterAll requires at least one operand");
1477 }
1478 for (int i = 0; i < tokens.size(); ++i) {
1479 const XlaOp& operand = tokens[i];
1480 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
1481 if (!operand_shape.IsToken()) {
1482 return InvalidArgument(
1483 "All operands to AfterAll must be tokens; operand %d has shape %s",
1484 i, ShapeUtil::HumanString(operand_shape));
1485 }
1486 }
1487 HloInstructionProto instr;
1488 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1489 return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens);
1490 });
1491 }
1492
CustomCall(const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout)1493 XlaOp XlaBuilder::CustomCall(
1494 const string& call_target_name, absl::Span<const XlaOp> operands,
1495 const Shape& shape, const string& opaque,
1496 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
1497 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1498 HloInstructionProto instr;
1499 if (absl::StartsWith(call_target_name, "$")) {
1500 return InvalidArgument(
1501 "Invalid custom_call_target \"%s\": Call targets that start with '$' "
1502 "are reserved for internal use.",
1503 call_target_name);
1504 }
1505 *instr.mutable_shape() = shape.ToProto();
1506 instr.set_custom_call_target(call_target_name);
1507 instr.set_custom_call_opaque(opaque);
1508 if (operand_shapes_with_layout.has_value()) {
1509 if (!LayoutUtil::HasLayout(shape)) {
1510 return InvalidArgument(
1511 "Result shape must have layout for custom call with constrained "
1512 "layout.");
1513 }
1514 if (operands.size() != operand_shapes_with_layout->size()) {
1515 return InvalidArgument(
1516 "Must specify a shape with layout for each operand for custom call "
1517 "with constrained layout; given %d shapes, expected %d",
1518 operand_shapes_with_layout->size(), operands.size());
1519 }
1520 instr.set_constrain_layout(true);
1521 int64 operand_num = 0;
1522 for (const Shape& operand_shape : *operand_shapes_with_layout) {
1523 if (!LayoutUtil::HasLayout(operand_shape)) {
1524 return InvalidArgument(
1525 "No layout specified for operand %d for custom call with "
1526 "constrained layout.",
1527 operand_num);
1528 }
1529 *instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
1530 ++operand_num;
1531 }
1532 }
1533 return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
1534 });
1535 }
1536
Transpose(const XlaOp & operand,absl::Span<const int64> permutation)1537 XlaOp XlaBuilder::Transpose(const XlaOp& operand,
1538 absl::Span<const int64> permutation) {
1539 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1540 HloInstructionProto instr;
1541 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
1542 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape(
1543 operand_shape, permutation));
1544 *instr.mutable_shape() = shape.ToProto();
1545 for (int64 dim : permutation) {
1546 instr.add_dimensions(dim);
1547 }
1548 return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
1549 });
1550 }
1551
Rev(const XlaOp & operand,absl::Span<const int64> dimensions)1552 XlaOp XlaBuilder::Rev(const XlaOp& operand,
1553 absl::Span<const int64> dimensions) {
1554 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1555 HloInstructionProto instr;
1556 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
1557 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape(
1558 operand_shape, dimensions));
1559 *instr.mutable_shape() = shape.ToProto();
1560 for (int64 dim : dimensions) {
1561 instr.add_dimensions(dim);
1562 }
1563 return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand});
1564 });
1565 }
1566
1567 namespace {
1568 // Switch from a floating point value to a integer value in such a way that when
1569 // using the integer value to compare, we get the same result for normal values,
1570 // and -Nan is treated as the smallest value, and Nan is treated as the largest
1571 // value.
1572 // If f is a float, and
1573 // x = bit_cast<int32>(f);
1574 // y = x < 0 ? numeric_limits<int32>::max() - x : x;
1575 // then y is ordered as an int32 such that finite values have the obvious order,
1576 // -0 is ordered before 0, and -NaN and NaN appear at the beginning and end of
1577 // the ordering.
1578 // Note that in order to avoid -x to overflow, we calculate
1579 // numeric_limits<int32>::max() - x as unsigned, and then convert back to
1580 // signed.
BitcastConvertFloatingPointToIntegral(const XlaOp & value,int64 bit_width)1581 XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value,
1582 int64 bit_width) {
1583 PrimitiveType signed_type;
1584 PrimitiveType unsigned_type;
1585 XlaOp max_value;
1586 switch (bit_width) {
1587 case 16:
1588 max_value =
1589 ConstantR0(value.builder(),
1590 static_cast<uint16>(std::numeric_limits<int16>::max()));
1591 signed_type = S16;
1592 unsigned_type = U16;
1593 break;
1594 case 32:
1595 max_value =
1596 ConstantR0(value.builder(),
1597 static_cast<uint32>(std::numeric_limits<int32>::max()));
1598 signed_type = S32;
1599 unsigned_type = U32;
1600 break;
1601 case 64:
1602 max_value =
1603 ConstantR0(value.builder(),
1604 static_cast<uint64>(std::numeric_limits<int64>::max()));
1605 signed_type = S64;
1606 unsigned_type = U64;
1607 break;
1608 default:
1609 return value.builder()->ReportError(
1610 InvalidArgument("Invalid bit width %lld for Comparator floating "
1611 "point parameter.",
1612 bit_width));
1613 }
1614 auto signed_value = BitcastConvertType(value, signed_type);
1615 auto unsigned_value = BitcastConvertType(value, unsigned_type);
1616 auto flipped_value =
1617 BitcastConvertType(Sub(max_value, unsigned_value), signed_type);
1618 auto is_negative =
1619 Lt(signed_value,
1620 ConstantLiteral(value.builder(), LiteralUtil::Zero(signed_type)));
1621 return Select(is_negative, flipped_value, signed_value);
1622 }
1623 } // namespace
1624
Sort(const XlaOp & keys,absl::Span<const XlaOp> values,int64 dimension)1625 XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span<const XlaOp> values,
1626 int64 dimension) {
1627 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1628 std::vector<XlaOp> operands{keys};
1629 for (const XlaOp& value : values) {
1630 operands.push_back(value);
1631 }
1632 // Build the default less-than comparator (copied from lib/comparators.cc).
1633 // TODO(b/122298745): Remove the deprecated API method so that this code
1634 // duplication can be deleted.
1635 auto b = this->CreateSubBuilder("comparator");
1636 std::vector<PrimitiveType> operand_types;
1637 for (const XlaOp& operand : operands) {
1638 TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(operand));
1639 operand_types.push_back(operand_shape.element_type());
1640 }
1641
1642 int64 parameter_count = 0;
1643 XlaOp first_lhs_param;
1644 XlaOp first_rhs_param;
1645
1646 for (auto operand_type : operand_types) {
1647 auto scalar_shape = ShapeUtil::MakeShape(operand_type, {});
1648 auto lhs_param =
1649 b->Parameter(parameter_count * 2, scalar_shape,
1650 absl::StrCat("p.", parameter_count, ".lhs"));
1651 auto rhs_param =
1652 b->Parameter(parameter_count * 2 + 1, scalar_shape,
1653 absl::StrCat("p.", parameter_count, ".rhs"));
1654 if (parameter_count == 0) {
1655 first_lhs_param = lhs_param;
1656 first_rhs_param = rhs_param;
1657 }
1658 ++parameter_count;
1659 }
1660 if (primitive_util::IsFloatingPointType(operand_types[0])) {
1661 PrimitiveType compare_type = operand_types[0];
1662 // Special-case handling for BF16. We currently do not support direct
1663 // comparisons with BF16, so we convert to F32 and then use the F32
1664 // comparison logic.
1665 if (compare_type == BF16) {
1666 compare_type = F32;
1667 first_lhs_param = b->ConvertElementType(first_lhs_param, F32);
1668 first_rhs_param = b->ConvertElementType(first_rhs_param, F32);
1669 }
1670 int64 bit_width = primitive_util::BitWidth(compare_type);
1671 first_lhs_param =
1672 BitcastConvertFloatingPointToIntegral(first_lhs_param, bit_width);
1673 first_rhs_param =
1674 BitcastConvertFloatingPointToIntegral(first_rhs_param, bit_width);
1675 }
1676 Lt(first_lhs_param, first_rhs_param);
1677
1678 TF_ASSIGN_OR_RETURN(auto comparator, b->Build());
1679 return Sort(operands, comparator, dimension, /*is_stable=*/false);
1680 });
1681 }
1682
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64 dimension,bool is_stable)1683 XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
1684 const XlaComputation& comparator, int64 dimension,
1685 bool is_stable) {
1686 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1687 HloInstructionProto instr;
1688 instr.set_is_stable(is_stable);
1689 std::vector<const Shape*> operand_shape_ptrs;
1690 TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes,
1691 GetOperandShapes(operands));
1692 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1693 [](const Shape& shape) { return &shape; });
1694 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape(
1695 HloOpcode::kSort, operand_shape_ptrs));
1696 *instr.mutable_shape() = shape.ToProto();
1697 if (dimension == -1) {
1698 TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(operands[0]));
1699 dimension = keys_shape.rank() - 1;
1700 }
1701 instr.add_dimensions(dimension);
1702 AddCalledComputation(comparator, &instr);
1703 return AddInstruction(std::move(instr), HloOpcode::kSort, operands);
1704 });
1705 }
1706
ConvertElementType(const XlaOp & operand,PrimitiveType new_element_type)1707 XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand,
1708 PrimitiveType new_element_type) {
1709 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1710 HloInstructionProto instr;
1711 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
1712 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
1713 operand_shape, new_element_type));
1714 *instr.mutable_shape() = shape.ToProto();
1715 return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand});
1716 });
1717 }
1718
BitcastConvertType(const XlaOp & operand,PrimitiveType new_element_type)1719 XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand,
1720 PrimitiveType new_element_type) {
1721 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1722 HloInstructionProto instr;
1723 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
1724 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
1725 operand_shape, new_element_type));
1726 *instr.mutable_shape() = shape.ToProto();
1727 return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
1728 {operand});
1729 });
1730 }
1731
Clamp(const XlaOp & min,const XlaOp & operand,const XlaOp & max)1732 XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand,
1733 const XlaOp& max) {
1734 return TernaryOp(HloOpcode::kClamp, min, operand, max);
1735 }
1736
Map(absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64> dimensions,absl::Span<const XlaOp> static_operands)1737 XlaOp XlaBuilder::Map(absl::Span<const XlaOp> operands,
1738 const XlaComputation& computation,
1739 absl::Span<const int64> dimensions,
1740 absl::Span<const XlaOp> static_operands) {
1741 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1742 if (!static_operands.empty()) {
1743 return Unimplemented("static_operands is not supported in Map");
1744 }
1745
1746 HloInstructionProto instr;
1747 std::vector<const Shape*> operand_shape_ptrs;
1748 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
1749 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1750 [](const Shape& shape) { return &shape; });
1751 TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
1752 computation.GetProgramShape());
1753 TF_ASSIGN_OR_RETURN(
1754 Shape shape, ShapeInference::InferMapShape(
1755 operand_shape_ptrs, called_program_shape, dimensions));
1756 *instr.mutable_shape() = shape.ToProto();
1757
1758 Shape output_shape(instr.shape());
1759 const int64 output_rank = output_shape.rank();
1760 AddCalledComputation(computation, &instr);
1761 std::vector<XlaOp> new_operands(operands.begin(), operands.end());
1762 for (XlaOp& new_operand : new_operands) {
1763 TF_ASSIGN_OR_RETURN(Shape shape, GetShape(new_operand));
1764 const int64 rank = shape.rank();
1765 if (rank != output_rank) {
1766 TF_ASSIGN_OR_RETURN(new_operand,
1767 InDimBroadcast(output_shape, new_operand, {}));
1768 TF_ASSIGN_OR_RETURN(shape, GetShape(new_operand));
1769 }
1770 if (!ShapeUtil::SameDimensions(output_shape, shape)) {
1771 TF_ASSIGN_OR_RETURN(new_operand,
1772 AddBroadcastSequence(output_shape, new_operand));
1773 }
1774 }
1775
1776 return AddInstruction(std::move(instr), HloOpcode::kMap, new_operands);
1777 });
1778 }
1779
RngOp(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)1780 XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
1781 absl::Span<const XlaOp> parameters,
1782 const Shape& shape) {
1783 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1784 HloInstructionProto instr;
1785
1786 // Check the number of parameters per RNG distribution.
1787 switch (distribution) {
1788 case RandomDistribution::RNG_NORMAL:
1789 case RandomDistribution::RNG_UNIFORM:
1790 if (parameters.size() != 2) {
1791 return InvalidArgument(
1792 "RNG distribution (%s) expects 2 parameters, but got %ld",
1793 RandomDistribution_Name(distribution), parameters.size());
1794 }
1795 break;
1796 default:
1797 LOG(FATAL) << "unhandled distribution " << distribution;
1798 }
1799
1800 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
1801 *instr.mutable_shape() = shape.ToProto();
1802
1803 instr.set_distribution(distribution);
1804
1805 return AddInstruction(std::move(instr), HloOpcode::kRng, parameters);
1806 });
1807 }
1808
RngNormal(const XlaOp & mu,const XlaOp & sigma,const Shape & shape)1809 XlaOp XlaBuilder::RngNormal(const XlaOp& mu, const XlaOp& sigma,
1810 const Shape& shape) {
1811 return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
1812 }
1813
RngUniform(const XlaOp & a,const XlaOp & b,const Shape & shape)1814 XlaOp XlaBuilder::RngUniform(const XlaOp& a, const XlaOp& b,
1815 const Shape& shape) {
1816 return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
1817 }
1818
While(const XlaComputation & condition,const XlaComputation & body,const XlaOp & init)1819 XlaOp XlaBuilder::While(const XlaComputation& condition,
1820 const XlaComputation& body, const XlaOp& init) {
1821 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1822 HloInstructionProto instr;
1823
1824 // Infer shape.
1825 TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape());
1826 TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
1827 condition.GetProgramShape());
1828 TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init));
1829 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape(
1830 condition_program_shape,
1831 body_program_shape, init_shape));
1832 *instr.mutable_shape() = shape.ToProto();
1833 // Body comes before condition computation in the vector.
1834 AddCalledComputation(body, &instr);
1835 AddCalledComputation(condition, &instr);
1836 return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
1837 });
1838 }
1839
Gather(const XlaOp & input,const XlaOp & start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes)1840 XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices,
1841 const GatherDimensionNumbers& dimension_numbers,
1842 absl::Span<const int64> slice_sizes) {
1843 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1844 HloInstructionProto instr;
1845
1846 TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
1847 TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
1848 GetShape(start_indices));
1849 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGatherShape(
1850 input_shape, start_indices_shape,
1851 dimension_numbers, slice_sizes));
1852 *instr.mutable_shape() = shape.ToProto();
1853
1854 *instr.mutable_gather_dimension_numbers() = dimension_numbers;
1855 for (int64 bound : slice_sizes) {
1856 instr.add_gather_slice_sizes(bound);
1857 }
1858
1859 return AddInstruction(std::move(instr), HloOpcode::kGather,
1860 {input, start_indices});
1861 });
1862 }
1863
Scatter(const XlaOp & input,const XlaOp & scatter_indices,const XlaOp & updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers)1864 XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices,
1865 const XlaOp& updates,
1866 const XlaComputation& update_computation,
1867 const ScatterDimensionNumbers& dimension_numbers) {
1868 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1869 HloInstructionProto instr;
1870
1871 TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
1872 TF_ASSIGN_OR_RETURN(const Shape& scatter_indices_shape,
1873 GetShape(scatter_indices));
1874 TF_ASSIGN_OR_RETURN(const Shape& updates_shape, GetShape(updates));
1875 TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
1876 update_computation.GetProgramShape());
1877 TF_ASSIGN_OR_RETURN(Shape shape,
1878 ShapeInference::InferScatterShape(
1879 input_shape, scatter_indices_shape, updates_shape,
1880 to_apply_shape, dimension_numbers));
1881 *instr.mutable_shape() = shape.ToProto();
1882
1883 *instr.mutable_scatter_dimension_numbers() = dimension_numbers;
1884
1885 AddCalledComputation(update_computation, &instr);
1886 return AddInstruction(std::move(instr), HloOpcode::kScatter,
1887 {input, scatter_indices, updates});
1888 });
1889 }
1890
Conditional(const XlaOp & predicate,const XlaOp & true_operand,const XlaComputation & true_computation,const XlaOp & false_operand,const XlaComputation & false_computation)1891 XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand,
1892 const XlaComputation& true_computation,
1893 const XlaOp& false_operand,
1894 const XlaComputation& false_computation) {
1895 // The index of true_computation must be 0 and that of false computation
1896 // must be 1.
1897 return Conditional(predicate, {&true_computation, &false_computation},
1898 {true_operand, false_operand});
1899 }
1900
Conditional(const XlaOp & branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)1901 XlaOp XlaBuilder::Conditional(
1902 const XlaOp& branch_index,
1903 absl::Span<const XlaComputation* const> branch_computations,
1904 absl::Span<const XlaOp> branch_operands) {
1905 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1906 HloInstructionProto instr;
1907
1908 TF_ASSIGN_OR_RETURN(const Shape& branch_index_shape,
1909 GetShape(branch_index));
1910 std::vector<Shape> branch_operand_shapes(branch_operands.size());
1911 std::vector<ProgramShape> branch_computation_shapes(
1912 branch_computations.size());
1913 for (int j = 0; j < branch_operands.size(); ++j) {
1914 TF_ASSIGN_OR_RETURN(branch_operand_shapes[j],
1915 GetShape(branch_operands[j]));
1916 TF_ASSIGN_OR_RETURN(branch_computation_shapes[j],
1917 branch_computations[j]->GetProgramShape());
1918 }
1919 TF_ASSIGN_OR_RETURN(const Shape shape,
1920 ShapeInference::InferConditionalShape(
1921 branch_index_shape, branch_computation_shapes,
1922 branch_operand_shapes));
1923 *instr.mutable_shape() = shape.ToProto();
1924
1925 for (const XlaComputation* branch_computation : branch_computations) {
1926 AddCalledComputation(*branch_computation, &instr);
1927 }
1928
1929 std::vector<XlaOp> operands(1, branch_index);
1930 for (const XlaOp branch_operand : branch_operands) {
1931 operands.emplace_back(branch_operand);
1932 }
1933 return AddInstruction(std::move(instr), HloOpcode::kConditional,
1934 absl::MakeSpan(operands));
1935 });
1936 }
1937
Reduce(const XlaOp & operand,const XlaOp & init_value,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)1938 XlaOp XlaBuilder::Reduce(const XlaOp& operand, const XlaOp& init_value,
1939 const XlaComputation& computation,
1940 absl::Span<const int64> dimensions_to_reduce) {
1941 return Reduce(absl::Span<const XlaOp>({operand}),
1942 absl::Span<const XlaOp>({init_value}), computation,
1943 dimensions_to_reduce);
1944 }
1945
Reduce(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)1946 XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
1947 absl::Span<const XlaOp> init_values,
1948 const XlaComputation& computation,
1949 absl::Span<const int64> dimensions_to_reduce) {
1950 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1951 HloInstructionProto instr;
1952
1953 TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
1954 computation.GetProgramShape());
1955
1956 std::vector<XlaOp> all_operands;
1957 all_operands.insert(all_operands.end(), operands.begin(), operands.end());
1958 all_operands.insert(all_operands.end(), init_values.begin(),
1959 init_values.end());
1960
1961 std::vector<const Shape*> operand_shape_ptrs;
1962 TF_ASSIGN_OR_RETURN(const auto& operand_shapes,
1963 GetOperandShapes(all_operands));
1964 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1965 [](const Shape& shape) { return &shape; });
1966
1967 TF_ASSIGN_OR_RETURN(
1968 Shape shape,
1969 ShapeInference::InferReduceShape(
1970 operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
1971 *instr.mutable_shape() = shape.ToProto();
1972
1973 for (int64 dim : dimensions_to_reduce) {
1974 instr.add_dimensions(dim);
1975 }
1976
1977 AddCalledComputation(computation, &instr);
1978
1979 return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands);
1980 });
1981 }
1982
ReduceAll(const XlaOp & operand,const XlaOp & init_value,const XlaComputation & computation)1983 XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value,
1984 const XlaComputation& computation) {
1985 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1986 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
1987 std::vector<int64> all_dimnos(operand_shape.rank());
1988 std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
1989 return Reduce(operand, init_value, computation, all_dimnos);
1990 });
1991 }
1992
ReduceWindow(const XlaOp & operand,const XlaOp & init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)1993 XlaOp XlaBuilder::ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
1994 const XlaComputation& computation,
1995 absl::Span<const int64> window_dimensions,
1996 absl::Span<const int64> window_strides,
1997 Padding padding) {
1998 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1999 HloInstructionProto instr;
2000
2001 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2002 TF_RETURN_IF_ERROR(
2003 ValidatePaddingValues(AsInt64Slice(operand_shape.dimensions()),
2004 window_dimensions, window_strides));
2005
2006 std::vector<std::pair<int64, int64>> padding_values =
2007 MakePadding(AsInt64Slice(operand_shape.dimensions()), window_dimensions,
2008 window_strides, padding);
2009 return ReduceWindowWithGeneralPadding(
2010 operand, init_value, computation, window_dimensions, window_strides,
2011 /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values);
2012 });
2013 }
2014
ReduceWindowWithGeneralPadding(const XlaOp & operand,const XlaOp & init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)2015 XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
2016 const XlaOp& operand, const XlaOp& init_value,
2017 const XlaComputation& computation,
2018 absl::Span<const int64> window_dimensions,
2019 absl::Span<const int64> window_strides,
2020 absl::Span<const int64> base_dilations,
2021 absl::Span<const int64> window_dilations,
2022 absl::Span<const std::pair<int64, int64>> padding) {
2023 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2024 HloInstructionProto instr;
2025
2026 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2027 TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value));
2028 TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2029 computation.GetProgramShape());
2030 TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
2031 MakeWindow(window_dimensions, window_strides, padding,
2032 /*lhs_dilation=*/base_dilations,
2033 /*rhs_dilation=*/window_dilations));
2034 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape(
2035 operand_shape, init_shape,
2036 instr.window(), to_apply_shape));
2037 *instr.mutable_shape() = shape.ToProto();
2038
2039 AddCalledComputation(computation, &instr);
2040 return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
2041 {operand, init_value});
2042 });
2043 }
2044
BatchNormTraining(const XlaOp & operand,const XlaOp & scale,const XlaOp & offset,float epsilon,int64 feature_index)2045 XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
2046 const XlaOp& offset, float epsilon,
2047 int64 feature_index) {
2048 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2049 HloInstructionProto instr;
2050
2051 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2052 TF_ASSIGN_OR_RETURN(const Shape& scale_shape, GetShape(scale));
2053 TF_ASSIGN_OR_RETURN(const Shape& offset_shape, GetShape(offset));
2054 TF_ASSIGN_OR_RETURN(
2055 Shape shape,
2056 ShapeInference::InferBatchNormTrainingShape(
2057 operand_shape, scale_shape, offset_shape, feature_index));
2058 *instr.mutable_shape() = shape.ToProto();
2059
2060 instr.set_epsilon(epsilon);
2061 instr.set_feature_index(feature_index);
2062
2063 return AddInstruction(std::move(instr), HloOpcode::kBatchNormTraining,
2064 {operand, scale, offset});
2065 });
2066 }
2067
BatchNormInference(const XlaOp & operand,const XlaOp & scale,const XlaOp & offset,const XlaOp & mean,const XlaOp & variance,float epsilon,int64 feature_index)2068 XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale,
2069 const XlaOp& offset, const XlaOp& mean,
2070 const XlaOp& variance, float epsilon,
2071 int64 feature_index) {
2072 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2073 HloInstructionProto instr;
2074
2075 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2076 TF_ASSIGN_OR_RETURN(const Shape& scale_shape, GetShape(scale));
2077 TF_ASSIGN_OR_RETURN(const Shape& offset_shape, GetShape(offset));
2078 TF_ASSIGN_OR_RETURN(const Shape& mean_shape, GetShape(mean));
2079 TF_ASSIGN_OR_RETURN(const Shape& variance_shape, GetShape(variance));
2080 TF_ASSIGN_OR_RETURN(
2081 Shape shape, ShapeInference::InferBatchNormInferenceShape(
2082 operand_shape, scale_shape, offset_shape, mean_shape,
2083 variance_shape, feature_index));
2084 *instr.mutable_shape() = shape.ToProto();
2085
2086 instr.set_epsilon(epsilon);
2087 instr.set_feature_index(feature_index);
2088
2089 return AddInstruction(std::move(instr), HloOpcode::kBatchNormInference,
2090 {operand, scale, offset, mean, variance});
2091 });
2092 }
2093
BatchNormGrad(const XlaOp & operand,const XlaOp & scale,const XlaOp & batch_mean,const XlaOp & batch_var,const XlaOp & grad_output,float epsilon,int64 feature_index)2094 XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
2095 const XlaOp& batch_mean, const XlaOp& batch_var,
2096 const XlaOp& grad_output, float epsilon,
2097 int64 feature_index) {
2098 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2099 HloInstructionProto instr;
2100
2101 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2102 TF_ASSIGN_OR_RETURN(const Shape& scale_shape, GetShape(scale));
2103 TF_ASSIGN_OR_RETURN(const Shape& batch_mean_shape, GetShape(batch_mean));
2104 TF_ASSIGN_OR_RETURN(const Shape& batch_var_shape, GetShape(batch_var));
2105 TF_ASSIGN_OR_RETURN(const Shape& grad_output_shape, GetShape(grad_output));
2106 TF_ASSIGN_OR_RETURN(Shape shape,
2107 ShapeInference::InferBatchNormGradShape(
2108 operand_shape, scale_shape, batch_mean_shape,
2109 batch_var_shape, grad_output_shape, feature_index));
2110 *instr.mutable_shape() = shape.ToProto();
2111
2112 instr.set_epsilon(epsilon);
2113 instr.set_feature_index(feature_index);
2114
2115 return AddInstruction(std::move(instr), HloOpcode::kBatchNormGrad,
2116 {operand, scale, batch_mean, batch_var, grad_output});
2117 });
2118 }
2119
CrossReplicaSum(const XlaOp & operand,absl::Span<const ReplicaGroup> replica_groups)2120 XlaOp XlaBuilder::CrossReplicaSum(
2121 const XlaOp& operand, absl::Span<const ReplicaGroup> replica_groups) {
2122 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2123 TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
2124 const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {});
2125 auto b = CreateSubBuilder("sum");
2126 Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"),
2127 b->Parameter(/*parameter_number=*/1, scalar_shape, "y"));
2128 TF_ASSIGN_OR_RETURN(auto computation, b->Build());
2129 return CrossReplicaSum(operand, computation, replica_groups,
2130 /*channel_id=*/absl::nullopt);
2131 });
2132 }
2133
CrossReplicaSum(const XlaOp & operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id)2134 XlaOp XlaBuilder::CrossReplicaSum(
2135 const XlaOp& operand, const XlaComputation& computation,
2136 absl::Span<const ReplicaGroup> replica_groups,
2137 const absl::optional<ChannelHandle>& channel_id) {
2138 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2139 HloInstructionProto instr;
2140 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2141 TF_ASSIGN_OR_RETURN(Shape shape,
2142 ShapeInference::InferAllReduceShape({&operand_shape}));
2143 *instr.mutable_shape() = shape.ToProto();
2144
2145 for (const ReplicaGroup& group : replica_groups) {
2146 *instr.add_replica_groups() = group;
2147 }
2148
2149 if (channel_id.has_value()) {
2150 instr.set_all_reduce_id(channel_id->handle());
2151 }
2152
2153 AddCalledComputation(computation, &instr);
2154
2155 return AddInstruction(std::move(instr), HloOpcode::kAllReduce, {operand});
2156 });
2157 }
2158
AllToAll(const XlaOp & operand,int64 split_dimension,int64 concat_dimension,int64 split_count,const std::vector<ReplicaGroup> & replica_groups)2159 XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension,
2160 int64 concat_dimension, int64 split_count,
2161 const std::vector<ReplicaGroup>& replica_groups) {
2162 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2163 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2164
2165 // The HloInstruction for Alltoall currently only handles the data
2166 // communication: it accepts N already split parts and scatters them to N
2167 // cores, and each core gathers the N received parts into a tuple as the
2168 // output. So here we explicitly split the operand before the hlo alltoall,
2169 // and concat the tuple elements.
2170 //
2171 // First, run shape inference to make sure the shapes are valid.
2172 TF_RETURN_IF_ERROR(
2173 ShapeInference::InferAllToAllShape(operand_shape, split_dimension,
2174 concat_dimension, split_count)
2175 .status());
2176
2177 // Split into N parts.
2178 std::vector<XlaOp> slices;
2179 slices.reserve(split_count);
2180 const int64 block_size =
2181 operand_shape.dimensions(split_dimension) / split_count;
2182 for (int i = 0; i < split_count; i++) {
2183 slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size,
2184 /*limit_index=*/(i + 1) * block_size,
2185 /*stride=*/1, /*dimno=*/split_dimension));
2186 }
2187
2188 // Handle data communication.
2189 HloInstructionProto instr;
2190 TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices));
2191 std::vector<const Shape*> slice_shape_ptrs;
2192 absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs),
2193 [](const Shape& shape) { return &shape; });
2194 TF_ASSIGN_OR_RETURN(
2195 Shape shape, ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
2196 *instr.mutable_shape() = shape.ToProto();
2197 for (const ReplicaGroup& group : replica_groups) {
2198 *instr.add_replica_groups() = group;
2199 }
2200 TF_ASSIGN_OR_RETURN(
2201 XlaOp alltoall,
2202 AddInstruction(std::move(instr), HloOpcode::kAllToAll, slices));
2203
2204 // Concat the N received parts.
2205 std::vector<XlaOp> received;
2206 received.reserve(split_count);
2207 for (int i = 0; i < split_count; i++) {
2208 received.push_back(this->GetTupleElement(alltoall, i));
2209 }
2210 return this->ConcatInDim(received, concat_dimension);
2211 });
2212 }
2213
CollectivePermute(const XlaOp & operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)2214 XlaOp XlaBuilder::CollectivePermute(
2215 const XlaOp& operand,
2216 const std::vector<std::pair<int64, int64>>& source_target_pairs) {
2217 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2218 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2219 HloInstructionProto instr;
2220 TF_ASSIGN_OR_RETURN(
2221 Shape shape,
2222 ShapeInference::InferCollectivePermuteShape(operand_shape));
2223 *instr.mutable_shape() = shape.ToProto();
2224
2225 for (const auto& pair : source_target_pairs) {
2226 auto* proto_pair = instr.add_source_target_pairs();
2227 proto_pair->set_source(pair.first);
2228 proto_pair->set_target(pair.second);
2229 }
2230
2231 return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute,
2232 {operand});
2233 });
2234 }
2235
ReplicaId()2236 XlaOp XlaBuilder::ReplicaId() {
2237 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2238 HloInstructionProto instr;
2239 *instr.mutable_shape() = ShapeUtil::MakeShape(U32, {}).ToProto();
2240 return AddInstruction(std::move(instr), HloOpcode::kReplicaId, {});
2241 });
2242 }
2243
SelectAndScatter(const XlaOp & operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding,const XlaOp & source,const XlaOp & init_value,const XlaComputation & scatter)2244 XlaOp XlaBuilder::SelectAndScatter(const XlaOp& operand,
2245 const XlaComputation& select,
2246 absl::Span<const int64> window_dimensions,
2247 absl::Span<const int64> window_strides,
2248 Padding padding, const XlaOp& source,
2249 const XlaOp& init_value,
2250 const XlaComputation& scatter) {
2251 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2252 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2253 return SelectAndScatterWithGeneralPadding(
2254 operand, select, window_dimensions, window_strides,
2255 MakePadding(AsInt64Slice(operand_shape.dimensions()), window_dimensions,
2256 window_strides, padding),
2257 source, init_value, scatter);
2258 });
2259 }
2260
SelectAndScatterWithGeneralPadding(const XlaOp & operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const XlaOp & source,const XlaOp & init_value,const XlaComputation & scatter)2261 XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
2262 const XlaOp& operand, const XlaComputation& select,
2263 absl::Span<const int64> window_dimensions,
2264 absl::Span<const int64> window_strides,
2265 absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
2266 const XlaOp& init_value, const XlaComputation& scatter) {
2267 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2268 HloInstructionProto instr;
2269
2270 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2271 TF_ASSIGN_OR_RETURN(const Shape& source_shape, GetShape(source));
2272 TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value));
2273 TF_ASSIGN_OR_RETURN(const ProgramShape& select_shape,
2274 select.GetProgramShape());
2275 TF_ASSIGN_OR_RETURN(const ProgramShape& scatter_shape,
2276 scatter.GetProgramShape());
2277 TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
2278 MakeWindow(window_dimensions, window_strides, padding,
2279 /*lhs_dilation=*/{}, /*rhs_dilation=*/{}));
2280 TF_ASSIGN_OR_RETURN(Shape shape,
2281 ShapeInference::InferSelectAndScatterShape(
2282 operand_shape, select_shape, instr.window(),
2283 source_shape, init_shape, scatter_shape));
2284 *instr.mutable_shape() = shape.ToProto();
2285
2286 AddCalledComputation(select, &instr);
2287 AddCalledComputation(scatter, &instr);
2288
2289 return AddInstruction(std::move(instr), HloOpcode::kSelectAndScatter,
2290 {operand, source, init_value});
2291 });
2292 }
2293
ReducePrecision(const XlaOp & operand,const int exponent_bits,const int mantissa_bits)2294 XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits,
2295 const int mantissa_bits) {
2296 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2297 HloInstructionProto instr;
2298 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2299 TF_ASSIGN_OR_RETURN(Shape shape,
2300 ShapeInference::InferReducePrecisionShape(
2301 operand_shape, exponent_bits, mantissa_bits));
2302 *instr.mutable_shape() = shape.ToProto();
2303 instr.set_exponent_bits(exponent_bits);
2304 instr.set_mantissa_bits(mantissa_bits);
2305 return AddInstruction(std::move(instr), HloOpcode::kReducePrecision,
2306 {operand});
2307 });
2308 }
2309
Send(const XlaOp & operand,const ChannelHandle & handle)2310 void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {
2311 ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2312 // Send HLO takes two operands: a data operand and a token. Generate the
2313 // token to pass into the send.
2314 // TODO(b/80000000): Remove this when clients have been updated to handle
2315 // tokens.
2316 HloInstructionProto token_instr;
2317 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
2318 TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
2319 HloOpcode::kAfterAll, {}));
2320
2321 return SendWithToken(operand, token, handle);
2322 });
2323 }
2324
SendWithToken(const XlaOp & operand,const XlaOp & token,const ChannelHandle & handle)2325 XlaOp XlaBuilder::SendWithToken(const XlaOp& operand, const XlaOp& token,
2326 const ChannelHandle& handle) {
2327 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2328 if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
2329 return InvalidArgument("Send must use a device-to-device channel");
2330 }
2331
2332 // Send instruction produces a tuple of {aliased operand, U32 context,
2333 // token}.
2334 HloInstructionProto send_instr;
2335 TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
2336 *send_instr.mutable_shape() =
2337 ShapeUtil::MakeTupleShape(
2338 {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
2339 .ToProto();
2340 send_instr.set_channel_id(handle.handle());
2341 TF_ASSIGN_OR_RETURN(XlaOp send,
2342 AddInstruction(std::move(send_instr), HloOpcode::kSend,
2343 {operand, token}));
2344
2345 HloInstructionProto send_done_instr;
2346 *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
2347 send_done_instr.set_channel_id(handle.handle());
2348 return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
2349 {send});
2350 });
2351 }
2352
Recv(const Shape & shape,const ChannelHandle & handle)2353 XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
2354 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2355 // Recv HLO takes a single token operand. Generate the token to pass into
2356 // the Recv and RecvDone instructions.
2357 // TODO(b/80000000): Remove this when clients have been updated to handle
2358 // tokens.
2359 HloInstructionProto token_instr;
2360 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
2361 TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
2362 HloOpcode::kAfterAll, {}));
2363
2364 XlaOp recv = RecvWithToken(token, shape, handle);
2365
2366 // The RecvDone instruction produces a tuple of the data and a token
2367 // type. Return XLA op containing the data.
2368 // TODO(b/80000000): Remove this when clients have been updated to handle
2369 // tokens.
2370 HloInstructionProto recv_data;
2371 *recv_data.mutable_shape() = shape.ToProto();
2372 recv_data.set_tuple_index(0);
2373 return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement,
2374 {recv});
2375 });
2376 }
2377
RecvWithToken(const XlaOp & token,const Shape & shape,const ChannelHandle & handle)2378 XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape,
2379 const ChannelHandle& handle) {
2380 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2381 if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
2382 return InvalidArgument("Recv must use a device-to-device channel");
2383 }
2384
2385 // Recv instruction produces a tuple of {receive buffer, U32 context,
2386 // token}.
2387 HloInstructionProto recv_instr;
2388 *recv_instr.mutable_shape() =
2389 ShapeUtil::MakeTupleShape(
2390 {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
2391 .ToProto();
2392 recv_instr.set_channel_id(handle.handle());
2393 TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
2394 HloOpcode::kRecv, {token}));
2395
2396 HloInstructionProto recv_done_instr;
2397 *recv_done_instr.mutable_shape() =
2398 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
2399 .ToProto();
2400 recv_done_instr.set_channel_id(handle.handle());
2401 return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
2402 {recv});
2403 });
2404 }
2405
SendToHost(const XlaOp & operand,const XlaOp & token,const Shape & shape_with_layout,const ChannelHandle & handle)2406 XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token,
2407 const Shape& shape_with_layout,
2408 const ChannelHandle& handle) {
2409 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2410 if (!LayoutUtil::HasLayout(shape_with_layout)) {
2411 return InvalidArgument("Shape passed to SendToHost must have a layout");
2412 }
2413 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2414 if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) {
2415 return InvalidArgument(
2416 "SendToHost shape %s must be compatible with operand shape %s",
2417 ShapeUtil::HumanStringWithLayout(shape_with_layout),
2418 ShapeUtil::HumanStringWithLayout(operand_shape));
2419 }
2420 // TODO(b/111544877): Support tuple shapes.
2421 if (!operand_shape.IsArray()) {
2422 return InvalidArgument("SendToHost only supports array shapes, shape: %s",
2423 ShapeUtil::HumanString(operand_shape));
2424 }
2425
2426 if (handle.type() != ChannelHandle::DEVICE_TO_HOST) {
2427 return InvalidArgument("SendToHost must use a device-to-host channel");
2428 }
2429
2430 // Send instruction produces a tuple of {aliased operand, U32 context,
2431 // token}.
2432 HloInstructionProto send_instr;
2433 *send_instr.mutable_shape() =
2434 ShapeUtil::MakeTupleShape({shape_with_layout,
2435 ShapeUtil::MakeShape(U32, {}),
2436 ShapeUtil::MakeTokenShape()})
2437 .ToProto();
2438 send_instr.set_channel_id(handle.handle());
2439 send_instr.set_is_host_transfer(true);
2440 TF_ASSIGN_OR_RETURN(XlaOp send,
2441 AddInstruction(std::move(send_instr), HloOpcode::kSend,
2442 {operand, token}));
2443
2444 HloInstructionProto send_done_instr;
2445 *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
2446 send_done_instr.set_channel_id(handle.handle());
2447 send_done_instr.set_is_host_transfer(true);
2448 return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
2449 {send});
2450 });
2451 }
2452
RecvFromHost(const XlaOp & token,const Shape & shape,const ChannelHandle & handle)2453 XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape,
2454 const ChannelHandle& handle) {
2455 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2456 if (!LayoutUtil::HasLayout(shape)) {
2457 return InvalidArgument("Shape passed to RecvFromHost must have a layout");
2458 }
2459
2460 // TODO(b/111544877): Support tuple shapes.
2461 if (!shape.IsArray()) {
2462 return InvalidArgument(
2463 "RecvFromHost only supports array shapes, shape: %s",
2464 ShapeUtil::HumanString(shape));
2465 }
2466
2467 if (handle.type() != ChannelHandle::HOST_TO_DEVICE) {
2468 return InvalidArgument("RecvFromHost must use a host-to-device channel");
2469 }
2470
2471 // Recv instruction produces a tuple of {receive buffer, U32 context,
2472 // token}.
2473 HloInstructionProto recv_instr;
2474 *recv_instr.mutable_shape() =
2475 ShapeUtil::MakeTupleShape(
2476 {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
2477 .ToProto();
2478 recv_instr.set_channel_id(handle.handle());
2479 recv_instr.set_is_host_transfer(true);
2480 TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
2481 HloOpcode::kRecv, {token}));
2482
2483 HloInstructionProto recv_done_instr;
2484 *recv_done_instr.mutable_shape() =
2485 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
2486 .ToProto();
2487 recv_done_instr.set_channel_id(handle.handle());
2488 recv_done_instr.set_is_host_transfer(true);
2489 return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
2490 {recv});
2491 });
2492 }
2493
GetDimensionSize(const XlaOp & operand,int64 dimension)2494 XlaOp XlaBuilder::GetDimensionSize(const XlaOp& operand, int64 dimension) {
2495 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2496 HloInstructionProto instr;
2497 TF_ASSIGN_OR_RETURN(const auto& operand_shape, GetShape(operand));
2498 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape(
2499 operand_shape, dimension));
2500 *instr.mutable_shape() = shape.ToProto();
2501 instr.add_dimensions(dimension);
2502 return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize,
2503 {operand});
2504 });
2505 }
2506
IsConstant(const XlaOp & operand) const2507 StatusOr<bool> XlaBuilder::IsConstant(const XlaOp& operand) const {
2508 TF_RETURN_IF_ERROR(first_error_);
2509
2510 // Verify that the handle is valid.
2511 TF_RETURN_IF_ERROR(LookUpInstruction(operand).status());
2512
2513 bool is_constant = true;
2514 absl::flat_hash_set<int64> visited;
2515 IsConstantVisitor(operand.handle(), &visited, &is_constant);
2516 return is_constant;
2517 }
2518
BuildConstantSubGraph(const XlaOp & root_op)2519 StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
2520 const XlaOp& root_op) {
2521 TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
2522 if (!is_constant) {
2523 auto op_status = LookUpInstruction(root_op);
2524 string op_string =
2525 op_status.ok() ? op_status.ValueOrDie()->name() : "<unknown operation>";
2526 return InvalidArgument(
2527 "Operand to BuildConstantSubGraph depends on a parameter.\n\n"
2528 " op requested for constant subgraph: %s\n\n"
2529 "This is an internal error that typically happens when the XLA user "
2530 "(e.g. TensorFlow) is attempting to determine a value that must be a "
2531 "compile-time constant (e.g. an array dimension) but it is not capable "
2532 "of being evaluated at XLA compile time.\n\n"
2533 "Please file a usability bug with the framework being used (e.g. "
2534 "TensorFlow).",
2535 op_string);
2536 }
2537
2538 TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
2539 LookUpInstruction(root_op));
2540
2541 HloComputationProto entry;
2542 SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator,
2543 GetNextId());
2544 entry.set_root_id(root->id());
2545 ProgramShapeProto* program_shape = entry.mutable_program_shape();
2546 *program_shape->mutable_result() = root->shape();
2547
2548 // We use std::set to keep the instruction ids in ascending order (which is
2549 // also a valid dependency order). The related ops will be added to the
2550 // subgraph in the same order.
2551 std::set<int64> related_ops;
2552 absl::flat_hash_set<int64> related_calls; // Related computations.
2553 std::queue<int64> worklist;
2554 worklist.push(root->id());
2555 related_ops.insert(root->id());
2556 while (!worklist.empty()) {
2557 int64 handle = worklist.front();
2558 worklist.pop();
2559 TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
2560 LookUpInstructionByHandle(handle));
2561
2562 if (instr_proto->opcode() ==
2563 HloOpcodeString(HloOpcode::kGetDimensionSize)) {
2564 // At this point, BuildConstantSubGraph should never encounter a
2565 // GetDimensionSize with a dynamic dimension. IsConstant check would have
2566 // failed at the beginning of this function.
2567 //
2568 // Replace GetDimensionSize with a Constant representing the static bound
2569 // of the shape.
2570 int64 dimension = instr_proto->dimensions(0);
2571 int64 operand_handle = instr_proto->operand_ids(0);
2572 TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
2573 LookUpInstructionByHandle(operand_handle));
2574
2575 TF_RET_CHECK(!operand_proto->shape().is_dynamic_dimension(dimension));
2576 auto constant_dimension_size =
2577 static_cast<uint32>(operand_proto->shape().dimensions(dimension));
2578
2579 Literal literal = LiteralUtil::CreateR0(constant_dimension_size);
2580
2581 HloInstructionProto const_instr;
2582 *const_instr.mutable_shape() = literal.shape().ToProto();
2583 *const_instr.mutable_literal() = literal.ToProto();
2584 *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
2585
2586 const_instr.set_id(handle);
2587 *const_instr.mutable_name() =
2588 GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id());
2589 *entry.add_instructions() =
2590 const_instr; // Add to the result constant graph.
2591 } else {
2592 for (int64 id : instr_proto->operand_ids()) {
2593 if (related_ops.insert(id).second) {
2594 worklist.push(id);
2595 }
2596 }
2597 for (int64 called_id : instr_proto->called_computation_ids()) {
2598 related_calls.insert(called_id);
2599 }
2600 }
2601 }
2602
2603 // Add related ops to the computation.
2604 for (int64 id : related_ops) {
2605 TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
2606 LookUpInstructionByHandle(id));
2607
2608 if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize)) {
2609 continue;
2610 }
2611 auto* instr = entry.add_instructions();
2612
2613 *instr = *instr_src;
2614 // Ensures that the instruction names are unique among the graph.
2615 const string& new_name =
2616 StrCat(instr->name(), ".", entry.id(), ".", instr->id());
2617 instr->set_name(new_name);
2618 }
2619
2620 XlaComputation computation(entry.id());
2621 HloModuleProto* module = computation.mutable_proto();
2622 module->set_name(entry.name());
2623 module->set_id(entry.id());
2624 module->set_entry_computation_name(entry.name());
2625 module->set_entry_computation_id(entry.id());
2626 *module->mutable_host_program_shape() = *program_shape;
2627 for (auto& e : embedded_) {
2628 if (related_calls.find(e.second.id()) != related_calls.end()) {
2629 *module->add_computations() = e.second;
2630 }
2631 }
2632 *module->add_computations() = std::move(entry);
2633
2634 return std::move(computation);
2635 }
2636
CreateSubBuilder(const string & computation_name)2637 std::unique_ptr<XlaBuilder> XlaBuilder::CreateSubBuilder(
2638 const string& computation_name) {
2639 auto sub_builder = absl::make_unique<XlaBuilder>(computation_name);
2640 sub_builder->parent_builder_ = this;
2641 sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_;
2642 return sub_builder;
2643 }
2644
2645 /* static */ ConvolutionDimensionNumbers
CreateDefaultConvDimensionNumbers(int num_spatial_dims)2646 XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
2647 ConvolutionDimensionNumbers dimension_numbers;
2648 dimension_numbers.set_input_batch_dimension(kConvBatchDimension);
2649 dimension_numbers.set_input_feature_dimension(kConvFeatureDimension);
2650 dimension_numbers.set_output_batch_dimension(kConvBatchDimension);
2651 dimension_numbers.set_output_feature_dimension(kConvFeatureDimension);
2652 dimension_numbers.set_kernel_output_feature_dimension(
2653 kConvKernelOutputDimension);
2654 dimension_numbers.set_kernel_input_feature_dimension(
2655 kConvKernelInputDimension);
2656 for (int i = 0; i < num_spatial_dims; ++i) {
2657 dimension_numbers.add_input_spatial_dimensions(i + 2);
2658 dimension_numbers.add_kernel_spatial_dimensions(i + 2);
2659 dimension_numbers.add_output_spatial_dimensions(i + 2);
2660 }
2661 return dimension_numbers;
2662 }
2663
Validate(const ConvolutionDimensionNumbers & dnum)2664 /* static */ Status XlaBuilder::Validate(
2665 const ConvolutionDimensionNumbers& dnum) {
2666 if (dnum.input_spatial_dimensions_size() < 2) {
2667 return FailedPrecondition("input spacial dimension < 2: %d",
2668 dnum.input_spatial_dimensions_size());
2669 }
2670 if (dnum.kernel_spatial_dimensions_size() < 2) {
2671 return FailedPrecondition("kernel spacial dimension < 2: %d",
2672 dnum.kernel_spatial_dimensions_size());
2673 }
2674 if (dnum.output_spatial_dimensions_size() < 2) {
2675 return FailedPrecondition("output spacial dimension < 2: %d",
2676 dnum.output_spatial_dimensions_size());
2677 }
2678
2679 if (std::set<int64>(
2680 {dnum.input_batch_dimension(), dnum.input_feature_dimension(),
2681 dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)})
2682 .size() != 4) {
2683 return FailedPrecondition(
2684 "dimension numbers for the input are not unique: (%d, %d, %d, "
2685 "%d)",
2686 dnum.input_batch_dimension(), dnum.input_feature_dimension(),
2687 dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1));
2688 }
2689 if (std::set<int64>({dnum.kernel_output_feature_dimension(),
2690 dnum.kernel_input_feature_dimension(),
2691 dnum.kernel_spatial_dimensions(0),
2692 dnum.kernel_spatial_dimensions(1)})
2693 .size() != 4) {
2694 return FailedPrecondition(
2695 "dimension numbers for the weight are not unique: (%d, %d, %d, "
2696 "%d)",
2697 dnum.kernel_output_feature_dimension(),
2698 dnum.kernel_input_feature_dimension(),
2699 dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1));
2700 }
2701 if (std::set<int64>({dnum.output_batch_dimension(),
2702 dnum.output_feature_dimension(),
2703 dnum.output_spatial_dimensions(0),
2704 dnum.output_spatial_dimensions(1)})
2705 .size() != 4) {
2706 return FailedPrecondition(
2707 "dimension numbers for the output are not unique: (%d, %d, %d, "
2708 "%d)",
2709 dnum.output_batch_dimension(), dnum.output_feature_dimension(),
2710 dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1));
2711 }
2712 return Status::OK();
2713 }
2714
AddInstruction(HloInstructionProto && instr,HloOpcode opcode,absl::Span<const XlaOp> operands)2715 StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
2716 HloOpcode opcode,
2717 absl::Span<const XlaOp> operands) {
2718 TF_RETURN_IF_ERROR(first_error_);
2719
2720 const int64 handle = GetNextId();
2721 instr.set_id(handle);
2722 instr.set_opcode(HloOpcodeString(opcode));
2723 if (instr.name().empty()) {
2724 instr.set_name(instr.opcode());
2725 }
2726 for (const auto& operand : operands) {
2727 if (operand.builder_ == nullptr) {
2728 return InvalidArgument("invalid XlaOp with handle %d", operand.handle());
2729 }
2730 if (operand.builder_ != this) {
2731 return InvalidArgument("Do not add XlaOp from builder %s to builder %s",
2732 operand.builder_->name(), this->name());
2733 }
2734 instr.add_operand_ids(operand.handle());
2735 }
2736
2737 *instr.mutable_metadata() = metadata_;
2738 if (sharding_) {
2739 *instr.mutable_sharding() = *sharding_;
2740 }
2741
2742 handle_to_index_[handle] = instructions_.size();
2743 instructions_.push_back(std::move(instr));
2744
2745 XlaOp op(handle, this);
2746 return op;
2747 }
2748
AddCalledComputation(const XlaComputation & computation,HloInstructionProto * instr)2749 void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
2750 HloInstructionProto* instr) {
2751 absl::flat_hash_map<int64, int64> remapped_ids;
2752 std::vector<HloComputationProto> imported_computations;
2753 imported_computations.reserve(computation.proto().computations_size());
2754 // Before we import the computations by remapping IDs, and capturing the
2755 // old->new mappings in remapped_ids.
2756 for (const HloComputationProto& e : computation.proto().computations()) {
2757 HloComputationProto new_computation(e);
2758 int64 computation_id = GetNextId();
2759 remapped_ids[new_computation.id()] = computation_id;
2760 SetProtoIdAndName(&new_computation,
2761 GetBaseName(new_computation.name(), kNameSeparator),
2762 kNameSeparator, computation_id);
2763 for (auto& instruction : *new_computation.mutable_instructions()) {
2764 int64 instruction_id = GetNextId();
2765 remapped_ids[instruction.id()] = instruction_id;
2766 SetProtoIdAndName(&instruction,
2767 GetBaseName(instruction.name(), kNameSeparator),
2768 kNameSeparator, instruction_id);
2769 }
2770 new_computation.set_root_id(remapped_ids.at(new_computation.root_id()));
2771
2772 imported_computations.push_back(std::move(new_computation));
2773 }
2774 // Once we have imported all the computations, and captured all the ID
2775 // mappings, we go back and fixup the IDs in the imported computations.
2776 instr->add_called_computation_ids(
2777 remapped_ids.at(computation.proto().entry_computation_id()));
2778 for (auto& imported_computation : imported_computations) {
2779 for (auto& instruction : *imported_computation.mutable_instructions()) {
2780 for (auto& operand_id : *instruction.mutable_operand_ids()) {
2781 operand_id = remapped_ids.at(operand_id);
2782 }
2783 for (auto& control_predecessor_id :
2784 *instruction.mutable_control_predecessor_ids()) {
2785 control_predecessor_id = remapped_ids.at(control_predecessor_id);
2786 }
2787 for (auto& called_computation_id :
2788 *instruction.mutable_called_computation_ids()) {
2789 called_computation_id = remapped_ids.at(called_computation_id);
2790 }
2791 }
2792
2793 int64 computation_id = imported_computation.id();
2794 embedded_.insert({computation_id, std::move(imported_computation)});
2795 }
2796 }
2797
LookUpInstruction(const XlaOp & op) const2798 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
2799 const XlaOp& op) const {
2800 TF_RETURN_IF_ERROR(first_error_);
2801
2802 if (op.builder_ == nullptr) {
2803 return InvalidArgument(
2804 "invalid XlaOp with handle %d; the builder of this op is freed",
2805 op.handle());
2806 }
2807 if (op.builder_ != this) {
2808 return InvalidArgument(
2809 "XlaOp with handle %d is built by builder '%s', but is trying to use "
2810 "it in builder '%s'",
2811 op.handle(), op.builder_->name(), this->name());
2812 }
2813
2814 return LookUpInstructionByHandle(op.handle());
2815 }
2816
LookUpInstructionByHandle(int64 handle) const2817 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
2818 int64 handle) const {
2819 auto it = handle_to_index_.find(handle);
2820 if (it == handle_to_index_.end()) {
2821 return InvalidArgument("No XlaOp with handle %d", handle);
2822 }
2823 return &instructions_[it->second];
2824 }
2825
2826 // Enqueues a "retrieve parameter value" instruction for a parameter that was
2827 // passed to the computation.
Parameter(XlaBuilder * builder,int64 parameter_number,const Shape & shape,const string & name)2828 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
2829 const string& name) {
2830 return builder->Parameter(parameter_number, shape, name);
2831 }
2832
2833 // Enqueues a constant with the value of the given literal onto the
2834 // computation.
ConstantLiteral(XlaBuilder * builder,const LiteralSlice & literal)2835 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) {
2836 return builder->ConstantLiteral(literal);
2837 }
2838
Broadcast(const XlaOp & operand,absl::Span<const int64> broadcast_sizes)2839 XlaOp Broadcast(const XlaOp& operand, absl::Span<const int64> broadcast_sizes) {
2840 return operand.builder()->Broadcast(operand, broadcast_sizes);
2841 }
2842
BroadcastInDim(const XlaOp & operand,const absl::Span<const int64> out_dim_size,const absl::Span<const int64> broadcast_dimensions)2843 XlaOp BroadcastInDim(const XlaOp& operand,
2844 const absl::Span<const int64> out_dim_size,
2845 const absl::Span<const int64> broadcast_dimensions) {
2846 return operand.builder()->BroadcastInDim(operand, out_dim_size,
2847 broadcast_dimensions);
2848 }
2849
Pad(const XlaOp & operand,const XlaOp & padding_value,const PaddingConfig & padding_config)2850 XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
2851 const PaddingConfig& padding_config) {
2852 return operand.builder()->Pad(operand, padding_value, padding_config);
2853 }
2854
Reshape(const XlaOp & operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes)2855 XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
2856 absl::Span<const int64> new_sizes) {
2857 return operand.builder()->Reshape(operand, dimensions, new_sizes);
2858 }
2859
Reshape(const XlaOp & operand,absl::Span<const int64> new_sizes)2860 XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes) {
2861 return operand.builder()->Reshape(operand, new_sizes);
2862 }
2863
Collapse(const XlaOp & operand,absl::Span<const int64> dimensions)2864 XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions) {
2865 return operand.builder()->Collapse(operand, dimensions);
2866 }
2867
Slice(const XlaOp & operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)2868 XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
2869 absl::Span<const int64> limit_indices,
2870 absl::Span<const int64> strides) {
2871 return operand.builder()->Slice(operand, start_indices, limit_indices,
2872 strides);
2873 }
2874
SliceInDim(const XlaOp & operand,int64 start_index,int64 limit_index,int64 stride,int64 dimno)2875 XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
2876 int64 stride, int64 dimno) {
2877 return operand.builder()->SliceInDim(operand, start_index, limit_index,
2878 stride, dimno);
2879 }
2880
DynamicSlice(const XlaOp & operand,const XlaOp & start_indices,absl::Span<const int64> slice_sizes)2881 XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
2882 absl::Span<const int64> slice_sizes) {
2883 return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
2884 }
DynamicSlice(const XlaOp & operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)2885 XlaOp DynamicSlice(const XlaOp& operand, absl::Span<const XlaOp> start_indices,
2886 absl::Span<const int64> slice_sizes) {
2887 return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
2888 }
2889
DynamicUpdateSlice(const XlaOp & operand,const XlaOp & update,const XlaOp & start_indices)2890 XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
2891 const XlaOp& start_indices) {
2892 return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
2893 }
2894
DynamicUpdateSlice(const XlaOp & operand,const XlaOp & update,absl::Span<const XlaOp> start_indices)2895 XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
2896 absl::Span<const XlaOp> start_indices) {
2897 return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
2898 }
2899
ConcatInDim(XlaBuilder * builder,absl::Span<const XlaOp> operands,int64 dimension)2900 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
2901 int64 dimension) {
2902 return builder->ConcatInDim(operands, dimension);
2903 }
2904
Trace(const string & tag,const XlaOp & operand)2905 void Trace(const string& tag, const XlaOp& operand) {
2906 return operand.builder()->Trace(tag, operand);
2907 }
2908
Select(const XlaOp & pred,const XlaOp & on_true,const XlaOp & on_false)2909 XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) {
2910 return pred.builder()->Select(pred, on_true, on_false);
2911 }
2912
Tuple(XlaBuilder * builder,absl::Span<const XlaOp> elements)2913 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements) {
2914 return builder->Tuple(elements);
2915 }
2916
GetTupleElement(const XlaOp & tuple_data,int64 index)2917 XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) {
2918 return tuple_data.builder()->GetTupleElement(tuple_data, index);
2919 }
2920
Eq(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)2921 XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
2922 absl::Span<const int64> broadcast_dimensions) {
2923 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
2924 }
2925
Ne(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)2926 XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
2927 absl::Span<const int64> broadcast_dimensions) {
2928 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe);
2929 }
2930
Ge(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)2931 XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
2932 absl::Span<const int64> broadcast_dimensions) {
2933 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe);
2934 }
2935
Gt(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)2936 XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
2937 absl::Span<const int64> broadcast_dimensions) {
2938 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt);
2939 }
2940
Le(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)2941 XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
2942 absl::Span<const int64> broadcast_dimensions) {
2943 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe);
2944 }
2945
Lt(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)2946 XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
2947 absl::Span<const int64> broadcast_dimensions) {
2948 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
2949 }
2950
Compare(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions,ComparisonDirection direction)2951 XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs,
2952 absl::Span<const int64> broadcast_dimensions,
2953 ComparisonDirection direction) {
2954 return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
2955 broadcast_dimensions, direction);
2956 }
2957
Dot(const XlaOp & lhs,const XlaOp & rhs,const PrecisionConfig * precision_config)2958 XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
2959 const PrecisionConfig* precision_config) {
2960 return lhs.builder()->Dot(lhs, rhs, precision_config);
2961 }
2962
DotGeneral(const XlaOp & lhs,const XlaOp & rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config)2963 XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
2964 const DotDimensionNumbers& dimension_numbers,
2965 const PrecisionConfig* precision_config) {
2966 return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
2967 precision_config);
2968 }
2969
Conv(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,Padding padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)2970 XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
2971 absl::Span<const int64> window_strides, Padding padding,
2972 int64 feature_group_count, int64 batch_group_count,
2973 const PrecisionConfig* precision_config) {
2974 return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
2975 feature_group_count, batch_group_count,
2976 precision_config);
2977 }
2978
ConvWithGeneralPadding(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)2979 XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs,
2980 absl::Span<const int64> window_strides,
2981 absl::Span<const std::pair<int64, int64>> padding,
2982 int64 feature_group_count, int64 batch_group_count,
2983 const PrecisionConfig* precision_config) {
2984 return lhs.builder()->ConvWithGeneralPadding(
2985 lhs, rhs, window_strides, padding, feature_group_count, batch_group_count,
2986 precision_config);
2987 }
2988
ConvWithGeneralDimensions(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)2989 XlaOp ConvWithGeneralDimensions(
2990 const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
2991 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
2992 int64 feature_group_count, int64 batch_group_count,
2993 const PrecisionConfig* precision_config) {
2994 return lhs.builder()->ConvWithGeneralDimensions(
2995 lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
2996 batch_group_count, precision_config);
2997 }
2998
ConvGeneral(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)2999 XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
3000 absl::Span<const int64> window_strides,
3001 absl::Span<const std::pair<int64, int64>> padding,
3002 const ConvolutionDimensionNumbers& dimension_numbers,
3003 int64 feature_group_count, int64 batch_group_count,
3004 const PrecisionConfig* precision_config) {
3005 return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding,
3006 dimension_numbers, feature_group_count,
3007 batch_group_count, precision_config);
3008 }
3009
ConvGeneralDilated(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)3010 XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
3011 absl::Span<const int64> window_strides,
3012 absl::Span<const std::pair<int64, int64>> padding,
3013 absl::Span<const int64> lhs_dilation,
3014 absl::Span<const int64> rhs_dilation,
3015 const ConvolutionDimensionNumbers& dimension_numbers,
3016 int64 feature_group_count, int64 batch_group_count,
3017 const PrecisionConfig* precision_config) {
3018 return lhs.builder()->ConvGeneralDilated(
3019 lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
3020 dimension_numbers, feature_group_count, batch_group_count,
3021 precision_config);
3022 }
3023
Fft(const XlaOp & operand,FftType fft_type,absl::Span<const int64> fft_length)3024 XlaOp Fft(const XlaOp& operand, FftType fft_type,
3025 absl::Span<const int64> fft_length) {
3026 return operand.builder()->Fft(operand, fft_type, fft_length);
3027 }
3028
TriangularSolve(XlaOp a,XlaOp b,bool left_side,bool lower,bool unit_diagonal,TriangularSolveOptions::Transpose transpose_a)3029 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
3030 bool unit_diagonal,
3031 TriangularSolveOptions::Transpose transpose_a) {
3032 XlaBuilder* builder = a.builder();
3033 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3034 HloInstructionProto instr;
3035 TF_ASSIGN_OR_RETURN(const Shape& a_shape, builder->GetShape(a));
3036 TF_ASSIGN_OR_RETURN(const Shape& b_shape, builder->GetShape(b));
3037 xla::TriangularSolveOptions& options =
3038 *instr.mutable_triangular_solve_options();
3039 options.set_left_side(left_side);
3040 options.set_lower(lower);
3041 options.set_unit_diagonal(unit_diagonal);
3042 options.set_transpose_a(transpose_a);
3043 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTriangularSolveShape(
3044 a_shape, b_shape, options));
3045 *instr.mutable_shape() = shape.ToProto();
3046
3047 return builder->AddInstruction(std::move(instr),
3048 HloOpcode::kTriangularSolve, {a, b});
3049 });
3050 }
3051
Cholesky(XlaOp a,bool lower)3052 XlaOp Cholesky(XlaOp a, bool lower) {
3053 XlaBuilder* builder = a.builder();
3054 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3055 HloInstructionProto instr;
3056 TF_ASSIGN_OR_RETURN(const Shape& a_shape, builder->GetShape(a));
3057 xla::CholeskyOptions& options = *instr.mutable_cholesky_options();
3058 options.set_lower(lower);
3059 TF_ASSIGN_OR_RETURN(Shape shape,
3060 ShapeInference::InferCholeskyShape(a_shape));
3061 *instr.mutable_shape() = shape.ToProto();
3062
3063 return builder->AddInstruction(std::move(instr), HloOpcode::kCholesky, {a});
3064 });
3065 }
3066
Infeed(XlaBuilder * builder,const Shape & shape,const string & config)3067 XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) {
3068 return builder->Infeed(shape, config);
3069 }
3070
Outfeed(const XlaOp & operand,const Shape & shape_with_layout,const string & outfeed_config)3071 void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
3072 const string& outfeed_config) {
3073 return operand.builder()->Outfeed(operand, shape_with_layout, outfeed_config);
3074 }
3075
Call(XlaBuilder * builder,const XlaComputation & computation,absl::Span<const XlaOp> operands)3076 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
3077 absl::Span<const XlaOp> operands) {
3078 return builder->Call(computation, operands);
3079 }
3080
CustomCall(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque)3081 XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
3082 absl::Span<const XlaOp> operands, const Shape& shape,
3083 const string& opaque) {
3084 return builder->CustomCall(call_target_name, operands, shape, opaque,
3085 /*operand_shapes_with_layout=*/absl::nullopt);
3086 }
3087
CustomCallWithLayout(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,absl::Span<const Shape> operand_shapes_with_layout,const string & opaque)3088 XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
3089 absl::Span<const XlaOp> operands, const Shape& shape,
3090 absl::Span<const Shape> operand_shapes_with_layout,
3091 const string& opaque) {
3092 return builder->CustomCall(call_target_name, operands, shape, opaque,
3093 operand_shapes_with_layout);
3094 }
3095
Complex(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3096 XlaOp Complex(const XlaOp& lhs, const XlaOp& rhs,
3097 absl::Span<const int64> broadcast_dimensions) {
3098 return lhs.builder()->BinaryOp(HloOpcode::kComplex, lhs, rhs,
3099 broadcast_dimensions);
3100 }
3101
Conj(const XlaOp & operand)3102 XlaOp Conj(const XlaOp& operand) {
3103 return Complex(Real(operand), Neg(Imag(operand)));
3104 }
3105
Add(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3106 XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
3107 absl::Span<const int64> broadcast_dimensions) {
3108 return lhs.builder()->BinaryOp(HloOpcode::kAdd, lhs, rhs,
3109 broadcast_dimensions);
3110 }
3111
Sub(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3112 XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
3113 absl::Span<const int64> broadcast_dimensions) {
3114 return lhs.builder()->BinaryOp(HloOpcode::kSubtract, lhs, rhs,
3115 broadcast_dimensions);
3116 }
3117
Mul(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3118 XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
3119 absl::Span<const int64> broadcast_dimensions) {
3120 return lhs.builder()->BinaryOp(HloOpcode::kMultiply, lhs, rhs,
3121 broadcast_dimensions);
3122 }
3123
Div(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3124 XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
3125 absl::Span<const int64> broadcast_dimensions) {
3126 return lhs.builder()->BinaryOp(HloOpcode::kDivide, lhs, rhs,
3127 broadcast_dimensions);
3128 }
3129
Rem(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3130 XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
3131 absl::Span<const int64> broadcast_dimensions) {
3132 return lhs.builder()->BinaryOp(HloOpcode::kRemainder, lhs, rhs,
3133 broadcast_dimensions);
3134 }
3135
Max(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3136 XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
3137 absl::Span<const int64> broadcast_dimensions) {
3138 return lhs.builder()->BinaryOp(HloOpcode::kMaximum, lhs, rhs,
3139 broadcast_dimensions);
3140 }
3141
Min(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3142 XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
3143 absl::Span<const int64> broadcast_dimensions) {
3144 return lhs.builder()->BinaryOp(HloOpcode::kMinimum, lhs, rhs,
3145 broadcast_dimensions);
3146 }
3147
And(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3148 XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
3149 absl::Span<const int64> broadcast_dimensions) {
3150 return lhs.builder()->BinaryOp(HloOpcode::kAnd, lhs, rhs,
3151 broadcast_dimensions);
3152 }
3153
Or(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3154 XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
3155 absl::Span<const int64> broadcast_dimensions) {
3156 return lhs.builder()->BinaryOp(HloOpcode::kOr, lhs, rhs,
3157 broadcast_dimensions);
3158 }
3159
Xor(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3160 XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
3161 absl::Span<const int64> broadcast_dimensions) {
3162 return lhs.builder()->BinaryOp(HloOpcode::kXor, lhs, rhs,
3163 broadcast_dimensions);
3164 }
3165
Not(const XlaOp & operand)3166 XlaOp Not(const XlaOp& operand) {
3167 return operand.builder()->UnaryOp(HloOpcode::kNot, operand);
3168 }
3169
ShiftLeft(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3170 XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
3171 absl::Span<const int64> broadcast_dimensions) {
3172 return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs,
3173 broadcast_dimensions);
3174 }
3175
ShiftRightArithmetic(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3176 XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
3177 absl::Span<const int64> broadcast_dimensions) {
3178 return lhs.builder()->BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs,
3179 broadcast_dimensions);
3180 }
3181
ShiftRightLogical(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3182 XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
3183 absl::Span<const int64> broadcast_dimensions) {
3184 return lhs.builder()->BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs,
3185 broadcast_dimensions);
3186 }
3187
Reduce(const XlaOp & operand,const XlaOp & init_value,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)3188 XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
3189 const XlaComputation& computation,
3190 absl::Span<const int64> dimensions_to_reduce) {
3191 return operand.builder()->Reduce(operand, init_value, computation,
3192 dimensions_to_reduce);
3193 }
3194
3195 // Reduces several arrays simultaneously among the provided dimensions, given
3196 // "computation" as a reduction operator.
Reduce(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)3197 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
3198 absl::Span<const XlaOp> init_values,
3199 const XlaComputation& computation,
3200 absl::Span<const int64> dimensions_to_reduce) {
3201 return builder->Reduce(operands, init_values, computation,
3202 dimensions_to_reduce);
3203 }
3204
ReduceAll(const XlaOp & operand,const XlaOp & init_value,const XlaComputation & computation)3205 XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
3206 const XlaComputation& computation) {
3207 return operand.builder()->ReduceAll(operand, init_value, computation);
3208 }
3209
ReduceWindow(const XlaOp & operand,const XlaOp & init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)3210 XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
3211 const XlaComputation& computation,
3212 absl::Span<const int64> window_dimensions,
3213 absl::Span<const int64> window_strides, Padding padding) {
3214 return operand.builder()->ReduceWindow(operand, init_value, computation,
3215 window_dimensions, window_strides,
3216 padding);
3217 }
3218
ReduceWindowWithGeneralPadding(const XlaOp & operand,const XlaOp & init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)3219 XlaOp ReduceWindowWithGeneralPadding(
3220 const XlaOp& operand, const XlaOp& init_value,
3221 const XlaComputation& computation,
3222 absl::Span<const int64> window_dimensions,
3223 absl::Span<const int64> window_strides,
3224 absl::Span<const int64> base_dilations,
3225 absl::Span<const int64> window_dilations,
3226 absl::Span<const std::pair<int64, int64>> padding) {
3227 return operand.builder()->ReduceWindowWithGeneralPadding(
3228 operand, init_value, computation, window_dimensions, window_strides,
3229 base_dilations, window_dilations, padding);
3230 }
3231
CrossReplicaSum(const XlaOp & operand,absl::Span<const ReplicaGroup> replica_groups)3232 XlaOp CrossReplicaSum(const XlaOp& operand,
3233 absl::Span<const ReplicaGroup> replica_groups) {
3234 return operand.builder()->CrossReplicaSum(operand, replica_groups);
3235 }
3236
CrossReplicaSum(const XlaOp & operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id)3237 XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation,
3238 absl::Span<const ReplicaGroup> replica_groups,
3239 const absl::optional<ChannelHandle>& channel_id) {
3240 return operand.builder()->CrossReplicaSum(operand, computation,
3241 replica_groups, channel_id);
3242 }
3243
AllToAll(const XlaOp & operand,int64 split_dimension,int64 concat_dimension,int64 split_count,const std::vector<ReplicaGroup> & replica_groups)3244 XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
3245 int64 concat_dimension, int64 split_count,
3246 const std::vector<ReplicaGroup>& replica_groups) {
3247 return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
3248 split_count, replica_groups);
3249 }
3250
CollectivePermute(const XlaOp & operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)3251 XlaOp CollectivePermute(
3252 const XlaOp& operand,
3253 const std::vector<std::pair<int64, int64>>& source_target_pairs) {
3254 return operand.builder()->CollectivePermute(operand, source_target_pairs);
3255 }
3256
ReplicaId(XlaBuilder * builder)3257 XlaOp ReplicaId(XlaBuilder* builder) { return builder->ReplicaId(); }
3258
SelectAndScatter(const XlaOp & operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding,const XlaOp & source,const XlaOp & init_value,const XlaComputation & scatter)3259 XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
3260 absl::Span<const int64> window_dimensions,
3261 absl::Span<const int64> window_strides, Padding padding,
3262 const XlaOp& source, const XlaOp& init_value,
3263 const XlaComputation& scatter) {
3264 return operand.builder()->SelectAndScatter(operand, select, window_dimensions,
3265 window_strides, padding, source,
3266 init_value, scatter);
3267 }
3268
SelectAndScatterWithGeneralPadding(const XlaOp & operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const XlaOp & source,const XlaOp & init_value,const XlaComputation & scatter)3269 XlaOp SelectAndScatterWithGeneralPadding(
3270 const XlaOp& operand, const XlaComputation& select,
3271 absl::Span<const int64> window_dimensions,
3272 absl::Span<const int64> window_strides,
3273 absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
3274 const XlaOp& init_value, const XlaComputation& scatter) {
3275 return operand.builder()->SelectAndScatterWithGeneralPadding(
3276 operand, select, window_dimensions, window_strides, padding, source,
3277 init_value, scatter);
3278 }
3279
Abs(const XlaOp & operand)3280 XlaOp Abs(const XlaOp& operand) {
3281 return operand.builder()->UnaryOp(HloOpcode::kAbs, operand);
3282 }
3283
Atan2(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3284 XlaOp Atan2(const XlaOp& lhs, const XlaOp& rhs,
3285 absl::Span<const int64> broadcast_dimensions) {
3286 return lhs.builder()->BinaryOp(HloOpcode::kAtan2, lhs, rhs,
3287 broadcast_dimensions);
3288 }
3289
Exp(const XlaOp & operand)3290 XlaOp Exp(const XlaOp& operand) {
3291 return operand.builder()->UnaryOp(HloOpcode::kExp, operand);
3292 }
Expm1(const XlaOp & operand)3293 XlaOp Expm1(const XlaOp& operand) {
3294 return operand.builder()->UnaryOp(HloOpcode::kExpm1, operand);
3295 }
Floor(const XlaOp & operand)3296 XlaOp Floor(const XlaOp& operand) {
3297 return operand.builder()->UnaryOp(HloOpcode::kFloor, operand);
3298 }
Ceil(const XlaOp & operand)3299 XlaOp Ceil(const XlaOp& operand) {
3300 return operand.builder()->UnaryOp(HloOpcode::kCeil, operand);
3301 }
Round(const XlaOp & operand)3302 XlaOp Round(const XlaOp& operand) {
3303 return operand.builder()->UnaryOp(HloOpcode::kRoundNearestAfz, operand);
3304 }
Log(const XlaOp & operand)3305 XlaOp Log(const XlaOp& operand) {
3306 return operand.builder()->UnaryOp(HloOpcode::kLog, operand);
3307 }
Log1p(const XlaOp & operand)3308 XlaOp Log1p(const XlaOp& operand) {
3309 return operand.builder()->UnaryOp(HloOpcode::kLog1p, operand);
3310 }
Sign(const XlaOp & operand)3311 XlaOp Sign(const XlaOp& operand) {
3312 return operand.builder()->UnaryOp(HloOpcode::kSign, operand);
3313 }
Clz(const XlaOp & operand)3314 XlaOp Clz(const XlaOp& operand) {
3315 return operand.builder()->UnaryOp(HloOpcode::kClz, operand);
3316 }
Cos(const XlaOp & operand)3317 XlaOp Cos(const XlaOp& operand) {
3318 return operand.builder()->UnaryOp(HloOpcode::kCos, operand);
3319 }
Sin(const XlaOp & operand)3320 XlaOp Sin(const XlaOp& operand) {
3321 return operand.builder()->UnaryOp(HloOpcode::kSin, operand);
3322 }
Tanh(const XlaOp & operand)3323 XlaOp Tanh(const XlaOp& operand) {
3324 return operand.builder()->UnaryOp(HloOpcode::kTanh, operand);
3325 }
Real(const XlaOp & operand)3326 XlaOp Real(const XlaOp& operand) {
3327 return operand.builder()->UnaryOp(HloOpcode::kReal, operand);
3328 }
Imag(const XlaOp & operand)3329 XlaOp Imag(const XlaOp& operand) {
3330 return operand.builder()->UnaryOp(HloOpcode::kImag, operand);
3331 }
Sqrt(const XlaOp & operand)3332 XlaOp Sqrt(const XlaOp& operand) {
3333 return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand);
3334 }
Rsqrt(const XlaOp & operand)3335 XlaOp Rsqrt(const XlaOp& operand) {
3336 return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand);
3337 }
3338
Pow(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3339 XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
3340 absl::Span<const int64> broadcast_dimensions) {
3341 return lhs.builder()->BinaryOp(HloOpcode::kPower, lhs, rhs,
3342 broadcast_dimensions);
3343 }
3344
IsFinite(const XlaOp & operand)3345 XlaOp IsFinite(const XlaOp& operand) {
3346 return operand.builder()->UnaryOp(HloOpcode::kIsFinite, operand);
3347 }
3348
ConvertElementType(const XlaOp & operand,PrimitiveType new_element_type)3349 XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) {
3350 return operand.builder()->ConvertElementType(operand, new_element_type);
3351 }
3352
BitcastConvertType(const XlaOp & operand,PrimitiveType new_element_type)3353 XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) {
3354 return operand.builder()->BitcastConvertType(operand, new_element_type);
3355 }
3356
Neg(const XlaOp & operand)3357 XlaOp Neg(const XlaOp& operand) {
3358 return operand.builder()->UnaryOp(HloOpcode::kNegate, operand);
3359 }
3360
Transpose(const XlaOp & operand,absl::Span<const int64> permutation)3361 XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation) {
3362 return operand.builder()->Transpose(operand, permutation);
3363 }
3364
Rev(const XlaOp & operand,absl::Span<const int64> dimensions)3365 XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions) {
3366 return operand.builder()->Rev(operand, dimensions);
3367 }
3368
Sort(const XlaOp & keys,absl::Span<const XlaOp> values,int64 dimension)3369 XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values, int64 dimension) {
3370 return keys.builder()->Sort(keys, values, dimension);
3371 }
3372
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64 dimension,bool is_stable)3373 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
3374 int64 dimension, bool is_stable) {
3375 return operands[0].builder()->Sort(operands, comparator, dimension,
3376 is_stable);
3377 }
3378
Clamp(const XlaOp & min,const XlaOp & operand,const XlaOp & max)3379 XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) {
3380 return min.builder()->Clamp(min, operand, max);
3381 }
3382
Map(XlaBuilder * builder,absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64> dimensions,absl::Span<const XlaOp> static_operands)3383 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
3384 const XlaComputation& computation, absl::Span<const int64> dimensions,
3385 absl::Span<const XlaOp> static_operands) {
3386 return builder->Map(operands, computation, dimensions, static_operands);
3387 }
3388
RngNormal(const XlaOp & mu,const XlaOp & sigma,const Shape & shape)3389 XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape) {
3390 return mu.builder()->RngNormal(mu, sigma, shape);
3391 }
3392
RngUniform(const XlaOp & a,const XlaOp & b,const Shape & shape)3393 XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape) {
3394 return a.builder()->RngUniform(a, b, shape);
3395 }
3396
While(const XlaComputation & condition,const XlaComputation & body,const XlaOp & init)3397 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
3398 const XlaOp& init) {
3399 return init.builder()->While(condition, body, init);
3400 }
3401
Conditional(const XlaOp & predicate,const XlaOp & true_operand,const XlaComputation & true_computation,const XlaOp & false_operand,const XlaComputation & false_computation)3402 XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
3403 const XlaComputation& true_computation,
3404 const XlaOp& false_operand,
3405 const XlaComputation& false_computation) {
3406 return predicate.builder()->Conditional(predicate, true_operand,
3407 true_computation, false_operand,
3408 false_computation);
3409 }
3410
Conditional(const XlaOp & branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)3411 XlaOp Conditional(const XlaOp& branch_index,
3412 absl::Span<const XlaComputation* const> branch_computations,
3413 absl::Span<const XlaOp> branch_operands) {
3414 return branch_index.builder()->Conditional(branch_index, branch_computations,
3415 branch_operands);
3416 }
3417
ReducePrecision(const XlaOp & operand,const int exponent_bits,const int mantissa_bits)3418 XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
3419 const int mantissa_bits) {
3420 return operand.builder()->ReducePrecision(operand, exponent_bits,
3421 mantissa_bits);
3422 }
3423
Gather(const XlaOp & input,const XlaOp & start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes)3424 XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
3425 const GatherDimensionNumbers& dimension_numbers,
3426 absl::Span<const int64> slice_sizes) {
3427 return input.builder()->Gather(input, start_indices, dimension_numbers,
3428 slice_sizes);
3429 }
3430
Scatter(const XlaOp & input,const XlaOp & scatter_indices,const XlaOp & updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers)3431 XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
3432 const XlaOp& updates, const XlaComputation& update_computation,
3433 const ScatterDimensionNumbers& dimension_numbers) {
3434 return input.builder()->Scatter(input, scatter_indices, updates,
3435 update_computation, dimension_numbers);
3436 }
3437
Send(const XlaOp & operand,const ChannelHandle & handle)3438 void Send(const XlaOp& operand, const ChannelHandle& handle) {
3439 return operand.builder()->Send(operand, handle);
3440 }
3441
Recv(XlaBuilder * builder,const Shape & shape,const ChannelHandle & handle)3442 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
3443 const ChannelHandle& handle) {
3444 return builder->Recv(shape, handle);
3445 }
3446
SendWithToken(const XlaOp & operand,const XlaOp & token,const ChannelHandle & handle)3447 XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
3448 const ChannelHandle& handle) {
3449 return operand.builder()->SendWithToken(operand, token, handle);
3450 }
3451
RecvWithToken(const XlaOp & token,const Shape & shape,const ChannelHandle & handle)3452 XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
3453 const ChannelHandle& handle) {
3454 return token.builder()->RecvWithToken(token, shape, handle);
3455 }
3456
SendToHost(const XlaOp & operand,const XlaOp & token,const Shape & shape_with_layout,const ChannelHandle & handle)3457 XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
3458 const Shape& shape_with_layout, const ChannelHandle& handle) {
3459 return operand.builder()->SendToHost(operand, token, shape_with_layout,
3460 handle);
3461 }
3462
RecvFromHost(const XlaOp & token,const Shape & shape,const ChannelHandle & handle)3463 XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
3464 const ChannelHandle& handle) {
3465 return token.builder()->RecvFromHost(token, shape, handle);
3466 }
3467
InfeedWithToken(const XlaOp & token,const Shape & shape,const string & config)3468 XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
3469 const string& config) {
3470 return token.builder()->InfeedWithToken(token, shape, config);
3471 }
3472
OutfeedWithToken(const XlaOp & operand,const XlaOp & token,const Shape & shape_with_layout,const string & outfeed_config)3473 XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
3474 const Shape& shape_with_layout,
3475 const string& outfeed_config) {
3476 return operand.builder()->OutfeedWithToken(operand, token, shape_with_layout,
3477 outfeed_config);
3478 }
3479
CreateToken(XlaBuilder * builder)3480 XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); }
3481
AfterAll(XlaBuilder * builder,absl::Span<const XlaOp> tokens)3482 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens) {
3483 return builder->AfterAll(tokens);
3484 }
3485
BatchNormTraining(const XlaOp & operand,const XlaOp & scale,const XlaOp & offset,float epsilon,int64 feature_index)3486 XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
3487 const XlaOp& offset, float epsilon,
3488 int64 feature_index) {
3489 return operand.builder()->BatchNormTraining(operand, scale, offset, epsilon,
3490 feature_index);
3491 }
3492
BatchNormInference(const XlaOp & operand,const XlaOp & scale,const XlaOp & offset,const XlaOp & mean,const XlaOp & variance,float epsilon,int64 feature_index)3493 XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
3494 const XlaOp& offset, const XlaOp& mean,
3495 const XlaOp& variance, float epsilon,
3496 int64 feature_index) {
3497 return operand.builder()->BatchNormInference(
3498 operand, scale, offset, mean, variance, epsilon, feature_index);
3499 }
3500
BatchNormGrad(const XlaOp & operand,const XlaOp & scale,const XlaOp & batch_mean,const XlaOp & batch_var,const XlaOp & grad_output,float epsilon,int64 feature_index)3501 XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
3502 const XlaOp& batch_mean, const XlaOp& batch_var,
3503 const XlaOp& grad_output, float epsilon,
3504 int64 feature_index) {
3505 return operand.builder()->BatchNormGrad(operand, scale, batch_mean, batch_var,
3506 grad_output, epsilon, feature_index);
3507 }
3508
Iota(XlaBuilder * builder,PrimitiveType type,int64 size)3509 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) {
3510 return builder->Iota(type, size);
3511 }
3512
Iota(XlaBuilder * builder,const Shape & shape,int64 iota_dimension)3513 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) {
3514 return builder->Iota(shape, iota_dimension);
3515 }
3516
GetDimensionSize(const XlaOp & operand,int64 dimension)3517 XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension) {
3518 return operand.builder()->GetDimensionSize(operand, dimension);
3519 }
3520
3521 } // namespace xla
3522