1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/xla_builder.h"
17
18 #include <functional>
19 #include <numeric>
20 #include <queue>
21 #include <string>
22 #include <utility>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/memory/memory.h"
28 #include "absl/strings/match.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_join.h"
31 #include "tensorflow/compiler/xla/client/sharding_builder.h"
32 #include "tensorflow/compiler/xla/client/xla_computation.h"
33 #include "tensorflow/compiler/xla/execution_options_util.h"
34 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
37 #include "tensorflow/compiler/xla/service/shape_inference.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/util.h"
40
41 namespace xla {
42
43 using absl::StrCat;
44
45 namespace {
46
47 static const char kNameSeparator = '.';
48
49 // Retrieves the base name of an instruction or computation fully qualified
50 // name, using separator as boundary between the initial base name part, and
51 // the numeric identification.
GetBaseName(const string & name,char separator)52 string GetBaseName(const string& name, char separator) {
53 auto pos = name.rfind(separator);
54 CHECK_NE(pos, string::npos) << name;
55 return name.substr(0, pos);
56 }
57
58 // Generates a fully qualified computation/instruction name.
GetFullName(const string & base_name,char separator,int64 id)59 string GetFullName(const string& base_name, char separator, int64 id) {
60 const char separator_str[] = {separator, '\0'};
61 return StrCat(base_name, separator_str, id);
62 }
63
64 // Common function to standardize setting name and IDs on computation and
65 // instruction proto entities.
66 template <typename T>
SetProtoIdAndName(T * entry,const string & base_name,char separator,int64 id)67 void SetProtoIdAndName(T* entry, const string& base_name, char separator,
68 int64 id) {
69 entry->set_id(id);
70 entry->set_name(GetFullName(base_name, separator, id));
71 }
72
73 template <typename InstructionType>
LookUpInstructionByHandleInternal(const absl::flat_hash_map<int64,int64> & handle_to_index,const std::vector<HloInstructionProto> & instructions,int64 handle)74 StatusOr<InstructionType> LookUpInstructionByHandleInternal(
75 const absl::flat_hash_map<int64, int64>& handle_to_index,
76 const std::vector<HloInstructionProto>& instructions, int64 handle) {
77 auto it = handle_to_index.find(handle);
78 if (it == handle_to_index.end()) {
79 return InvalidArgument("No XlaOp with handle %d", handle);
80 }
81 return const_cast<InstructionType>(&instructions.at(it->second));
82 }
83
CheckBuildersAffinity(const XlaBuilder * op_builder,const XlaBuilder * builder,int64 handle)84 Status CheckBuildersAffinity(const XlaBuilder* op_builder,
85 const XlaBuilder* builder, int64 handle) {
86 if (op_builder != builder) {
87 return InvalidArgument(
88 "XlaOp with handle %d is built by builder '%s', but is trying to use "
89 "it in builder '%s'",
90 handle, op_builder->name(), builder->name());
91 }
92 return Status::OK();
93 }
94
95 template <typename InstructionType, typename OpBuilderType,
96 typename BuilderType, typename OpType>
LookUpInstructionInternal(const absl::flat_hash_map<int64,int64> & handle_to_index,const std::vector<HloInstructionProto> & instructions,OpBuilderType op_builder,BuilderType builder,OpType op_handle)97 StatusOr<InstructionType> LookUpInstructionInternal(
98 const absl::flat_hash_map<int64, int64>& handle_to_index,
99 const std::vector<HloInstructionProto>& instructions,
100 OpBuilderType op_builder, BuilderType builder, OpType op_handle) {
101 if (op_builder == nullptr) {
102 return InvalidArgument(
103 "Invalid XlaOp with handle %d; the builder of this op is freed",
104 op_handle);
105 }
106 TF_RETURN_IF_ERROR(CheckBuildersAffinity(op_builder, builder, op_handle));
107 return LookUpInstructionByHandleInternal<InstructionType>(
108 handle_to_index, instructions, op_handle);
109 }
110
111 } // namespace
112
operator -(XlaOp x)113 XlaOp operator-(XlaOp x) { return Neg(x); }
operator +(XlaOp x,XlaOp y)114 XlaOp operator+(XlaOp x, XlaOp y) { return Add(x, y); }
operator -(XlaOp x,XlaOp y)115 XlaOp operator-(XlaOp x, XlaOp y) { return Sub(x, y); }
operator *(XlaOp x,XlaOp y)116 XlaOp operator*(XlaOp x, XlaOp y) { return Mul(x, y); }
operator /(XlaOp x,XlaOp y)117 XlaOp operator/(XlaOp x, XlaOp y) { return Div(x, y); }
operator %(XlaOp x,XlaOp y)118 XlaOp operator%(XlaOp x, XlaOp y) { return Rem(x, y); }
119
operator ~(XlaOp x)120 XlaOp operator~(XlaOp x) { return Not(x); }
operator &(XlaOp x,XlaOp y)121 XlaOp operator&(XlaOp x, XlaOp y) { return And(x, y); }
operator |(XlaOp x,XlaOp y)122 XlaOp operator|(XlaOp x, XlaOp y) { return Or(x, y); }
operator ^(XlaOp x,XlaOp y)123 XlaOp operator^(XlaOp x, XlaOp y) { return Xor(x, y); }
operator <<(XlaOp x,XlaOp y)124 XlaOp operator<<(XlaOp x, XlaOp y) { return ShiftLeft(x, y); }
125
operator >>(XlaOp x,XlaOp y)126 XlaOp operator>>(XlaOp x, XlaOp y) {
127 XlaBuilder* builder = x.builder();
128 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
129 TF_ASSIGN_OR_RETURN(const xla::Shape* shape, builder->GetShapePtr(x));
130 if (!ShapeUtil::ElementIsIntegral(*shape)) {
131 return InvalidArgument(
132 "Argument to >> operator does not have an integral type (%s).",
133 ShapeUtil::HumanString(*shape));
134 }
135 if (ShapeUtil::ElementIsSigned(*shape)) {
136 return ShiftRightArithmetic(x, y);
137 } else {
138 return ShiftRightLogical(x, y);
139 }
140 });
141 }
142
GetShapePtr(XlaOp op) const143 StatusOr<const Shape*> XlaBuilder::GetShapePtr(XlaOp op) const {
144 TF_RETURN_IF_ERROR(first_error_);
145 TF_RETURN_IF_ERROR(CheckBuildersAffinity(op.builder(), this, op.handle()));
146 auto it = handle_to_index_.find(op.handle());
147 if (it == handle_to_index_.end()) {
148 return InvalidArgument("No XlaOp with handle %d", op.handle());
149 }
150 return instruction_shapes_.at(it->second).get();
151 }
152
GetShape(XlaOp op) const153 StatusOr<Shape> XlaBuilder::GetShape(XlaOp op) const {
154 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(op));
155 return *shape;
156 }
157
GetOperandShapes(absl::Span<const XlaOp> operands) const158 StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
159 absl::Span<const XlaOp> operands) const {
160 std::vector<Shape> operand_shapes;
161 for (XlaOp operand : operands) {
162 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
163 operand_shapes.push_back(*shape);
164 }
165 return operand_shapes;
166 }
167
XlaBuilder(const string & computation_name)168 XlaBuilder::XlaBuilder(const string& computation_name)
169 : name_(computation_name) {}
170
~XlaBuilder()171 XlaBuilder::~XlaBuilder() {}
172
ReportError(const Status & error)173 XlaOp XlaBuilder::ReportError(const Status& error) {
174 CHECK(!error.ok());
175 if (die_immediately_on_error_) {
176 LOG(FATAL) << "error building computation: " << error;
177 }
178
179 if (first_error_.ok()) {
180 first_error_ = error;
181 first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
182 }
183 return XlaOp(this);
184 }
185
ReportErrorOrReturn(const StatusOr<XlaOp> & op)186 XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr<XlaOp>& op) {
187 if (!first_error_.ok()) {
188 return XlaOp(this);
189 }
190 if (!op.ok()) {
191 return ReportError(op.status());
192 }
193 return op.ValueOrDie();
194 }
195
ReportErrorOrReturn(const std::function<StatusOr<XlaOp> ()> & op_creator)196 XlaOp XlaBuilder::ReportErrorOrReturn(
197 const std::function<StatusOr<XlaOp>()>& op_creator) {
198 return ReportErrorOrReturn(op_creator());
199 }
200
GetProgramShape(int64 root_id) const201 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
202 TF_RETURN_IF_ERROR(first_error_);
203 TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
204 LookUpInstructionByHandle(root_id));
205
206 ProgramShape program_shape;
207
208 *program_shape.mutable_result() = Shape(root_proto->shape());
209
210 // Check that the parameter numbers are continuous from 0, and add parameter
211 // shapes and names to the program shape.
212 const int64 param_count = parameter_numbers_.size();
213 for (int64 i = 0; i < param_count; i++) {
214 program_shape.add_parameters();
215 program_shape.add_parameter_names();
216 }
217 for (const HloInstructionProto& instr : instructions_) {
218 // Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So
219 // to verify continuity, we just need to verify that every parameter is in
220 // the right range.
221 if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
222 const int64 index = instr.parameter_number();
223 TF_RET_CHECK(index >= 0 && index < param_count)
224 << "invalid parameter number: " << index;
225 *program_shape.mutable_parameters(index) = Shape(instr.shape());
226 *program_shape.mutable_parameter_names(index) = instr.name();
227 }
228 }
229 return program_shape;
230 }
231
GetProgramShape() const232 StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const {
233 TF_RET_CHECK(!instructions_.empty());
234 return GetProgramShape(instructions_.back().id());
235 }
236
GetProgramShape(XlaOp root) const237 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(XlaOp root) const {
238 if (root.builder_ != this) {
239 return InvalidArgument("Given root operation is not in this computation.");
240 }
241 return GetProgramShape(root.handle());
242 }
243
IsConstantVisitor(const int64 op_handle,absl::flat_hash_set<int64> * visited,bool * is_constant) const244 void XlaBuilder::IsConstantVisitor(const int64 op_handle,
245 absl::flat_hash_set<int64>* visited,
246 bool* is_constant) const {
247 if (visited->contains(op_handle) || !*is_constant) {
248 return;
249 }
250
251 const HloInstructionProto& instr =
252 *(LookUpInstructionByHandle(op_handle).ValueOrDie());
253 const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
254 switch (opcode) {
255 default:
256 for (const int64 operand_id : instr.operand_ids()) {
257 IsConstantVisitor(operand_id, visited, is_constant);
258 }
259 // TODO(b/32495713): We aren't checking the called computations.
260 break;
261
262 case HloOpcode::kGetDimensionSize:
263 // GetDimensionSize is always considered constant in XLA -- If a dynamic
264 // dimension is presented, -1 is returned.
265 break;
266
267 // Non functional ops.
268 case HloOpcode::kRng:
269 case HloOpcode::kAllReduce:
270 // TODO(b/33009255): Implement constant folding for cross replica sum.
271 case HloOpcode::kInfeed:
272 case HloOpcode::kOutfeed:
273 case HloOpcode::kCall:
274 // TODO(b/32495713): We aren't checking the to_apply computation itself,
275 // so we conservatively say that computations containing the Call op
276 // cannot be constant. We cannot set is_functional=false in other similar
277 // cases since we're already relying on IsConstant to return true.
278 case HloOpcode::kCustomCall:
279 case HloOpcode::kWhile:
280 // TODO(b/32495713): We aren't checking the condition and body
281 // computations themselves.
282 case HloOpcode::kScatter:
283 // TODO(b/32495713): We aren't checking the embedded computation in
284 // Scatter.
285 case HloOpcode::kSend:
286 case HloOpcode::kRecv:
287 case HloOpcode::kParameter:
288 *is_constant = false;
289 break;
290 }
291 if (!*is_constant) {
292 VLOG(1) << "Non-constant: " << instr.name();
293 }
294 visited->insert(op_handle);
295 }
296
SetDynamicBinding(int64 dynamic_size_param_num,ShapeIndex dynamic_size_param_index,int64 target_param_num,ShapeIndex target_param_index,int64 target_dim_num)297 Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num,
298 ShapeIndex dynamic_size_param_index,
299 int64 target_param_num,
300 ShapeIndex target_param_index,
301 int64 target_dim_num) {
302 bool param_exists = false;
303 for (size_t index = 0; index < instructions_.size(); ++index) {
304 HloInstructionProto& instr = instructions_[index];
305 if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
306 instr.parameter_number() == target_param_num) {
307 param_exists = true;
308 Shape param_shape(instr.shape());
309 Shape* param_shape_ptr = ¶m_shape;
310 for (int64 index : target_param_index) {
311 param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index);
312 }
313 param_shape_ptr->set_dynamic_dimension(target_dim_num,
314 /*is_dynamic=*/true);
315 *instr.mutable_shape() = param_shape.ToProto();
316 instruction_shapes_[index] =
317 absl::make_unique<Shape>(std::move(param_shape));
318 }
319 }
320 if (!param_exists) {
321 return InvalidArgument(
322 "Asked to mark parameter %lld as dynamic sized parameter, but the "
323 "doesn't exists",
324 target_param_num);
325 }
326
327 TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind(
328 DynamicParameterBinding::DynamicParameter{dynamic_size_param_num,
329 dynamic_size_param_index},
330 DynamicParameterBinding::DynamicDimension{
331 target_param_num, target_param_index, target_dim_num}));
332 return Status::OK();
333 }
334
SetInstructionFrontendAttribute(const XlaOp op,std::string attribute,std::string value)335 Status XlaBuilder::SetInstructionFrontendAttribute(const XlaOp op,
336 std::string attribute,
337 std::string value) {
338 TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op));
339 auto* frontend_attributes = instr_proto->mutable_frontend_attributes();
340 (*frontend_attributes->mutable_map())[attribute] = std::move(value);
341 return Status::OK();
342 }
343
BuildAndNoteError()344 XlaComputation XlaBuilder::BuildAndNoteError() {
345 DCHECK(parent_builder_ != nullptr);
346 auto build_status = Build();
347 if (!build_status.ok()) {
348 parent_builder_->ReportError(
349 AddStatus(build_status.status(), absl::StrCat("error from: ", name_)));
350 return {};
351 }
352 return build_status.ConsumeValueOrDie();
353 }
354
GetCurrentStatus() const355 Status XlaBuilder::GetCurrentStatus() const {
356 if (!first_error_.ok()) {
357 string backtrace;
358 first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
359 return AppendStatus(first_error_, backtrace);
360 }
361 return Status::OK();
362 }
363
Build(bool remove_dynamic_dimensions)364 StatusOr<XlaComputation> XlaBuilder::Build(bool remove_dynamic_dimensions) {
365 TF_RETURN_IF_ERROR(GetCurrentStatus());
366 return Build(instructions_.back().id(), remove_dynamic_dimensions);
367 }
368
Build(XlaOp root,bool remove_dynamic_dimensions)369 StatusOr<XlaComputation> XlaBuilder::Build(XlaOp root,
370 bool remove_dynamic_dimensions) {
371 if (root.builder_ != this) {
372 return InvalidArgument("Given root operation is not in this computation.");
373 }
374 return Build(root.handle(), remove_dynamic_dimensions);
375 }
376
Build(int64 root_id,bool remove_dynamic_dimensions)377 StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id,
378 bool remove_dynamic_dimensions) {
379 TF_RETURN_IF_ERROR(GetCurrentStatus());
380
381 // TODO(b/121223198): XLA backend cannot handle dynamic dimensions yet, remove
382 // all dynamic dimensions before building xla program until we have support in
383 // the backend.
384 if (remove_dynamic_dimensions) {
385 std::function<void(Shape*)> remove_dynamic_dimension = [&](Shape* shape) {
386 if (shape->tuple_shapes_size() != 0) {
387 for (int64 i = 0; i < shape->tuple_shapes_size(); ++i) {
388 remove_dynamic_dimension(shape->mutable_tuple_shapes(i));
389 }
390 }
391 for (int64 i = 0; i < shape->dimensions_size(); ++i) {
392 shape->set_dynamic_dimension(i, false);
393 }
394 };
395 for (size_t index = 0; index < instructions_.size(); ++index) {
396 remove_dynamic_dimension(instruction_shapes_[index].get());
397 *instructions_[index].mutable_shape() =
398 instruction_shapes_[index]->ToProto();
399 }
400 }
401
402 HloComputationProto entry;
403 SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId());
404 TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id));
405 *entry.mutable_program_shape() = program_shape.ToProto();
406 entry.set_root_id(root_id);
407
408 for (auto& instruction : instructions_) {
409 // Ensures that the instruction names are unique among the whole graph.
410 instruction.set_name(
411 GetFullName(instruction.name(), kNameSeparator, instruction.id()));
412 entry.add_instructions()->Swap(&instruction);
413 }
414
415 XlaComputation computation(entry.id());
416 HloModuleProto* module = computation.mutable_proto();
417 module->set_name(entry.name());
418 module->set_id(entry.id());
419 module->set_entry_computation_name(entry.name());
420 module->set_entry_computation_id(entry.id());
421 *module->mutable_host_program_shape() = entry.program_shape();
422 for (auto& e : embedded_) {
423 module->add_computations()->Swap(&e.second);
424 }
425 module->add_computations()->Swap(&entry);
426 if (!input_output_aliases_.empty()) {
427 TF_RETURN_IF_ERROR(
428 PopulateInputOutputAlias(module, program_shape, input_output_aliases_));
429 }
430 *(module->mutable_dynamic_parameter_binding()) =
431 dynamic_parameter_binding_.ToProto();
432
433 // Clear data held by this builder.
434 this->instructions_.clear();
435 this->instruction_shapes_.clear();
436 this->handle_to_index_.clear();
437 this->embedded_.clear();
438 this->parameter_numbers_.clear();
439
440 return std::move(computation);
441 }
442
PopulateInputOutputAlias(HloModuleProto * module,const ProgramShape & program_shape,const std::vector<InputOutputAlias> & input_output_aliases)443 /* static */ Status XlaBuilder::PopulateInputOutputAlias(
444 HloModuleProto* module, const ProgramShape& program_shape,
445 const std::vector<InputOutputAlias>& input_output_aliases) {
446 HloInputOutputAliasConfig config(program_shape.result());
447 for (auto& alias : input_output_aliases) {
448 // The HloInputOutputAliasConfig does not do parameter validation as it only
449 // carries the result shape. Maybe it should be constructed with a
450 // ProgramShape to allow full validation. We will still get an error when
451 // trying to compile the HLO module, but would be better to have validation
452 // at this stage.
453 if (alias.param_number >= program_shape.parameters_size()) {
454 return InvalidArgument("Invalid parameter number %ld (total %ld)",
455 alias.param_number,
456 program_shape.parameters_size());
457 }
458 const Shape& parameter_shape = program_shape.parameters(alias.param_number);
459 if (!ShapeUtil::IndexIsValid(parameter_shape, alias.param_index)) {
460 return InvalidArgument("Invalid parameter %ld index: %s",
461 alias.param_number,
462 alias.param_index.ToString().c_str());
463 }
464 TF_RETURN_IF_ERROR(config.SetUpAlias(
465 alias.output_index, alias.param_number, alias.param_index,
466 HloInputOutputAliasConfig::AliasKind::kUserAlias));
467 }
468 *module->mutable_input_output_alias() = config.ToProto();
469 return Status::OK();
470 }
471
InDimBroadcast(const Shape & shape,XlaOp operand,absl::Span<const int64> broadcast_dimensions)472 StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
473 const Shape& shape, XlaOp operand,
474 absl::Span<const int64> broadcast_dimensions) {
475 TF_RETURN_IF_ERROR(first_error_);
476
477 HloInstructionProto instr;
478 *instr.mutable_shape() = shape.ToProto();
479 for (int64 dim : broadcast_dimensions) {
480 instr.add_dimensions(dim);
481 }
482
483 return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand});
484 }
485
AddBroadcastSequence(const Shape & output_shape,XlaOp operand)486 StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
487 XlaOp operand) {
488 TF_RETURN_IF_ERROR(first_error_);
489
490 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
491
492 CHECK(ShapeUtil::IsScalar(*operand_shape) ||
493 operand_shape->rank() == output_shape.rank());
494 Shape broadcast_shape =
495 ShapeUtil::ChangeElementType(output_shape, operand_shape->element_type());
496
497 // Do explicit broadcast for scalar.
498 if (ShapeUtil::IsScalar(*operand_shape)) {
499 return InDimBroadcast(broadcast_shape, operand, {});
500 }
501
502 // Do explicit broadcast for degenerate broadcast.
503 std::vector<int64> broadcast_dimensions;
504 std::vector<int64> reshaped_dimensions;
505 for (int i = 0; i < operand_shape->rank(); i++) {
506 if (operand_shape->dimensions(i) == output_shape.dimensions(i)) {
507 broadcast_dimensions.push_back(i);
508 reshaped_dimensions.push_back(operand_shape->dimensions(i));
509 } else {
510 TF_RET_CHECK(operand_shape->dimensions(i) == 1)
511 << "An explicit broadcast sequence requires the broadcasted "
512 "dimensions to be trivial; operand shape: "
513 << *operand_shape << "; output_shape: " << output_shape;
514 }
515 }
516
517 Shape reshaped_shape =
518 ShapeUtil::MakeShape(operand_shape->element_type(), reshaped_dimensions);
519
520 std::vector<std::pair<int64, int64>> unmodified_dims =
521 ShapeUtil::DimensionsUnmodifiedByReshape(*operand_shape, reshaped_shape);
522
523 for (auto& unmodified : unmodified_dims) {
524 if (operand_shape->is_dynamic_dimension(unmodified.first)) {
525 reshaped_shape.set_dynamic_dimension(unmodified.second, true);
526 }
527 }
528
529 // Eliminate the size one dimensions.
530 TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, Reshape(reshaped_shape, operand));
531 // Broadcast 'reshape' up to the larger size.
532 return InDimBroadcast(broadcast_shape, reshaped_operand,
533 broadcast_dimensions);
534 }
535
UnaryOp(HloOpcode unop,XlaOp operand)536 XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
537 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
538 HloInstructionProto instr;
539 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
540 TF_ASSIGN_OR_RETURN(
541 Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape));
542 *instr.mutable_shape() = shape.ToProto();
543 return AddInstruction(std::move(instr), unop, {operand});
544 });
545 }
546
BinaryOp(HloOpcode binop,XlaOp lhs,XlaOp rhs,absl::Span<const int64> broadcast_dimensions,absl::optional<ComparisonDirection> direction)547 XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
548 absl::Span<const int64> broadcast_dimensions,
549 absl::optional<ComparisonDirection> direction) {
550 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
551 HloInstructionProto instr;
552 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
553 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
554 TF_ASSIGN_OR_RETURN(
555 Shape shape, ShapeInference::InferBinaryOpShape(
556 binop, *lhs_shape, *rhs_shape, broadcast_dimensions));
557 *instr.mutable_shape() = shape.ToProto();
558 if (binop == HloOpcode::kCompare) {
559 if (!direction.has_value()) {
560 return InvalidArgument(
561 "kCompare expects a ComparisonDirection, but none provided.");
562 }
563 instr.set_comparison_direction(ComparisonDirectionToString(*direction));
564 } else if (direction.has_value()) {
565 return InvalidArgument(
566 "A comparison direction is provided for a non-compare opcode: %s.",
567 HloOpcodeString(binop));
568 }
569
570 const int64 lhs_rank = lhs_shape->rank();
571 const int64 rhs_rank = rhs_shape->rank();
572
573 XlaOp updated_lhs = lhs;
574 XlaOp updated_rhs = rhs;
575
576 if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) {
577 const bool should_broadcast_lhs = lhs_rank < rhs_rank;
578 XlaOp from = should_broadcast_lhs ? lhs : rhs;
579 const Shape& from_shape = should_broadcast_lhs ? *lhs_shape : *rhs_shape;
580
581 std::vector<int64> to_size;
582 std::vector<bool> to_size_is_dynamic;
583 for (int i = 0; i < shape.rank(); i++) {
584 to_size.push_back(shape.dimensions(i));
585 to_size_is_dynamic.push_back(shape.is_dynamic_dimension(i));
586 }
587 for (int64 from_dim = 0; from_dim < from_shape.rank(); from_dim++) {
588 int64 to_dim = broadcast_dimensions[from_dim];
589 to_size[to_dim] = from_shape.dimensions(from_dim);
590 to_size_is_dynamic[to_dim] = from_shape.is_dynamic_dimension(from_dim);
591 }
592
593 const Shape& broadcasted_shape = ShapeUtil::MakeShape(
594 from_shape.element_type(), to_size, to_size_is_dynamic);
595 TF_ASSIGN_OR_RETURN(
596 XlaOp broadcasted_operand,
597 InDimBroadcast(broadcasted_shape, from, broadcast_dimensions));
598
599 updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs;
600 updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs;
601 }
602
603 TF_ASSIGN_OR_RETURN(const Shape* updated_lhs_shape,
604 GetShapePtr(updated_lhs));
605 if (!ShapeUtil::SameDimensions(shape, *updated_lhs_shape)) {
606 TF_ASSIGN_OR_RETURN(updated_lhs,
607 AddBroadcastSequence(shape, updated_lhs));
608 }
609 TF_ASSIGN_OR_RETURN(const Shape* updated_rhs_shape,
610 GetShapePtr(updated_rhs));
611 if (!ShapeUtil::SameDimensions(shape, *updated_rhs_shape)) {
612 TF_ASSIGN_OR_RETURN(updated_rhs,
613 AddBroadcastSequence(shape, updated_rhs));
614 }
615
616 return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs});
617 });
618 }
619
TernaryOp(HloOpcode triop,XlaOp lhs,XlaOp rhs,XlaOp ehs)620 XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) {
621 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
622 HloInstructionProto instr;
623 XlaOp updated_lhs = lhs;
624 XlaOp updated_rhs = rhs;
625 XlaOp updated_ehs = ehs;
626 // The client API supports implicit broadcast for kSelect and kClamp, but
627 // XLA does not support implicit broadcast. Make implicit broadcast explicit
628 // and update the operands.
629 if (triop == HloOpcode::kSelect || triop == HloOpcode::kClamp) {
630 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
631 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
632 TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(ehs));
633
634 absl::optional<Shape> non_scalar_shape;
635 for (const Shape* shape : {lhs_shape, rhs_shape, ehs_shape}) {
636 if (shape->IsArray() && shape->rank() != 0) {
637 if (non_scalar_shape.has_value()) {
638 // TODO(jpienaar): The case where we need to compute the broadcasted
639 // shape by considering multiple of the shapes is not implemented.
640 // Consider reusing getBroadcastedType from mlir/Dialect/Traits.h.
641 TF_RET_CHECK(non_scalar_shape.value().dimensions() ==
642 shape->dimensions())
643 << "Unimplemented implicit broadcast.";
644 } else {
645 non_scalar_shape = *shape;
646 }
647 }
648 }
649 if (non_scalar_shape.has_value()) {
650 if (ShapeUtil::IsScalar(*lhs_shape)) {
651 TF_ASSIGN_OR_RETURN(updated_lhs,
652 AddBroadcastSequence(*non_scalar_shape, lhs));
653 }
654 if (ShapeUtil::IsScalar(*rhs_shape)) {
655 TF_ASSIGN_OR_RETURN(updated_rhs,
656 AddBroadcastSequence(*non_scalar_shape, rhs));
657 }
658 if (ShapeUtil::IsScalar(*ehs_shape)) {
659 TF_ASSIGN_OR_RETURN(updated_ehs,
660 AddBroadcastSequence(*non_scalar_shape, ehs));
661 }
662 }
663 }
664
665 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(updated_lhs));
666 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(updated_rhs));
667 TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(updated_ehs));
668 StatusOr<const Shape> status_or_shape = ShapeInference::InferTernaryOpShape(
669 triop, *lhs_shape, *rhs_shape, *ehs_shape);
670 if (!status_or_shape.status().ok()) {
671 return InvalidArgument(
672 "%s Input scalar shapes may have been changed to non-scalar shapes.",
673 status_or_shape.status().error_message());
674 }
675 *instr.mutable_shape() = status_or_shape.ConsumeValueOrDie().ToProto();
676 return AddInstruction(std::move(instr), triop,
677 {updated_lhs, updated_rhs, updated_ehs});
678 });
679 }
680
ConstantLiteral(const LiteralSlice & literal)681 XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
682 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
683 if (literal.shape().IsArray() && literal.element_count() > 1 &&
684 literal.IsAllFirst()) {
685 Literal scalar = LiteralUtil::GetFirstScalarLiteral(literal);
686 HloInstructionProto instr;
687 *instr.mutable_shape() = scalar.shape().ToProto();
688 *instr.mutable_literal() = scalar.ToProto();
689 TF_ASSIGN_OR_RETURN(
690 XlaOp scalar_op,
691 AddInstruction(std::move(instr), HloOpcode::kConstant));
692 return Broadcast(scalar_op, literal.shape().dimensions());
693 } else {
694 HloInstructionProto instr;
695 *instr.mutable_shape() = literal.shape().ToProto();
696 *instr.mutable_literal() = literal.ToProto();
697 return AddInstruction(std::move(instr), HloOpcode::kConstant);
698 }
699 });
700 }
701
Iota(const Shape & shape,int64 iota_dimension)702 XlaOp XlaBuilder::Iota(const Shape& shape, int64 iota_dimension) {
703 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
704 HloInstructionProto instr;
705 *instr.mutable_shape() = shape.ToProto();
706 instr.add_dimensions(iota_dimension);
707 return AddInstruction(std::move(instr), HloOpcode::kIota);
708 });
709 }
710
Iota(PrimitiveType type,int64 size)711 XlaOp XlaBuilder::Iota(PrimitiveType type, int64 size) {
712 return Iota(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0);
713 }
714
Call(const XlaComputation & computation,absl::Span<const XlaOp> operands)715 XlaOp XlaBuilder::Call(const XlaComputation& computation,
716 absl::Span<const XlaOp> operands) {
717 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
718 HloInstructionProto instr;
719 std::vector<const Shape*> operand_shape_ptrs;
720 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
721 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
722 [](const Shape& shape) { return &shape; });
723 TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
724 computation.GetProgramShape());
725 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape(
726 operand_shape_ptrs,
727 /*to_apply=*/called_program_shape));
728 *instr.mutable_shape() = shape.ToProto();
729
730 AddCalledComputation(computation, &instr);
731
732 return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
733 });
734 }
735
Parameter(int64 parameter_number,const Shape & shape,const string & name,const std::vector<bool> & replicated_at_leaf_buffers)736 XlaOp XlaBuilder::Parameter(
737 int64 parameter_number, const Shape& shape, const string& name,
738 const std::vector<bool>& replicated_at_leaf_buffers) {
739 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
740 HloInstructionProto instr;
741 if (!parameter_numbers_.insert(parameter_number).second) {
742 return InvalidArgument("parameter %d already registered",
743 parameter_number);
744 }
745 instr.set_parameter_number(parameter_number);
746 instr.set_name(name);
747 *instr.mutable_shape() = shape.ToProto();
748 if (!replicated_at_leaf_buffers.empty()) {
749 auto replication = instr.mutable_parameter_replication();
750 for (bool replicated : replicated_at_leaf_buffers) {
751 replication->add_replicated_at_leaf_buffers(replicated);
752 }
753 }
754 return AddInstruction(std::move(instr), HloOpcode::kParameter);
755 });
756 }
757
Broadcast(XlaOp operand,absl::Span<const int64> broadcast_sizes)758 XlaOp XlaBuilder::Broadcast(XlaOp operand,
759 absl::Span<const int64> broadcast_sizes) {
760 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
761 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
762 TF_ASSIGN_OR_RETURN(
763 const Shape& shape,
764 ShapeInference::InferBroadcastShape(*operand_shape, broadcast_sizes));
765
766 // The client-level broadcast op just appends dimensions on the left (adds
767 // lowest numbered dimensions). The HLO broadcast instruction is more
768 // flexible and can add new dimensions anywhere. The instruction's
769 // dimensions field maps operand dimensions to dimensions in the broadcast
770 // output, so to append dimensions on the left the instruction's dimensions
771 // should just be the n highest dimension numbers of the output shape where
772 // n is the number of input dimensions.
773 const int64 operand_rank = operand_shape->rank();
774 std::vector<int64> dimensions(operand_rank);
775 for (int i = 0; i < operand_rank; ++i) {
776 dimensions[i] = i + shape.rank() - operand_rank;
777 }
778 return InDimBroadcast(shape, operand, dimensions);
779 });
780 }
781
BroadcastInDim(XlaOp operand,const absl::Span<const int64> out_dim_size,const absl::Span<const int64> broadcast_dimensions)782 XlaOp XlaBuilder::BroadcastInDim(
783 XlaOp operand, const absl::Span<const int64> out_dim_size,
784 const absl::Span<const int64> broadcast_dimensions) {
785 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
786 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
787 // Output shape, in the case of degenerate broadcast, the out_dim_size is
788 // not necessarily the same as the dimension sizes of the output shape.
789 auto output_shape =
790 ShapeUtil::MakeShape(operand_shape->element_type(), out_dim_size);
791 if (operand_shape->rank() != broadcast_dimensions.size()) {
792 return InvalidArgument(
793 "Size of broadcast_dimensions has to match operand's rank; operand "
794 "rank: %lld, size of broadcast_dimensions %u.",
795 operand_shape->rank(), broadcast_dimensions.size());
796 }
797 for (int i = 0; i < broadcast_dimensions.size(); i++) {
798 if (broadcast_dimensions[i] < 0 ||
799 broadcast_dimensions[i] > out_dim_size.size()) {
800 return InvalidArgument("Broadcast dimension %lld is out of bound",
801 broadcast_dimensions[i]);
802 }
803 output_shape.set_dynamic_dimension(
804 broadcast_dimensions[i], operand_shape->is_dynamic_dimension(i));
805 }
806
807 TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape(
808 *operand_shape, output_shape, broadcast_dimensions)
809 .status());
810 std::vector<int64> in_dim_size(out_dim_size.begin(), out_dim_size.end());
811 for (int i = 0; i < broadcast_dimensions.size(); i++) {
812 in_dim_size[broadcast_dimensions[i]] = operand_shape->dimensions(i);
813 }
814 const auto& in_dim_shape =
815 ShapeUtil::MakeShape(operand_shape->element_type(), in_dim_size);
816 TF_ASSIGN_OR_RETURN(
817 XlaOp in_dim_broadcast,
818 InDimBroadcast(in_dim_shape, operand, broadcast_dimensions));
819
820 // If broadcast is not degenerate, return broadcasted result.
821 if (ShapeUtil::Equal(in_dim_shape, output_shape)) {
822 return in_dim_broadcast;
823 }
824
825 // Otherwise handle degenerate broadcast case.
826 return AddBroadcastSequence(output_shape, in_dim_broadcast);
827 });
828 }
829
Reshape(const Shape & shape,XlaOp operand,int64 inferred_dimension)830 StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, XlaOp operand,
831 int64 inferred_dimension) {
832 TF_RETURN_IF_ERROR(first_error_);
833
834 HloInstructionProto instr;
835 *instr.mutable_shape() = shape.ToProto();
836 if (inferred_dimension != -1) {
837 instr.add_dimensions(inferred_dimension);
838 }
839 return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand});
840 }
841
Slice(XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)842 XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span<const int64> start_indices,
843 absl::Span<const int64> limit_indices,
844 absl::Span<const int64> strides) {
845 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
846 HloInstructionProto instr;
847 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
848 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSliceShape(
849 *operand_shape, start_indices,
850 limit_indices, strides));
851 *instr.mutable_shape() = shape.ToProto();
852 for (int i = 0; i < start_indices.size(); i++) {
853 auto* slice_config = instr.add_slice_dimensions();
854 slice_config->set_start(start_indices[i]);
855 slice_config->set_limit(limit_indices[i]);
856 slice_config->set_stride(strides[i]);
857 }
858
859 return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
860 });
861 }
862
SliceInDim(XlaOp operand,int64 start_index,int64 limit_index,int64 stride,int64 dimno)863 XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64 start_index,
864 int64 limit_index, int64 stride, int64 dimno) {
865 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
866 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
867 std::vector<int64> starts(shape->rank(), 0);
868 std::vector<int64> limits(shape->dimensions().begin(),
869 shape->dimensions().end());
870 std::vector<int64> strides(shape->rank(), 1);
871 starts[dimno] = start_index;
872 limits[dimno] = limit_index;
873 strides[dimno] = stride;
874 return Slice(operand, starts, limits, strides);
875 });
876 }
877
DynamicSlice(XlaOp operand,XlaOp start_indices,absl::Span<const int64> slice_sizes)878 XlaOp XlaBuilder::DynamicSlice(XlaOp operand, XlaOp start_indices,
879 absl::Span<const int64> slice_sizes) {
880 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
881 HloInstructionProto instr;
882
883 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
884 TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape,
885 GetShapePtr(start_indices));
886 TF_ASSIGN_OR_RETURN(
887 Shape shape, ShapeInference::InferDynamicSliceShape(
888 *operand_shape, {*start_indices_shape}, slice_sizes));
889 *instr.mutable_shape() = shape.ToProto();
890
891 for (int64 size : slice_sizes) {
892 instr.add_dynamic_slice_sizes(size);
893 }
894
895 return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice,
896 {operand, start_indices});
897 });
898 }
899
DynamicSlice(XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)900 XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
901 absl::Span<const XlaOp> start_indices,
902 absl::Span<const int64> slice_sizes) {
903 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
904 HloInstructionProto instr;
905
906 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
907 std::vector<const Shape*> start_indices_shape_ptrs;
908 TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
909 GetOperandShapes(start_indices));
910 absl::c_transform(start_indices_shapes,
911 std::back_inserter(start_indices_shape_ptrs),
912 [](const Shape& shape) { return &shape; });
913 TF_ASSIGN_OR_RETURN(Shape shape,
914 ShapeInference::InferDynamicSliceShape(
915 *operand_shape, start_indices_shapes, slice_sizes));
916 *instr.mutable_shape() = shape.ToProto();
917
918 for (int64 size : slice_sizes) {
919 instr.add_dynamic_slice_sizes(size);
920 }
921
922 std::vector<XlaOp> operands = {operand};
923 operands.insert(operands.end(), start_indices.begin(), start_indices.end());
924 return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands);
925 });
926 }
927
DynamicUpdateSlice(XlaOp operand,XlaOp update,XlaOp start_indices)928 XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
929 XlaOp start_indices) {
930 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
931 HloInstructionProto instr;
932
933 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
934 TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
935 TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape,
936 GetShapePtr(start_indices));
937 TF_ASSIGN_OR_RETURN(
938 Shape shape,
939 ShapeInference::InferDynamicUpdateSliceShape(
940 *operand_shape, *update_shape, {*start_indices_shape}));
941 *instr.mutable_shape() = shape.ToProto();
942
943 return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
944 {operand, update, start_indices});
945 });
946 }
947
DynamicUpdateSlice(XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)948 XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
949 absl::Span<const XlaOp> start_indices) {
950 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
951 HloInstructionProto instr;
952
953 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
954 TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
955 std::vector<const Shape*> start_indices_shape_ptrs;
956 TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
957 GetOperandShapes(start_indices));
958 absl::c_transform(start_indices_shapes,
959 std::back_inserter(start_indices_shape_ptrs),
960 [](const Shape& shape) { return &shape; });
961 TF_ASSIGN_OR_RETURN(
962 Shape shape, ShapeInference::InferDynamicUpdateSliceShape(
963 *operand_shape, *update_shape, start_indices_shapes));
964 *instr.mutable_shape() = shape.ToProto();
965
966 std::vector<XlaOp> operands = {operand, update};
967 operands.insert(operands.end(), start_indices.begin(), start_indices.end());
968 return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
969 operands);
970 });
971 }
972
ConcatInDim(absl::Span<const XlaOp> operands,int64 dimension)973 XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
974 int64 dimension) {
975 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
976 HloInstructionProto instr;
977
978 std::vector<const Shape*> operand_shape_ptrs;
979 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
980 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
981 [](const Shape& shape) { return &shape; });
982 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape(
983 operand_shape_ptrs, dimension));
984 *instr.mutable_shape() = shape.ToProto();
985
986 instr.add_dimensions(dimension);
987
988 return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
989 });
990 }
991
Pad(XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)992 XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value,
993 const PaddingConfig& padding_config) {
994 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
995 HloInstructionProto instr;
996
997 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
998 TF_ASSIGN_OR_RETURN(const Shape* padding_value_shape,
999 GetShapePtr(padding_value));
1000 TF_ASSIGN_OR_RETURN(
1001 Shape shape, ShapeInference::InferPadShape(
1002 *operand_shape, *padding_value_shape, padding_config));
1003 *instr.mutable_shape() = shape.ToProto();
1004 *instr.mutable_padding_config() = padding_config;
1005
1006 return AddInstruction(std::move(instr), HloOpcode::kPad,
1007 {operand, padding_value});
1008 });
1009 }
1010
Reshape(XlaOp operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes,int64 inferred_dimension)1011 XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> dimensions,
1012 absl::Span<const int64> new_sizes,
1013 int64 inferred_dimension) {
1014 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1015 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1016 TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferReshapeShape(
1017 *operand_shape, dimensions,
1018 new_sizes, inferred_dimension));
1019 XlaOp transposed = IsIdentityPermutation(dimensions)
1020 ? operand
1021 : Transpose(operand, dimensions);
1022 return Reshape(shape, transposed, inferred_dimension);
1023 });
1024 }
1025
Reshape(XlaOp operand,absl::Span<const int64> new_sizes,int64 inferred_dimension)1026 XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
1027 int64 inferred_dimension) {
1028 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1029 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
1030 std::vector<int64> dimensions(shape->dimensions_size());
1031 std::iota(dimensions.begin(), dimensions.end(), 0);
1032 return Reshape(operand, dimensions, new_sizes, inferred_dimension);
1033 });
1034 }
1035
Collapse(XlaOp operand,absl::Span<const int64> dimensions)1036 XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span<const int64> dimensions) {
1037 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1038 if (dimensions.size() <= 1) {
1039 // Not collapsing anything, trivially we can return the operand versus
1040 // enqueueing a trivial reshape.
1041 return operand;
1042 }
1043
1044 // Out-of-order collapse is not supported.
1045 // Checks that the collapsed dimensions are in order and consecutive.
1046 for (absl::Span<const int64>::size_type i = 1; i < dimensions.size(); ++i) {
1047 if (dimensions[i] - 1 != dimensions[i - 1]) {
1048 return InvalidArgument(
1049 "Collapsed dimensions are not in consecutive order.");
1050 }
1051 }
1052
1053 // Create a new sizes vector from the old shape, replacing the collapsed
1054 // dimensions by the product of their sizes.
1055 TF_ASSIGN_OR_RETURN(const Shape* original_shape, GetShapePtr(operand));
1056
1057 VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape);
1058 VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ",");
1059
1060 std::vector<int64> new_sizes;
1061 for (int i = 0; i < original_shape->rank(); ++i) {
1062 if (i <= dimensions.front() || i > dimensions.back()) {
1063 new_sizes.push_back(original_shape->dimensions(i));
1064 } else {
1065 new_sizes.back() *= original_shape->dimensions(i);
1066 }
1067 }
1068
1069 VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]";
1070
1071 return Reshape(operand, new_sizes);
1072 });
1073 }
1074
Trace(const string & tag,XlaOp operand)1075 void XlaBuilder::Trace(const string& tag, XlaOp operand) {
1076 ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1077 HloInstructionProto instr;
1078 *instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
1079 *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto();
1080 return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
1081 });
1082 }
1083
Select(XlaOp pred,XlaOp on_true,XlaOp on_false)1084 XlaOp XlaBuilder::Select(XlaOp pred, XlaOp on_true, XlaOp on_false) {
1085 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1086 TF_ASSIGN_OR_RETURN(const Shape* true_shape, GetShapePtr(on_true));
1087 TF_ASSIGN_OR_RETURN(const Shape* false_shape, GetShapePtr(on_false));
1088 TF_RET_CHECK(true_shape->IsTuple() == false_shape->IsTuple());
1089 HloOpcode opcode =
1090 true_shape->IsTuple() ? HloOpcode::kTupleSelect : HloOpcode::kSelect;
1091 return TernaryOp(opcode, pred, on_true, on_false);
1092 });
1093 }
1094
Tuple(absl::Span<const XlaOp> elements)1095 XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
1096 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1097 HloInstructionProto instr;
1098 std::vector<const Shape*> operand_shape_ptrs;
1099 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
1100 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1101 [](const Shape& shape) { return &shape; });
1102 TF_ASSIGN_OR_RETURN(const Shape shape,
1103 ShapeInference::InferVariadicOpShape(
1104 HloOpcode::kTuple, operand_shape_ptrs));
1105 *instr.mutable_shape() = shape.ToProto();
1106 return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
1107 });
1108 }
1109
GetTupleElement(XlaOp tuple_data,int64 index)1110 XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64 index) {
1111 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1112 HloInstructionProto instr;
1113 TF_ASSIGN_OR_RETURN(const Shape* tuple_shape, GetShapePtr(tuple_data));
1114 if (!tuple_shape->IsTuple()) {
1115 return InvalidArgument(
1116 "Operand to GetTupleElement() is not a tuple; got %s",
1117 ShapeUtil::HumanString(*tuple_shape));
1118 }
1119 if (index < 0 || index >= ShapeUtil::TupleElementCount(*tuple_shape)) {
1120 return InvalidArgument(
1121 "GetTupleElement() index (%d) out of range for tuple shape %s", index,
1122 ShapeUtil::HumanString(*tuple_shape));
1123 }
1124 *instr.mutable_shape() =
1125 ShapeUtil::GetTupleElementShape(*tuple_shape, index).ToProto();
1126
1127 instr.set_tuple_index(index);
1128
1129 return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
1130 {tuple_data});
1131 });
1132 }
1133
Dot(XlaOp lhs,XlaOp rhs,const PrecisionConfig * precision_config)1134 XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
1135 const PrecisionConfig* precision_config) {
1136 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1137 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1138
1139 DotDimensionNumbers dimension_numbers;
1140 dimension_numbers.add_lhs_contracting_dimensions(
1141 lhs_shape->dimensions_size() == 1 ? 0 : 1);
1142 dimension_numbers.add_rhs_contracting_dimensions(0);
1143 return DotGeneral(lhs, rhs, dimension_numbers, precision_config);
1144 });
1145 }
1146
DotGeneral(XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config)1147 XlaOp XlaBuilder::DotGeneral(XlaOp lhs, XlaOp rhs,
1148 const DotDimensionNumbers& dimension_numbers,
1149 const PrecisionConfig* precision_config) {
1150 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1151 HloInstructionProto instr;
1152 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1153 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1154 TF_ASSIGN_OR_RETURN(Shape shape,
1155 ShapeInference::InferDotOpShape(*lhs_shape, *rhs_shape,
1156 dimension_numbers));
1157 *instr.mutable_shape() = shape.ToProto();
1158 *instr.mutable_dot_dimension_numbers() = dimension_numbers;
1159 if (precision_config != nullptr) {
1160 *instr.mutable_precision_config() = *precision_config;
1161 }
1162 return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
1163 });
1164 }
1165
VerifyConvolution(const Shape & lhs_shape,const Shape & rhs_shape,const ConvolutionDimensionNumbers & dimension_numbers) const1166 Status XlaBuilder::VerifyConvolution(
1167 const Shape& lhs_shape, const Shape& rhs_shape,
1168 const ConvolutionDimensionNumbers& dimension_numbers) const {
1169 if (lhs_shape.rank() != rhs_shape.rank()) {
1170 return InvalidArgument(
1171 "Convolution arguments must have same number of "
1172 "dimensions. Got: %s and %s",
1173 ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1174 }
1175 int num_dims = lhs_shape.rank();
1176 if (num_dims < 2) {
1177 return InvalidArgument(
1178 "Convolution expects argument arrays with >= 3 dimensions. "
1179 "Got: %s and %s",
1180 ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1181 }
1182 int num_spatial_dims = num_dims - 2;
1183
1184 const auto check_spatial_dimensions =
1185 [&](const char* const field_name,
1186 const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
1187 numbers) {
1188 if (numbers.size() != num_spatial_dims) {
1189 return InvalidArgument("Expected %d elements for %s, but got %d.",
1190 num_spatial_dims, field_name, numbers.size());
1191 }
1192 for (int i = 0; i < numbers.size(); ++i) {
1193 if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
1194 return InvalidArgument("Convolution %s[%d] is out of bounds: %d",
1195 field_name, i, numbers.Get(i));
1196 }
1197 }
1198 return Status::OK();
1199 };
1200 TF_RETURN_IF_ERROR(
1201 check_spatial_dimensions("input_spatial_dimensions",
1202 dimension_numbers.input_spatial_dimensions()));
1203 TF_RETURN_IF_ERROR(
1204 check_spatial_dimensions("kernel_spatial_dimensions",
1205 dimension_numbers.kernel_spatial_dimensions()));
1206 return check_spatial_dimensions(
1207 "output_spatial_dimensions",
1208 dimension_numbers.output_spatial_dimensions());
1209 }
1210
Conv(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1211 XlaOp XlaBuilder::Conv(XlaOp lhs, XlaOp rhs,
1212 absl::Span<const int64> window_strides, Padding padding,
1213 int64 feature_group_count, int64 batch_group_count,
1214 const PrecisionConfig* precision_config) {
1215 return ConvWithGeneralDimensions(
1216 lhs, rhs, window_strides, padding,
1217 CreateDefaultConvDimensionNumbers(window_strides.size()),
1218 feature_group_count, batch_group_count, precision_config);
1219 }
1220
ConvWithGeneralPadding(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1221 XlaOp XlaBuilder::ConvWithGeneralPadding(
1222 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1223 absl::Span<const std::pair<int64, int64>> padding,
1224 int64 feature_group_count, int64 batch_group_count,
1225 const PrecisionConfig* precision_config) {
1226 return ConvGeneral(lhs, rhs, window_strides, padding,
1227 CreateDefaultConvDimensionNumbers(window_strides.size()),
1228 feature_group_count, batch_group_count, precision_config);
1229 }
1230
ConvWithGeneralDimensions(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1231 XlaOp XlaBuilder::ConvWithGeneralDimensions(
1232 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1233 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1234 int64 feature_group_count, int64 batch_group_count,
1235 const PrecisionConfig* precision_config) {
1236 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1237 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1238 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1239
1240 TF_RETURN_IF_ERROR(
1241 VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers));
1242
1243 std::vector<int64> base_area_dimensions(
1244 dimension_numbers.input_spatial_dimensions_size());
1245 for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size();
1246 ++i) {
1247 base_area_dimensions[i] =
1248 lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i));
1249 }
1250
1251 std::vector<int64> window_dimensions(
1252 dimension_numbers.kernel_spatial_dimensions_size());
1253 for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
1254 ++i) {
1255 window_dimensions[i] =
1256 rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1257 }
1258
1259 return ConvGeneral(lhs, rhs, window_strides,
1260 MakePadding(base_area_dimensions, window_dimensions,
1261 window_strides, padding),
1262 dimension_numbers, feature_group_count,
1263 batch_group_count, precision_config);
1264 });
1265 }
1266
ConvGeneral(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1267 XlaOp XlaBuilder::ConvGeneral(
1268 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1269 absl::Span<const std::pair<int64, int64>> padding,
1270 const ConvolutionDimensionNumbers& dimension_numbers,
1271 int64 feature_group_count, int64 batch_group_count,
1272 const PrecisionConfig* precision_config) {
1273 return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
1274 dimension_numbers, feature_group_count,
1275 batch_group_count, precision_config);
1276 }
1277
ConvGeneralDilated(XlaOp lhs,XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1278 XlaOp XlaBuilder::ConvGeneralDilated(
1279 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1280 absl::Span<const std::pair<int64, int64>> padding,
1281 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1282 const ConvolutionDimensionNumbers& dimension_numbers,
1283 int64 feature_group_count, int64 batch_group_count,
1284 const PrecisionConfig* precision_config) {
1285 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1286 HloInstructionProto instr;
1287 TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
1288 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
1289 TF_RETURN_IF_ERROR(
1290 VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers));
1291
1292 std::vector<int64> window_dimensions(
1293 dimension_numbers.kernel_spatial_dimensions_size());
1294 for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
1295 ++i) {
1296 window_dimensions[i] =
1297 rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1298 }
1299 TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
1300 ShapeInference::InferWindowFromDimensions(
1301 window_dimensions, window_strides, padding,
1302 lhs_dilation, rhs_dilation));
1303
1304 TF_ASSIGN_OR_RETURN(
1305 Shape shape, ShapeInference::InferConvolveShape(
1306 *lhs_shape, *rhs_shape, feature_group_count,
1307 batch_group_count, instr.window(), dimension_numbers));
1308 *instr.mutable_shape() = shape.ToProto();
1309
1310 *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
1311 instr.set_feature_group_count(feature_group_count);
1312 instr.set_batch_group_count(batch_group_count);
1313
1314 if (precision_config != nullptr) {
1315 *instr.mutable_precision_config() = *precision_config;
1316 }
1317
1318 return AddInstruction(std::move(instr), HloOpcode::kConvolution,
1319 {lhs, rhs});
1320 });
1321 }
1322
Fft(XlaOp operand,const FftType fft_type,const absl::Span<const int64> fft_length)1323 XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type,
1324 const absl::Span<const int64> fft_length) {
1325 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1326 HloInstructionProto instr;
1327 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1328 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape(
1329 *operand_shape, fft_type, fft_length));
1330 *instr.mutable_shape() = shape.ToProto();
1331 instr.set_fft_type(fft_type);
1332 for (int64 i : fft_length) {
1333 instr.add_fft_length(i);
1334 }
1335
1336 return AddInstruction(std::move(instr), HloOpcode::kFft, {operand});
1337 });
1338 }
1339
Infeed(const Shape & shape,const string & config)1340 XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
1341 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1342 HloInstructionProto instr;
1343 if (!LayoutUtil::HasLayout(shape)) {
1344 return InvalidArgument("Given shape to Infeed must have a layout");
1345 }
1346 const Shape infeed_instruction_shape =
1347 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1348 *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1349 instr.set_infeed_config(config);
1350
1351 if (shape.IsArray() && sharding() &&
1352 sharding()->type() == OpSharding::OTHER) {
1353 // TODO(b/110793772): Support tiled array-shaped infeeds.
1354 return InvalidArgument(
1355 "Tiled sharding is not yet supported for array-shaped infeeds");
1356 }
1357
1358 if (sharding() && sharding()->type() == OpSharding::REPLICATED) {
1359 return InvalidArgument(
1360 "Replicated sharding is not yet supported for infeeds");
1361 }
1362
1363 // Infeed takes a single token operand. Generate the token to pass to the
1364 // infeed.
1365 XlaOp token;
1366 auto make_token = [&]() {
1367 HloInstructionProto token_instr;
1368 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1369 return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {});
1370 };
1371 if (sharding()) {
1372 // Arbitrarily assign token to device 0.
1373 OpSharding sharding = sharding_builder::AssignDevice(0);
1374 XlaScopedShardingAssignment scoped_sharding(this, sharding);
1375 TF_ASSIGN_OR_RETURN(token, make_token());
1376 } else {
1377 TF_ASSIGN_OR_RETURN(token, make_token());
1378 }
1379
1380 // The sharding is set by the client according to the data tuple shape.
1381 // However, the shape of the infeed instruction is a tuple containing the
1382 // data and a token. For tuple sharding type, the sharding must be changed
1383 // to accommodate the token.
1384 XlaOp infeed;
1385 if (sharding() && sharding()->type() == OpSharding::TUPLE) {
1386 // TODO(b/80000000): Remove this when clients have been updated to handle
1387 // tokens.
1388 OpSharding infeed_instruction_sharding = *sharding();
1389 // Arbitrarily assign the token to device 0.
1390 *infeed_instruction_sharding.add_tuple_shardings() =
1391 sharding_builder::AssignDevice(0);
1392 XlaScopedShardingAssignment scoped_sharding(this,
1393 infeed_instruction_sharding);
1394 TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1395 HloOpcode::kInfeed, {token}));
1396 } else {
1397 TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1398 HloOpcode::kInfeed, {token}));
1399 }
1400
1401 // The infeed instruction produces a tuple of the infed data and a token
1402 // type. Return XLA op containing the data.
1403 // TODO(b/80000000): Remove this when clients have been updated to handle
1404 // tokens.
1405 HloInstructionProto infeed_data;
1406 *infeed_data.mutable_shape() = shape.ToProto();
1407 infeed_data.set_tuple_index(0);
1408 return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement,
1409 {infeed});
1410 });
1411 }
1412
InfeedWithToken(XlaOp token,const Shape & shape,const string & config)1413 XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape,
1414 const string& config) {
1415 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1416 HloInstructionProto instr;
1417 if (!LayoutUtil::HasLayout(shape)) {
1418 return InvalidArgument("Given shape to Infeed must have a layout");
1419 }
1420 const Shape infeed_instruction_shape =
1421 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1422 *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1423 instr.set_infeed_config(config);
1424
1425 if (shape.IsArray() && sharding() &&
1426 sharding()->type() == OpSharding::OTHER) {
1427 // TODO(b/110793772): Support tiled array-shaped infeeds.
1428 return InvalidArgument(
1429 "Tiled sharding is not yet supported for array-shaped infeeds");
1430 }
1431
1432 if (sharding() && sharding()->type() == OpSharding::REPLICATED) {
1433 return InvalidArgument(
1434 "Replicated sharding is not yet supported for infeeds");
1435 }
1436
1437 return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
1438 });
1439 }
1440
Outfeed(XlaOp operand,const Shape & shape_with_layout,const string & outfeed_config)1441 void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout,
1442 const string& outfeed_config) {
1443 ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1444 HloInstructionProto instr;
1445
1446 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1447
1448 // Check and set outfeed shape.
1449 if (!LayoutUtil::HasLayout(shape_with_layout)) {
1450 return InvalidArgument("Given shape to Outfeed must have a layout");
1451 }
1452 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1453 if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
1454 return InvalidArgument(
1455 "Outfeed shape %s must be compatible with operand shape %s",
1456 ShapeUtil::HumanStringWithLayout(shape_with_layout),
1457 ShapeUtil::HumanStringWithLayout(*operand_shape));
1458 }
1459 *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1460
1461 instr.set_outfeed_config(outfeed_config);
1462
1463 // Outfeed takes a token as its second operand. Generate the token to pass
1464 // to the outfeed.
1465 HloInstructionProto token_instr;
1466 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1467 TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
1468 HloOpcode::kAfterAll, {}));
1469
1470 TF_RETURN_IF_ERROR(
1471 AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand, token})
1472 .status());
1473
1474 // The outfeed instruction produces a token. However, existing users expect
1475 // a nil shape (empty tuple). This should only be relevant if the outfeed is
1476 // the root of a computation.
1477 // TODO(b/80000000): Remove this when clients have been updated to handle
1478 // tokens.
1479 HloInstructionProto tuple_instr;
1480 *tuple_instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
1481
1482 // The dummy tuple should have no sharding.
1483 {
1484 XlaScopedShardingAssignment scoped_sharding(this, OpSharding());
1485 TF_ASSIGN_OR_RETURN(
1486 XlaOp empty_tuple,
1487 AddInstruction(std::move(tuple_instr), HloOpcode::kTuple, {}));
1488 return empty_tuple;
1489 }
1490 });
1491 }
1492
OutfeedWithToken(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)1493 XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token,
1494 const Shape& shape_with_layout,
1495 const string& outfeed_config) {
1496 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1497 HloInstructionProto instr;
1498
1499 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1500
1501 // Check and set outfeed shape.
1502 if (!LayoutUtil::HasLayout(shape_with_layout)) {
1503 return InvalidArgument("Given shape to Outfeed must have a layout");
1504 }
1505 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1506 if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
1507 return InvalidArgument(
1508 "Outfeed shape %s must be compatible with operand shape %s",
1509 ShapeUtil::HumanStringWithLayout(shape_with_layout),
1510 ShapeUtil::HumanStringWithLayout(*operand_shape));
1511 }
1512 *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1513
1514 instr.set_outfeed_config(outfeed_config);
1515
1516 return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
1517 {operand, token});
1518 });
1519 }
1520
CreateToken()1521 XlaOp XlaBuilder::CreateToken() {
1522 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1523 HloInstructionProto instr;
1524 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1525 return AddInstruction(std::move(instr), HloOpcode::kAfterAll);
1526 });
1527 }
1528
AfterAll(absl::Span<const XlaOp> tokens)1529 XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
1530 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1531 if (tokens.empty()) {
1532 return InvalidArgument("AfterAll requires at least one operand");
1533 }
1534 for (int i = 0; i < tokens.size(); ++i) {
1535 XlaOp operand = tokens[i];
1536 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1537 if (!operand_shape->IsToken()) {
1538 return InvalidArgument(
1539 "All operands to AfterAll must be tokens; operand %d has shape %s",
1540 i, ShapeUtil::HumanString(*operand_shape));
1541 }
1542 }
1543 HloInstructionProto instr;
1544 *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1545 return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens);
1546 });
1547 }
1548
CustomCall(const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout)1549 XlaOp XlaBuilder::CustomCall(
1550 const string& call_target_name, absl::Span<const XlaOp> operands,
1551 const Shape& shape, const string& opaque,
1552 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
1553 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1554 HloInstructionProto instr;
1555 if (absl::StartsWith(call_target_name, "$")) {
1556 return InvalidArgument(
1557 "Invalid custom_call_target \"%s\": Call targets that start with '$' "
1558 "are reserved for internal use.",
1559 call_target_name);
1560 }
1561 *instr.mutable_shape() = shape.ToProto();
1562 instr.set_custom_call_target(call_target_name);
1563 instr.set_backend_config(opaque);
1564 if (operand_shapes_with_layout.has_value()) {
1565 if (!LayoutUtil::HasLayout(shape)) {
1566 return InvalidArgument(
1567 "Result shape must have layout for custom call with constrained "
1568 "layout.");
1569 }
1570 if (operands.size() != operand_shapes_with_layout->size()) {
1571 return InvalidArgument(
1572 "Must specify a shape with layout for each operand for custom call "
1573 "with constrained layout; given %d shapes, expected %d",
1574 operand_shapes_with_layout->size(), operands.size());
1575 }
1576 instr.set_constrain_layout(true);
1577 int64 operand_num = 0;
1578 for (const Shape& operand_shape : *operand_shapes_with_layout) {
1579 if (!LayoutUtil::HasLayout(operand_shape)) {
1580 return InvalidArgument(
1581 "No layout specified for operand %d for custom call with "
1582 "constrained layout.",
1583 operand_num);
1584 }
1585 *instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
1586 ++operand_num;
1587 }
1588 }
1589 return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
1590 });
1591 }
1592
Transpose(XlaOp operand,absl::Span<const int64> permutation)1593 XlaOp XlaBuilder::Transpose(XlaOp operand,
1594 absl::Span<const int64> permutation) {
1595 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1596 HloInstructionProto instr;
1597 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1598 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape(
1599 *operand_shape, permutation));
1600 *instr.mutable_shape() = shape.ToProto();
1601 for (int64 dim : permutation) {
1602 instr.add_dimensions(dim);
1603 }
1604 return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
1605 });
1606 }
1607
Rev(XlaOp operand,absl::Span<const int64> dimensions)1608 XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span<const int64> dimensions) {
1609 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1610 HloInstructionProto instr;
1611 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1612 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape(
1613 *operand_shape, dimensions));
1614 *instr.mutable_shape() = shape.ToProto();
1615 for (int64 dim : dimensions) {
1616 instr.add_dimensions(dim);
1617 }
1618 return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand});
1619 });
1620 }
1621
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64 dimension,bool is_stable)1622 XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
1623 const XlaComputation& comparator, int64 dimension,
1624 bool is_stable) {
1625 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1626 HloInstructionProto instr;
1627 instr.set_is_stable(is_stable);
1628 std::vector<const Shape*> operand_shape_ptrs;
1629 TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes,
1630 GetOperandShapes(operands));
1631 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1632 [](const Shape& shape) { return &shape; });
1633 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape(
1634 HloOpcode::kSort, operand_shape_ptrs));
1635 *instr.mutable_shape() = shape.ToProto();
1636 if (dimension == -1) {
1637 TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0]));
1638 dimension = keys_shape->rank() - 1;
1639 }
1640 instr.add_dimensions(dimension);
1641 AddCalledComputation(comparator, &instr);
1642 return AddInstruction(std::move(instr), HloOpcode::kSort, operands);
1643 });
1644 }
1645
ConvertElementType(XlaOp operand,PrimitiveType new_element_type)1646 XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
1647 PrimitiveType new_element_type) {
1648 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1649 HloInstructionProto instr;
1650 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1651 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
1652 *operand_shape, new_element_type));
1653 *instr.mutable_shape() = shape.ToProto();
1654 return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand});
1655 });
1656 }
1657
BitcastConvertType(XlaOp operand,PrimitiveType new_element_type)1658 XlaOp XlaBuilder::BitcastConvertType(XlaOp operand,
1659 PrimitiveType new_element_type) {
1660 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1661 HloInstructionProto instr;
1662 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1663 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
1664 *operand_shape, new_element_type));
1665 *instr.mutable_shape() = shape.ToProto();
1666 return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
1667 {operand});
1668 });
1669 }
1670
Clamp(XlaOp min,XlaOp operand,XlaOp max)1671 XlaOp XlaBuilder::Clamp(XlaOp min, XlaOp operand, XlaOp max) {
1672 return TernaryOp(HloOpcode::kClamp, min, operand, max);
1673 }
1674
Map(absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64> dimensions,absl::Span<const XlaOp> static_operands)1675 XlaOp XlaBuilder::Map(absl::Span<const XlaOp> operands,
1676 const XlaComputation& computation,
1677 absl::Span<const int64> dimensions,
1678 absl::Span<const XlaOp> static_operands) {
1679 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1680 if (!static_operands.empty()) {
1681 return Unimplemented("static_operands is not supported in Map");
1682 }
1683
1684 HloInstructionProto instr;
1685 std::vector<const Shape*> operand_shape_ptrs;
1686 TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
1687 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1688 [](const Shape& shape) { return &shape; });
1689 TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
1690 computation.GetProgramShape());
1691 TF_ASSIGN_OR_RETURN(
1692 Shape shape, ShapeInference::InferMapShape(
1693 operand_shape_ptrs, called_program_shape, dimensions));
1694 *instr.mutable_shape() = shape.ToProto();
1695
1696 Shape output_shape(instr.shape());
1697 const int64 output_rank = output_shape.rank();
1698 AddCalledComputation(computation, &instr);
1699 std::vector<XlaOp> new_operands(operands.begin(), operands.end());
1700 for (XlaOp& new_operand : new_operands) {
1701 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(new_operand));
1702 const int64 rank = shape->rank();
1703 if (rank != output_rank) {
1704 TF_ASSIGN_OR_RETURN(new_operand,
1705 InDimBroadcast(output_shape, new_operand, {}));
1706 TF_ASSIGN_OR_RETURN(shape, GetShapePtr(new_operand));
1707 }
1708 if (!ShapeUtil::SameDimensions(output_shape, *shape)) {
1709 TF_ASSIGN_OR_RETURN(new_operand,
1710 AddBroadcastSequence(output_shape, new_operand));
1711 }
1712 }
1713
1714 return AddInstruction(std::move(instr), HloOpcode::kMap, new_operands);
1715 });
1716 }
1717
RngOp(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)1718 XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
1719 absl::Span<const XlaOp> parameters,
1720 const Shape& shape) {
1721 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1722 HloInstructionProto instr;
1723
1724 // Check the number of parameters per RNG distribution.
1725 switch (distribution) {
1726 case RandomDistribution::RNG_NORMAL:
1727 case RandomDistribution::RNG_UNIFORM:
1728 if (parameters.size() != 2) {
1729 return InvalidArgument(
1730 "RNG distribution (%s) expects 2 parameters, but got %ld",
1731 RandomDistribution_Name(distribution), parameters.size());
1732 }
1733 break;
1734 default:
1735 LOG(FATAL) << "unhandled distribution " << distribution;
1736 }
1737
1738 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
1739 *instr.mutable_shape() = shape.ToProto();
1740
1741 instr.set_distribution(distribution);
1742
1743 return AddInstruction(std::move(instr), HloOpcode::kRng, parameters);
1744 });
1745 }
1746
RngNormal(XlaOp mu,XlaOp sigma,const Shape & shape)1747 XlaOp XlaBuilder::RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape) {
1748 return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
1749 }
1750
RngUniform(XlaOp a,XlaOp b,const Shape & shape)1751 XlaOp XlaBuilder::RngUniform(XlaOp a, XlaOp b, const Shape& shape) {
1752 return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
1753 }
1754
While(const XlaComputation & condition,const XlaComputation & body,XlaOp init)1755 XlaOp XlaBuilder::While(const XlaComputation& condition,
1756 const XlaComputation& body, XlaOp init) {
1757 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1758 HloInstructionProto instr;
1759
1760 // Infer shape.
1761 TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape());
1762 TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
1763 condition.GetProgramShape());
1764 TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init));
1765 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape(
1766 condition_program_shape,
1767 body_program_shape, *init_shape));
1768 *instr.mutable_shape() = shape.ToProto();
1769 // Body comes before condition computation in the vector.
1770 AddCalledComputation(body, &instr);
1771 AddCalledComputation(condition, &instr);
1772 return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
1773 });
1774 }
1775
Gather(XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)1776 XlaOp XlaBuilder::Gather(XlaOp input, XlaOp start_indices,
1777 const GatherDimensionNumbers& dimension_numbers,
1778 absl::Span<const int64> slice_sizes,
1779 bool indices_are_sorted) {
1780 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1781 HloInstructionProto instr;
1782 instr.set_indices_are_sorted(indices_are_sorted);
1783
1784 TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
1785 TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape,
1786 GetShapePtr(start_indices));
1787 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGatherShape(
1788 *input_shape, *start_indices_shape,
1789 dimension_numbers, slice_sizes));
1790 *instr.mutable_shape() = shape.ToProto();
1791
1792 *instr.mutable_gather_dimension_numbers() = dimension_numbers;
1793 for (int64 bound : slice_sizes) {
1794 instr.add_gather_slice_sizes(bound);
1795 }
1796
1797 return AddInstruction(std::move(instr), HloOpcode::kGather,
1798 {input, start_indices});
1799 });
1800 }
1801
Scatter(XlaOp input,XlaOp scatter_indices,XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)1802 XlaOp XlaBuilder::Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
1803 const XlaComputation& update_computation,
1804 const ScatterDimensionNumbers& dimension_numbers,
1805 bool indices_are_sorted, bool unique_indices) {
1806 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1807 HloInstructionProto instr;
1808 instr.set_indices_are_sorted(indices_are_sorted);
1809
1810 instr.set_unique_indices(unique_indices);
1811
1812 TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
1813 TF_ASSIGN_OR_RETURN(const Shape* scatter_indices_shape,
1814 GetShapePtr(scatter_indices));
1815 TF_ASSIGN_OR_RETURN(const Shape* updates_shape, GetShapePtr(updates));
1816 TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
1817 update_computation.GetProgramShape());
1818 TF_ASSIGN_OR_RETURN(
1819 Shape shape, ShapeInference::InferScatterShape(
1820 *input_shape, *scatter_indices_shape, *updates_shape,
1821 to_apply_shape, dimension_numbers));
1822 *instr.mutable_shape() = shape.ToProto();
1823
1824 *instr.mutable_scatter_dimension_numbers() = dimension_numbers;
1825
1826 AddCalledComputation(update_computation, &instr);
1827 return AddInstruction(std::move(instr), HloOpcode::kScatter,
1828 {input, scatter_indices, updates});
1829 });
1830 }
1831
Conditional(XlaOp predicate,XlaOp true_operand,const XlaComputation & true_computation,XlaOp false_operand,const XlaComputation & false_computation)1832 XlaOp XlaBuilder::Conditional(XlaOp predicate, XlaOp true_operand,
1833 const XlaComputation& true_computation,
1834 XlaOp false_operand,
1835 const XlaComputation& false_computation) {
1836 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1837 TF_ASSIGN_OR_RETURN(const xla::Shape* shape, GetShapePtr(predicate));
1838
1839 if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != PRED) {
1840 return InvalidArgument(
1841 "Argument to predicated-Conditional is not a scalar of PRED type "
1842 "(%s).",
1843 ShapeUtil::HumanString(*shape));
1844 }
1845 // The index of true_computation must be 0 and that of false computation
1846 // must be 1.
1847 return ConditionalImpl(predicate, {&true_computation, &false_computation},
1848 {true_operand, false_operand});
1849 });
1850 }
1851
Conditional(XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)1852 XlaOp XlaBuilder::Conditional(
1853 XlaOp branch_index,
1854 absl::Span<const XlaComputation* const> branch_computations,
1855 absl::Span<const XlaOp> branch_operands) {
1856 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1857 TF_ASSIGN_OR_RETURN(const xla::Shape* shape, GetShapePtr(branch_index));
1858
1859 if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != S32) {
1860 return InvalidArgument(
1861 "Argument to indexed-Conditional is not a scalar of S32 type (%s).",
1862 ShapeUtil::HumanString(*shape));
1863 }
1864 return ConditionalImpl(branch_index, branch_computations, branch_operands);
1865 });
1866 }
1867
ConditionalImpl(XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)1868 XlaOp XlaBuilder::ConditionalImpl(
1869 XlaOp branch_index,
1870 absl::Span<const XlaComputation* const> branch_computations,
1871 absl::Span<const XlaOp> branch_operands) {
1872 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1873 HloInstructionProto instr;
1874
1875 TF_ASSIGN_OR_RETURN(const Shape* branch_index_shape,
1876 GetShapePtr(branch_index));
1877 std::vector<Shape> branch_operand_shapes(branch_operands.size());
1878 std::vector<ProgramShape> branch_computation_shapes(
1879 branch_computations.size());
1880 for (int j = 0; j < branch_operands.size(); ++j) {
1881 TF_ASSIGN_OR_RETURN(branch_operand_shapes[j],
1882 GetShape(branch_operands[j]));
1883 TF_ASSIGN_OR_RETURN(branch_computation_shapes[j],
1884 branch_computations[j]->GetProgramShape());
1885 }
1886 TF_ASSIGN_OR_RETURN(const Shape shape,
1887 ShapeInference::InferConditionalShape(
1888 *branch_index_shape, branch_computation_shapes,
1889 branch_operand_shapes));
1890 *instr.mutable_shape() = shape.ToProto();
1891
1892 for (const XlaComputation* branch_computation : branch_computations) {
1893 AddCalledComputation(*branch_computation, &instr);
1894 }
1895
1896 std::vector<XlaOp> operands(1, branch_index);
1897 for (const XlaOp branch_operand : branch_operands) {
1898 operands.emplace_back(branch_operand);
1899 }
1900 return AddInstruction(std::move(instr), HloOpcode::kConditional,
1901 absl::MakeSpan(operands));
1902 });
1903 }
1904
Reduce(XlaOp operand,XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)1905 XlaOp XlaBuilder::Reduce(XlaOp operand, XlaOp init_value,
1906 const XlaComputation& computation,
1907 absl::Span<const int64> dimensions_to_reduce) {
1908 return Reduce(absl::Span<const XlaOp>({operand}),
1909 absl::Span<const XlaOp>({init_value}), computation,
1910 dimensions_to_reduce);
1911 }
1912
Reduce(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)1913 XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
1914 absl::Span<const XlaOp> init_values,
1915 const XlaComputation& computation,
1916 absl::Span<const int64> dimensions_to_reduce) {
1917 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1918 HloInstructionProto instr;
1919
1920 TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
1921 computation.GetProgramShape());
1922
1923 std::vector<XlaOp> all_operands;
1924 all_operands.insert(all_operands.end(), operands.begin(), operands.end());
1925 all_operands.insert(all_operands.end(), init_values.begin(),
1926 init_values.end());
1927
1928 std::vector<const Shape*> operand_shape_ptrs;
1929 TF_ASSIGN_OR_RETURN(const auto& operand_shapes,
1930 GetOperandShapes(all_operands));
1931 absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1932 [](const Shape& shape) { return &shape; });
1933
1934 TF_ASSIGN_OR_RETURN(
1935 Shape shape,
1936 ShapeInference::InferReduceShape(
1937 operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
1938 *instr.mutable_shape() = shape.ToProto();
1939
1940 for (int64 dim : dimensions_to_reduce) {
1941 instr.add_dimensions(dim);
1942 }
1943
1944 AddCalledComputation(computation, &instr);
1945
1946 return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands);
1947 });
1948 }
1949
ReduceAll(XlaOp operand,XlaOp init_value,const XlaComputation & computation)1950 XlaOp XlaBuilder::ReduceAll(XlaOp operand, XlaOp init_value,
1951 const XlaComputation& computation) {
1952 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1953 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1954 std::vector<int64> all_dimnos(operand_shape->rank());
1955 std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
1956 return Reduce(operand, init_value, computation, all_dimnos);
1957 });
1958 }
1959
ReduceWindow(XlaOp operand,XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)1960 XlaOp XlaBuilder::ReduceWindow(XlaOp operand, XlaOp init_value,
1961 const XlaComputation& computation,
1962 absl::Span<const int64> window_dimensions,
1963 absl::Span<const int64> window_strides,
1964 Padding padding) {
1965 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1966 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1967 TF_RETURN_IF_ERROR(
1968 ValidatePaddingValues(AsInt64Slice(operand_shape->dimensions()),
1969 window_dimensions, window_strides));
1970
1971 std::vector<std::pair<int64, int64>> padding_values =
1972 MakePadding(AsInt64Slice(operand_shape->dimensions()),
1973 window_dimensions, window_strides, padding);
1974 return ReduceWindowWithGeneralPadding(
1975 operand, init_value, computation, window_dimensions, window_strides,
1976 /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values);
1977 });
1978 }
1979
ReduceWindowWithGeneralPadding(XlaOp operand,XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)1980 XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
1981 XlaOp operand, XlaOp init_value, const XlaComputation& computation,
1982 absl::Span<const int64> window_dimensions,
1983 absl::Span<const int64> window_strides,
1984 absl::Span<const int64> base_dilations,
1985 absl::Span<const int64> window_dilations,
1986 absl::Span<const std::pair<int64, int64>> padding) {
1987 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1988 HloInstructionProto instr;
1989
1990 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
1991 TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
1992 TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
1993 computation.GetProgramShape());
1994 TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
1995 ShapeInference::InferWindowFromDimensions(
1996 window_dimensions, window_strides, padding,
1997 /*lhs_dilation=*/base_dilations,
1998 /*rhs_dilation=*/window_dilations));
1999 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape(
2000 *operand_shape, *init_shape,
2001 instr.window(), to_apply_shape));
2002 *instr.mutable_shape() = shape.ToProto();
2003
2004 AddCalledComputation(computation, &instr);
2005 return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
2006 {operand, init_value});
2007 });
2008 }
2009
BatchNormTraining(XlaOp operand,XlaOp scale,XlaOp offset,float epsilon,int64 feature_index)2010 XlaOp XlaBuilder::BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
2011 float epsilon, int64 feature_index) {
2012 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2013 HloInstructionProto instr;
2014
2015 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2016 TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2017 TF_ASSIGN_OR_RETURN(const Shape* offset_shape, GetShapePtr(offset));
2018 TF_ASSIGN_OR_RETURN(
2019 Shape shape,
2020 ShapeInference::InferBatchNormTrainingShape(
2021 *operand_shape, *scale_shape, *offset_shape, feature_index));
2022 *instr.mutable_shape() = shape.ToProto();
2023
2024 instr.set_epsilon(epsilon);
2025 instr.set_feature_index(feature_index);
2026
2027 return AddInstruction(std::move(instr), HloOpcode::kBatchNormTraining,
2028 {operand, scale, offset});
2029 });
2030 }
2031
BatchNormInference(XlaOp operand,XlaOp scale,XlaOp offset,XlaOp mean,XlaOp variance,float epsilon,int64 feature_index)2032 XlaOp XlaBuilder::BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset,
2033 XlaOp mean, XlaOp variance, float epsilon,
2034 int64 feature_index) {
2035 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2036 HloInstructionProto instr;
2037
2038 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2039 TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2040 TF_ASSIGN_OR_RETURN(const Shape* offset_shape, GetShapePtr(offset));
2041 TF_ASSIGN_OR_RETURN(const Shape* mean_shape, GetShapePtr(mean));
2042 TF_ASSIGN_OR_RETURN(const Shape* variance_shape, GetShapePtr(variance));
2043 TF_ASSIGN_OR_RETURN(Shape shape,
2044 ShapeInference::InferBatchNormInferenceShape(
2045 *operand_shape, *scale_shape, *offset_shape,
2046 *mean_shape, *variance_shape, feature_index));
2047 *instr.mutable_shape() = shape.ToProto();
2048
2049 instr.set_epsilon(epsilon);
2050 instr.set_feature_index(feature_index);
2051
2052 return AddInstruction(std::move(instr), HloOpcode::kBatchNormInference,
2053 {operand, scale, offset, mean, variance});
2054 });
2055 }
2056
BatchNormGrad(XlaOp operand,XlaOp scale,XlaOp batch_mean,XlaOp batch_var,XlaOp grad_output,float epsilon,int64 feature_index)2057 XlaOp XlaBuilder::BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
2058 XlaOp batch_var, XlaOp grad_output,
2059 float epsilon, int64 feature_index) {
2060 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2061 HloInstructionProto instr;
2062
2063 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2064 TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
2065 TF_ASSIGN_OR_RETURN(const Shape* batch_mean_shape, GetShapePtr(batch_mean));
2066 TF_ASSIGN_OR_RETURN(const Shape* batch_var_shape, GetShapePtr(batch_var));
2067 TF_ASSIGN_OR_RETURN(const Shape* grad_output_shape,
2068 GetShapePtr(grad_output));
2069 TF_ASSIGN_OR_RETURN(
2070 Shape shape, ShapeInference::InferBatchNormGradShape(
2071 *operand_shape, *scale_shape, *batch_mean_shape,
2072 *batch_var_shape, *grad_output_shape, feature_index));
2073 *instr.mutable_shape() = shape.ToProto();
2074
2075 instr.set_epsilon(epsilon);
2076 instr.set_feature_index(feature_index);
2077
2078 return AddInstruction(std::move(instr), HloOpcode::kBatchNormGrad,
2079 {operand, scale, batch_mean, batch_var, grad_output});
2080 });
2081 }
2082
CrossReplicaSum(XlaOp operand,absl::Span<const ReplicaGroup> replica_groups)2083 XlaOp XlaBuilder::CrossReplicaSum(
2084 XlaOp operand, absl::Span<const ReplicaGroup> replica_groups) {
2085 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2086 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
2087 const Shape* element_shape;
2088 if (shape->IsTuple()) {
2089 if (shape->tuple_shapes_size() == 0) {
2090 return Unimplemented(
2091 "0 element tuple CrossReplicaSum is not supported");
2092 }
2093 element_shape = &shape->tuple_shapes(0);
2094 } else {
2095 element_shape = shape;
2096 }
2097 const Shape scalar_shape =
2098 ShapeUtil::MakeShape(element_shape->element_type(), {});
2099 auto b = CreateSubBuilder("sum");
2100 auto x = b->Parameter(/*parameter_number=*/0, scalar_shape, "x");
2101 auto y = b->Parameter(/*parameter_number=*/1, scalar_shape, "y");
2102 if (scalar_shape.element_type() == PRED) {
2103 Or(x, y);
2104 } else {
2105 Add(x, y);
2106 }
2107 TF_ASSIGN_OR_RETURN(auto computation, b->Build());
2108 return AllReduce(operand, computation, replica_groups,
2109 /*channel_id=*/absl::nullopt);
2110 });
2111 }
2112
AllReduce(XlaOp operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Shape> & shape_with_layout)2113 XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
2114 absl::Span<const ReplicaGroup> replica_groups,
2115 const absl::optional<ChannelHandle>& channel_id,
2116 const absl::optional<Shape>& shape_with_layout) {
2117 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2118 HloInstructionProto instr;
2119 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2120 std::vector<const Shape*> operand_shapes;
2121 std::vector<XlaOp> operands;
2122 if (operand_shape->IsTuple()) {
2123 if (operand_shape->tuple_shapes_size() == 0) {
2124 return Unimplemented("0 element tuple AllReduce is not supported");
2125 }
2126 for (int64 i = 0; i < operand_shape->tuple_shapes_size(); ++i) {
2127 if (operand_shape->tuple_shapes(i).element_type() !=
2128 operand_shape->tuple_shapes(0).element_type()) {
2129 return Unimplemented(
2130 "All the shapes of a tuple input of AllReduce must have the same "
2131 "element type");
2132 }
2133 operand_shapes.push_back(&operand_shape->tuple_shapes(i));
2134 operands.push_back(GetTupleElement(operand, i));
2135 }
2136 } else {
2137 operand_shapes.push_back(operand_shape);
2138 operands.push_back(operand);
2139 }
2140
2141 TF_ASSIGN_OR_RETURN(Shape inferred_shape,
2142 ShapeInference::InferAllReduceShape(operand_shapes));
2143 if (shape_with_layout) {
2144 if (!LayoutUtil::HasLayout(*shape_with_layout)) {
2145 return InvalidArgument("shape_with_layout must have the layout set: %s",
2146 shape_with_layout->ToString());
2147 }
2148 if (!ShapeUtil::Compatible(*shape_with_layout, *operand_shape)) {
2149 return InvalidArgument(
2150 "Provided shape_with_layout must be compatible with the "
2151 "operand shape: %s vs %s",
2152 shape_with_layout->ToString(), operand_shape->ToString());
2153 }
2154 instr.set_constrain_layout(true);
2155 if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
2156 // For a single-element tuple, take the tuple element shape.
2157 TF_RET_CHECK(shape_with_layout->tuple_shapes_size() == 1);
2158 *instr.mutable_shape() = shape_with_layout->tuple_shapes(0).ToProto();
2159 } else {
2160 *instr.mutable_shape() = shape_with_layout->ToProto();
2161 }
2162 } else {
2163 *instr.mutable_shape() = inferred_shape.ToProto();
2164 }
2165
2166 for (const ReplicaGroup& group : replica_groups) {
2167 *instr.add_replica_groups() = group;
2168 }
2169
2170 if (channel_id.has_value()) {
2171 instr.set_channel_id(channel_id->handle());
2172 }
2173
2174 AddCalledComputation(computation, &instr);
2175
2176 TF_ASSIGN_OR_RETURN(
2177 auto all_reduce,
2178 AddInstruction(std::move(instr), HloOpcode::kAllReduce, operands));
2179 if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
2180 // For a single-element tuple, wrap the result into a tuple.
2181 TF_RET_CHECK(operand_shapes.size() == 1);
2182 TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], inferred_shape));
2183 return Tuple({all_reduce});
2184 }
2185 return all_reduce;
2186 });
2187 }
2188
AllToAll(XlaOp operand,int64 split_dimension,int64 concat_dimension,int64 split_count,const std::vector<ReplicaGroup> & replica_groups)2189 XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension,
2190 int64 concat_dimension, int64 split_count,
2191 const std::vector<ReplicaGroup>& replica_groups) {
2192 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2193 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2194
2195 // The HloInstruction for Alltoall currently only handles the data
2196 // communication: it accepts N already split parts and scatters them to N
2197 // cores, and each core gathers the N received parts into a tuple as the
2198 // output. So here we explicitly split the operand before the hlo alltoall,
2199 // and concat the tuple elements.
2200 //
2201 // First, run shape inference to make sure the shapes are valid.
2202 TF_RETURN_IF_ERROR(
2203 ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
2204 concat_dimension, split_count)
2205 .status());
2206
2207 // Split into N parts.
2208 std::vector<XlaOp> slices;
2209 slices.reserve(split_count);
2210 const int64 block_size =
2211 operand_shape->dimensions(split_dimension) / split_count;
2212 for (int i = 0; i < split_count; i++) {
2213 slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size,
2214 /*limit_index=*/(i + 1) * block_size,
2215 /*stride=*/1, /*dimno=*/split_dimension));
2216 }
2217
2218 // Handle data communication.
2219 HloInstructionProto instr;
2220 TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices));
2221 std::vector<const Shape*> slice_shape_ptrs;
2222 absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs),
2223 [](const Shape& shape) { return &shape; });
2224 TF_ASSIGN_OR_RETURN(
2225 Shape shape, ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
2226 *instr.mutable_shape() = shape.ToProto();
2227 for (const ReplicaGroup& group : replica_groups) {
2228 *instr.add_replica_groups() = group;
2229 }
2230 TF_ASSIGN_OR_RETURN(
2231 XlaOp alltoall,
2232 AddInstruction(std::move(instr), HloOpcode::kAllToAll, slices));
2233
2234 // Concat the N received parts.
2235 std::vector<XlaOp> received;
2236 received.reserve(split_count);
2237 for (int i = 0; i < split_count; i++) {
2238 received.push_back(this->GetTupleElement(alltoall, i));
2239 }
2240 return this->ConcatInDim(received, concat_dimension);
2241 });
2242 }
2243
CollectivePermute(XlaOp operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)2244 XlaOp XlaBuilder::CollectivePermute(
2245 XlaOp operand,
2246 const std::vector<std::pair<int64, int64>>& source_target_pairs) {
2247 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2248 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2249 HloInstructionProto instr;
2250 TF_ASSIGN_OR_RETURN(
2251 Shape shape,
2252 ShapeInference::InferCollectivePermuteShape(*operand_shape));
2253 *instr.mutable_shape() = shape.ToProto();
2254
2255 for (const auto& pair : source_target_pairs) {
2256 auto* proto_pair = instr.add_source_target_pairs();
2257 proto_pair->set_source(pair.first);
2258 proto_pair->set_target(pair.second);
2259 }
2260
2261 return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute,
2262 {operand});
2263 });
2264 }
2265
ReplicaId()2266 XlaOp XlaBuilder::ReplicaId() {
2267 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2268 HloInstructionProto instr;
2269 *instr.mutable_shape() = ShapeUtil::MakeShape(U32, {}).ToProto();
2270 return AddInstruction(std::move(instr), HloOpcode::kReplicaId, {});
2271 });
2272 }
2273
SelectAndScatter(XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)2274 XlaOp XlaBuilder::SelectAndScatter(XlaOp operand, const XlaComputation& select,
2275 absl::Span<const int64> window_dimensions,
2276 absl::Span<const int64> window_strides,
2277 Padding padding, XlaOp source,
2278 XlaOp init_value,
2279 const XlaComputation& scatter) {
2280 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2281 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2282 return SelectAndScatterWithGeneralPadding(
2283 operand, select, window_dimensions, window_strides,
2284 MakePadding(AsInt64Slice(operand_shape->dimensions()),
2285 window_dimensions, window_strides, padding),
2286 source, init_value, scatter);
2287 });
2288 }
2289
SelectAndScatterWithGeneralPadding(XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,XlaOp source,XlaOp init_value,const XlaComputation & scatter)2290 XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
2291 XlaOp operand, const XlaComputation& select,
2292 absl::Span<const int64> window_dimensions,
2293 absl::Span<const int64> window_strides,
2294 absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
2295 XlaOp init_value, const XlaComputation& scatter) {
2296 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2297 HloInstructionProto instr;
2298
2299 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2300 TF_ASSIGN_OR_RETURN(const Shape* source_shape, GetShapePtr(source));
2301 TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
2302 TF_ASSIGN_OR_RETURN(const ProgramShape& select_shape,
2303 select.GetProgramShape());
2304 TF_ASSIGN_OR_RETURN(const ProgramShape& scatter_shape,
2305 scatter.GetProgramShape());
2306 TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
2307 ShapeInference::InferWindowFromDimensions(
2308 window_dimensions, window_strides, padding,
2309 /*lhs_dilation=*/{}, /*rhs_dilation=*/{}));
2310 TF_ASSIGN_OR_RETURN(Shape shape,
2311 ShapeInference::InferSelectAndScatterShape(
2312 *operand_shape, select_shape, instr.window(),
2313 *source_shape, *init_shape, scatter_shape));
2314 *instr.mutable_shape() = shape.ToProto();
2315
2316 AddCalledComputation(select, &instr);
2317 AddCalledComputation(scatter, &instr);
2318
2319 return AddInstruction(std::move(instr), HloOpcode::kSelectAndScatter,
2320 {operand, source, init_value});
2321 });
2322 }
2323
ReducePrecision(XlaOp operand,const int exponent_bits,const int mantissa_bits)2324 XlaOp XlaBuilder::ReducePrecision(XlaOp operand, const int exponent_bits,
2325 const int mantissa_bits) {
2326 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2327 HloInstructionProto instr;
2328 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2329 TF_ASSIGN_OR_RETURN(Shape shape,
2330 ShapeInference::InferReducePrecisionShape(
2331 *operand_shape, exponent_bits, mantissa_bits));
2332 *instr.mutable_shape() = shape.ToProto();
2333 instr.set_exponent_bits(exponent_bits);
2334 instr.set_mantissa_bits(mantissa_bits);
2335 return AddInstruction(std::move(instr), HloOpcode::kReducePrecision,
2336 {operand});
2337 });
2338 }
2339
Send(XlaOp operand,const ChannelHandle & handle)2340 void XlaBuilder::Send(XlaOp operand, const ChannelHandle& handle) {
2341 ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2342 // Send HLO takes two operands: a data operand and a token. Generate the
2343 // token to pass into the send.
2344 // TODO(b/80000000): Remove this when clients have been updated to handle
2345 // tokens.
2346 HloInstructionProto token_instr;
2347 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
2348 TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
2349 HloOpcode::kAfterAll, {}));
2350
2351 return SendWithToken(operand, token, handle);
2352 });
2353 }
2354
SendWithToken(XlaOp operand,XlaOp token,const ChannelHandle & handle)2355 XlaOp XlaBuilder::SendWithToken(XlaOp operand, XlaOp token,
2356 const ChannelHandle& handle) {
2357 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2358 if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
2359 return InvalidArgument("Send must use a device-to-device channel");
2360 }
2361
2362 // Send instruction produces a tuple of {aliased operand, U32 context,
2363 // token}.
2364 HloInstructionProto send_instr;
2365 TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
2366 *send_instr.mutable_shape() =
2367 ShapeUtil::MakeTupleShape({*shape, ShapeUtil::MakeShape(U32, {}),
2368 ShapeUtil::MakeTokenShape()})
2369 .ToProto();
2370 send_instr.set_channel_id(handle.handle());
2371 TF_ASSIGN_OR_RETURN(XlaOp send,
2372 AddInstruction(std::move(send_instr), HloOpcode::kSend,
2373 {operand, token}));
2374
2375 HloInstructionProto send_done_instr;
2376 *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
2377 send_done_instr.set_channel_id(handle.handle());
2378 return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
2379 {send});
2380 });
2381 }
2382
Recv(const Shape & shape,const ChannelHandle & handle)2383 XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
2384 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2385 // Recv HLO takes a single token operand. Generate the token to pass into
2386 // the Recv and RecvDone instructions.
2387 // TODO(b/80000000): Remove this when clients have been updated to handle
2388 // tokens.
2389 HloInstructionProto token_instr;
2390 *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
2391 TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
2392 HloOpcode::kAfterAll, {}));
2393
2394 XlaOp recv = RecvWithToken(token, shape, handle);
2395
2396 // The RecvDone instruction produces a tuple of the data and a token
2397 // type. Return XLA op containing the data.
2398 // TODO(b/80000000): Remove this when clients have been updated to handle
2399 // tokens.
2400 HloInstructionProto recv_data;
2401 *recv_data.mutable_shape() = shape.ToProto();
2402 recv_data.set_tuple_index(0);
2403 return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement,
2404 {recv});
2405 });
2406 }
2407
RecvWithToken(XlaOp token,const Shape & shape,const ChannelHandle & handle)2408 XlaOp XlaBuilder::RecvWithToken(XlaOp token, const Shape& shape,
2409 const ChannelHandle& handle) {
2410 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2411 if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
2412 return InvalidArgument("Recv must use a device-to-device channel");
2413 }
2414
2415 // Recv instruction produces a tuple of {receive buffer, U32 context,
2416 // token}.
2417 HloInstructionProto recv_instr;
2418 *recv_instr.mutable_shape() =
2419 ShapeUtil::MakeTupleShape(
2420 {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
2421 .ToProto();
2422 recv_instr.set_channel_id(handle.handle());
2423 TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
2424 HloOpcode::kRecv, {token}));
2425
2426 HloInstructionProto recv_done_instr;
2427 *recv_done_instr.mutable_shape() =
2428 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
2429 .ToProto();
2430 recv_done_instr.set_channel_id(handle.handle());
2431 return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
2432 {recv});
2433 });
2434 }
2435
SendToHost(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const ChannelHandle & handle)2436 XlaOp XlaBuilder::SendToHost(XlaOp operand, XlaOp token,
2437 const Shape& shape_with_layout,
2438 const ChannelHandle& handle) {
2439 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2440 if (!LayoutUtil::HasLayout(shape_with_layout)) {
2441 return InvalidArgument("Shape passed to SendToHost must have a layout");
2442 }
2443 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2444 if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
2445 return InvalidArgument(
2446 "SendToHost shape %s must be compatible with operand shape %s",
2447 ShapeUtil::HumanStringWithLayout(shape_with_layout),
2448 ShapeUtil::HumanStringWithLayout(*operand_shape));
2449 }
2450 // TODO(b/111544877): Support tuple shapes.
2451 if (!operand_shape->IsArray()) {
2452 return InvalidArgument("SendToHost only supports array shapes, shape: %s",
2453 ShapeUtil::HumanString(*operand_shape));
2454 }
2455
2456 if (handle.type() != ChannelHandle::DEVICE_TO_HOST) {
2457 return InvalidArgument("SendToHost must use a device-to-host channel");
2458 }
2459
2460 // Send instruction produces a tuple of {aliased operand, U32 context,
2461 // token}.
2462 HloInstructionProto send_instr;
2463 *send_instr.mutable_shape() =
2464 ShapeUtil::MakeTupleShape({shape_with_layout,
2465 ShapeUtil::MakeShape(U32, {}),
2466 ShapeUtil::MakeTokenShape()})
2467 .ToProto();
2468 send_instr.set_channel_id(handle.handle());
2469 send_instr.set_is_host_transfer(true);
2470 TF_ASSIGN_OR_RETURN(XlaOp send,
2471 AddInstruction(std::move(send_instr), HloOpcode::kSend,
2472 {operand, token}));
2473
2474 HloInstructionProto send_done_instr;
2475 *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
2476 send_done_instr.set_channel_id(handle.handle());
2477 send_done_instr.set_is_host_transfer(true);
2478 return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
2479 {send});
2480 });
2481 }
2482
RecvFromHost(XlaOp token,const Shape & shape,const ChannelHandle & handle)2483 XlaOp XlaBuilder::RecvFromHost(XlaOp token, const Shape& shape,
2484 const ChannelHandle& handle) {
2485 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2486 if (!LayoutUtil::HasLayout(shape)) {
2487 return InvalidArgument("Shape passed to RecvFromHost must have a layout");
2488 }
2489
2490 // TODO(b/111544877): Support tuple shapes.
2491 if (!shape.IsArray()) {
2492 return InvalidArgument(
2493 "RecvFromHost only supports array shapes, shape: %s",
2494 ShapeUtil::HumanString(shape));
2495 }
2496
2497 if (handle.type() != ChannelHandle::HOST_TO_DEVICE) {
2498 return InvalidArgument("RecvFromHost must use a host-to-device channel");
2499 }
2500
2501 // Recv instruction produces a tuple of {receive buffer, U32 context,
2502 // token}.
2503 HloInstructionProto recv_instr;
2504 *recv_instr.mutable_shape() =
2505 ShapeUtil::MakeTupleShape(
2506 {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
2507 .ToProto();
2508 recv_instr.set_channel_id(handle.handle());
2509 recv_instr.set_is_host_transfer(true);
2510 TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
2511 HloOpcode::kRecv, {token}));
2512
2513 HloInstructionProto recv_done_instr;
2514 *recv_done_instr.mutable_shape() =
2515 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
2516 .ToProto();
2517 recv_done_instr.set_channel_id(handle.handle());
2518 recv_done_instr.set_is_host_transfer(true);
2519 return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
2520 {recv});
2521 });
2522 }
2523
GetDimensionSize(XlaOp operand,int64 dimension)2524 XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64 dimension) {
2525 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2526 HloInstructionProto instr;
2527 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2528 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape(
2529 *operand_shape, dimension));
2530 *instr.mutable_shape() = shape.ToProto();
2531 instr.add_dimensions(dimension);
2532 return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize,
2533 {operand});
2534 });
2535 }
2536
SetDimensionSize(XlaOp operand,XlaOp val,int64 dimension)2537 XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) {
2538 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2539 HloInstructionProto instr;
2540 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
2541 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSetDimensionSizeShape(
2542 *operand_shape, dimension));
2543 *instr.mutable_shape() = shape.ToProto();
2544 instr.add_dimensions(dimension);
2545 return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize,
2546 {operand, val});
2547 });
2548 }
2549
IsConstant(XlaOp operand) const2550 StatusOr<bool> XlaBuilder::IsConstant(XlaOp operand) const {
2551 TF_RETURN_IF_ERROR(first_error_);
2552
2553 // Verify that the handle is valid.
2554 TF_RETURN_IF_ERROR(LookUpInstruction(operand).status());
2555
2556 bool is_constant = true;
2557 absl::flat_hash_set<int64> visited;
2558 IsConstantVisitor(operand.handle(), &visited, &is_constant);
2559 return is_constant;
2560 }
2561
BuildConstantSubGraph(XlaOp root_op,bool dynamic_dimension_is_minus_one)2562 StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
2563 XlaOp root_op, bool dynamic_dimension_is_minus_one) {
2564 TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
2565 if (!is_constant) {
2566 auto op_status = LookUpInstruction(root_op);
2567 string op_string =
2568 op_status.ok() ? op_status.ValueOrDie()->name() : "<unknown operation>";
2569 return InvalidArgument(
2570 "Operand to BuildConstantSubGraph depends on a parameter.\n\n"
2571 " op requested for constant subgraph: %s\n\n"
2572 "This is an internal error that typically happens when the XLA user "
2573 "(e.g. TensorFlow) is attempting to determine a value that must be a "
2574 "compile-time constant (e.g. an array dimension) but it is not capable "
2575 "of being evaluated at XLA compile time.\n\n"
2576 "Please file a usability bug with the framework being used (e.g. "
2577 "TensorFlow).",
2578 op_string);
2579 }
2580
2581 TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
2582 LookUpInstruction(root_op));
2583
2584 HloComputationProto entry;
2585 SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator,
2586 GetNextId());
2587 entry.set_root_id(root->id());
2588 ProgramShapeProto* program_shape = entry.mutable_program_shape();
2589 *program_shape->mutable_result() = root->shape();
2590
2591 // We use std::set to keep the instruction ids in ascending order (which is
2592 // also a valid dependency order). The related ops will be added to the
2593 // subgraph in the same order.
2594 std::set<int64> related_ops;
2595 absl::flat_hash_set<int64> related_calls; // Related computations.
2596 std::queue<int64> worklist;
2597 worklist.push(root->id());
2598 related_ops.insert(root->id());
2599 while (!worklist.empty()) {
2600 int64 handle = worklist.front();
2601 worklist.pop();
2602 TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
2603 LookUpInstructionByHandle(handle));
2604
2605 if (instr_proto->opcode() ==
2606 HloOpcodeString(HloOpcode::kGetDimensionSize)) {
2607 // At this point, BuildConstantSubGraph should never encounter a
2608 // GetDimensionSize with a dynamic dimension. IsConstant check would have
2609 // failed at the beginning of this function.
2610 //
2611 // Replace GetDimensionSize with a Constant representing the static bound
2612 // of the shape.
2613 int64 dimension = instr_proto->dimensions(0);
2614 int64 operand_handle = instr_proto->operand_ids(0);
2615 TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
2616 LookUpInstructionByHandle(operand_handle));
2617
2618 int32 constant_dimension_size = -1;
2619 if (!(operand_proto->shape().is_dynamic_dimension(dimension) &&
2620 dynamic_dimension_is_minus_one)) {
2621 constant_dimension_size =
2622 static_cast<int32>(operand_proto->shape().dimensions(dimension));
2623 }
2624
2625 Literal literal = LiteralUtil::CreateR0(constant_dimension_size);
2626
2627 HloInstructionProto const_instr;
2628 *const_instr.mutable_shape() = literal.shape().ToProto();
2629 *const_instr.mutable_literal() = literal.ToProto();
2630 *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
2631
2632 const_instr.set_id(handle);
2633 *const_instr.mutable_name() =
2634 GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id());
2635 *entry.add_instructions() =
2636 const_instr; // Add to the result constant graph.
2637 } else {
2638 for (int64 id : instr_proto->operand_ids()) {
2639 if (related_ops.insert(id).second) {
2640 worklist.push(id);
2641 }
2642 }
2643 for (int64 called_id : instr_proto->called_computation_ids()) {
2644 related_calls.insert(called_id);
2645 }
2646 }
2647 }
2648
2649 // Add related ops to the computation.
2650 for (int64 id : related_ops) {
2651 TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
2652 LookUpInstructionByHandle(id));
2653
2654 if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize)) {
2655 continue;
2656 }
2657 auto* instr = entry.add_instructions();
2658
2659 *instr = *instr_src;
2660 // Ensures that the instruction names are unique among the graph.
2661 const string& new_name =
2662 StrCat(instr->name(), ".", entry.id(), ".", instr->id());
2663 instr->set_name(new_name);
2664 }
2665
2666 XlaComputation computation(entry.id());
2667 HloModuleProto* module = computation.mutable_proto();
2668 module->set_name(entry.name());
2669 module->set_id(entry.id());
2670 module->set_entry_computation_name(entry.name());
2671 module->set_entry_computation_id(entry.id());
2672 *module->mutable_host_program_shape() = *program_shape;
2673 for (auto& e : embedded_) {
2674 if (related_calls.find(e.second.id()) != related_calls.end()) {
2675 *module->add_computations() = e.second;
2676 }
2677 }
2678 *module->add_computations() = std::move(entry);
2679
2680 return std::move(computation);
2681 }
2682
CreateSubBuilder(const string & computation_name)2683 std::unique_ptr<XlaBuilder> XlaBuilder::CreateSubBuilder(
2684 const string& computation_name) {
2685 auto sub_builder = absl::make_unique<XlaBuilder>(computation_name);
2686 sub_builder->parent_builder_ = this;
2687 sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_;
2688 return sub_builder;
2689 }
2690
2691 /* static */ ConvolutionDimensionNumbers
CreateDefaultConvDimensionNumbers(int num_spatial_dims)2692 XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
2693 ConvolutionDimensionNumbers dimension_numbers;
2694 dimension_numbers.set_input_batch_dimension(kConvBatchDimension);
2695 dimension_numbers.set_input_feature_dimension(kConvFeatureDimension);
2696 dimension_numbers.set_output_batch_dimension(kConvBatchDimension);
2697 dimension_numbers.set_output_feature_dimension(kConvFeatureDimension);
2698 dimension_numbers.set_kernel_output_feature_dimension(
2699 kConvKernelOutputDimension);
2700 dimension_numbers.set_kernel_input_feature_dimension(
2701 kConvKernelInputDimension);
2702 for (int i = 0; i < num_spatial_dims; ++i) {
2703 dimension_numbers.add_input_spatial_dimensions(i + 2);
2704 dimension_numbers.add_kernel_spatial_dimensions(i + 2);
2705 dimension_numbers.add_output_spatial_dimensions(i + 2);
2706 }
2707 return dimension_numbers;
2708 }
2709
Validate(const ConvolutionDimensionNumbers & dnum)2710 /* static */ Status XlaBuilder::Validate(
2711 const ConvolutionDimensionNumbers& dnum) {
2712 if (dnum.input_spatial_dimensions_size() < 2) {
2713 return FailedPrecondition("input spacial dimension < 2: %d",
2714 dnum.input_spatial_dimensions_size());
2715 }
2716 if (dnum.kernel_spatial_dimensions_size() < 2) {
2717 return FailedPrecondition("kernel spacial dimension < 2: %d",
2718 dnum.kernel_spatial_dimensions_size());
2719 }
2720 if (dnum.output_spatial_dimensions_size() < 2) {
2721 return FailedPrecondition("output spacial dimension < 2: %d",
2722 dnum.output_spatial_dimensions_size());
2723 }
2724
2725 if (std::set<int64>(
2726 {dnum.input_batch_dimension(), dnum.input_feature_dimension(),
2727 dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)})
2728 .size() != 4) {
2729 return FailedPrecondition(
2730 "dimension numbers for the input are not unique: (%d, %d, %d, "
2731 "%d)",
2732 dnum.input_batch_dimension(), dnum.input_feature_dimension(),
2733 dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1));
2734 }
2735 if (std::set<int64>({dnum.kernel_output_feature_dimension(),
2736 dnum.kernel_input_feature_dimension(),
2737 dnum.kernel_spatial_dimensions(0),
2738 dnum.kernel_spatial_dimensions(1)})
2739 .size() != 4) {
2740 return FailedPrecondition(
2741 "dimension numbers for the weight are not unique: (%d, %d, %d, "
2742 "%d)",
2743 dnum.kernel_output_feature_dimension(),
2744 dnum.kernel_input_feature_dimension(),
2745 dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1));
2746 }
2747 if (std::set<int64>({dnum.output_batch_dimension(),
2748 dnum.output_feature_dimension(),
2749 dnum.output_spatial_dimensions(0),
2750 dnum.output_spatial_dimensions(1)})
2751 .size() != 4) {
2752 return FailedPrecondition(
2753 "dimension numbers for the output are not unique: (%d, %d, %d, "
2754 "%d)",
2755 dnum.output_batch_dimension(), dnum.output_feature_dimension(),
2756 dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1));
2757 }
2758 return Status::OK();
2759 }
2760
AddInstruction(HloInstructionProto && instr,HloOpcode opcode,absl::Span<const XlaOp> operands)2761 StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
2762 HloOpcode opcode,
2763 absl::Span<const XlaOp> operands) {
2764 TF_RETURN_IF_ERROR(first_error_);
2765
2766 const int64 handle = GetNextId();
2767 instr.set_id(handle);
2768 instr.set_opcode(HloOpcodeString(opcode));
2769 if (instr.name().empty()) {
2770 instr.set_name(instr.opcode());
2771 }
2772 for (const auto& operand : operands) {
2773 if (operand.builder_ == nullptr) {
2774 return InvalidArgument("invalid XlaOp with handle %d", operand.handle());
2775 }
2776 if (operand.builder_ != this) {
2777 return InvalidArgument("Do not add XlaOp from builder %s to builder %s",
2778 operand.builder_->name(), this->name());
2779 }
2780 instr.add_operand_ids(operand.handle());
2781 }
2782
2783 *instr.mutable_metadata() = metadata_;
2784 if (sharding_) {
2785 *instr.mutable_sharding() = *sharding_;
2786 }
2787 *instr.mutable_frontend_attributes() = frontend_attributes_;
2788
2789 handle_to_index_[handle] = instructions_.size();
2790 instructions_.push_back(std::move(instr));
2791 instruction_shapes_.push_back(
2792 absl::make_unique<Shape>(instructions_.back().shape()));
2793
2794 XlaOp op(handle, this);
2795 return op;
2796 }
2797
AddCalledComputation(const XlaComputation & computation,HloInstructionProto * instr)2798 void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
2799 HloInstructionProto* instr) {
2800 absl::flat_hash_map<int64, int64> remapped_ids;
2801 std::vector<HloComputationProto> imported_computations;
2802 imported_computations.reserve(computation.proto().computations_size());
2803 // Before we import the computations by remapping IDs, and capturing the
2804 // old->new mappings in remapped_ids.
2805 for (const HloComputationProto& e : computation.proto().computations()) {
2806 HloComputationProto new_computation(e);
2807 int64 computation_id = GetNextId();
2808 remapped_ids[new_computation.id()] = computation_id;
2809 SetProtoIdAndName(&new_computation,
2810 GetBaseName(new_computation.name(), kNameSeparator),
2811 kNameSeparator, computation_id);
2812 for (auto& instruction : *new_computation.mutable_instructions()) {
2813 int64 instruction_id = GetNextId();
2814 remapped_ids[instruction.id()] = instruction_id;
2815 SetProtoIdAndName(&instruction,
2816 GetBaseName(instruction.name(), kNameSeparator),
2817 kNameSeparator, instruction_id);
2818 }
2819 new_computation.set_root_id(remapped_ids.at(new_computation.root_id()));
2820
2821 imported_computations.push_back(std::move(new_computation));
2822 }
2823 // Once we have imported all the computations, and captured all the ID
2824 // mappings, we go back and fixup the IDs in the imported computations.
2825 instr->add_called_computation_ids(
2826 remapped_ids.at(computation.proto().entry_computation_id()));
2827 for (auto& imported_computation : imported_computations) {
2828 for (auto& instruction : *imported_computation.mutable_instructions()) {
2829 for (auto& operand_id : *instruction.mutable_operand_ids()) {
2830 operand_id = remapped_ids.at(operand_id);
2831 }
2832 for (auto& control_predecessor_id :
2833 *instruction.mutable_control_predecessor_ids()) {
2834 control_predecessor_id = remapped_ids.at(control_predecessor_id);
2835 }
2836 for (auto& called_computation_id :
2837 *instruction.mutable_called_computation_ids()) {
2838 called_computation_id = remapped_ids.at(called_computation_id);
2839 }
2840 }
2841
2842 int64 computation_id = imported_computation.id();
2843 embedded_.insert({computation_id, std::move(imported_computation)});
2844 }
2845 }
2846
LookUpInstruction(const XlaOp op) const2847 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
2848 const XlaOp op) const {
2849 TF_RETURN_IF_ERROR(first_error_);
2850 return LookUpInstructionInternal<const HloInstructionProto*>(
2851 handle_to_index_, instructions_, op.builder_, this, op.handle());
2852 }
2853
LookUpInstructionByHandle(int64 handle) const2854 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
2855 int64 handle) const {
2856 return LookUpInstructionByHandleInternal<const HloInstructionProto*>(
2857 handle_to_index_, instructions_, handle);
2858 }
2859
LookUpMutableInstruction(const XlaOp op)2860 StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstruction(
2861 const XlaOp op) {
2862 TF_RETURN_IF_ERROR(first_error_);
2863 return LookUpInstructionInternal<HloInstructionProto*>(
2864 handle_to_index_, instructions_, op.builder_, this, op.handle());
2865 }
2866
LookUpMutableInstructionByHandle(int64 handle)2867 StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstructionByHandle(
2868 int64 handle) {
2869 return LookUpInstructionByHandleInternal<HloInstructionProto*>(
2870 handle_to_index_, instructions_, handle);
2871 }
2872
2873 // Enqueues a "retrieve parameter value" instruction for a parameter that was
2874 // passed to the computation.
Parameter(XlaBuilder * builder,int64 parameter_number,const Shape & shape,const string & name)2875 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
2876 const string& name) {
2877 std::vector<bool> empty_bools;
2878 return Parameter(builder, parameter_number, shape, name, empty_bools);
2879 }
2880
Parameter(XlaBuilder * builder,int64 parameter_number,const Shape & shape,const string & name,const std::vector<bool> & replicated_at_leaf_buffers)2881 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
2882 const string& name,
2883 const std::vector<bool>& replicated_at_leaf_buffers) {
2884 return builder->Parameter(parameter_number, shape, name,
2885 replicated_at_leaf_buffers);
2886 }
2887
2888 // Enqueues a constant with the value of the given literal onto the
2889 // computation.
ConstantLiteral(XlaBuilder * builder,const LiteralSlice & literal)2890 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) {
2891 return builder->ConstantLiteral(literal);
2892 }
2893
Broadcast(const XlaOp operand,absl::Span<const int64> broadcast_sizes)2894 XlaOp Broadcast(const XlaOp operand, absl::Span<const int64> broadcast_sizes) {
2895 return operand.builder()->Broadcast(operand, broadcast_sizes);
2896 }
2897
BroadcastInDim(const XlaOp operand,const absl::Span<const int64> out_dim_size,const absl::Span<const int64> broadcast_dimensions)2898 XlaOp BroadcastInDim(const XlaOp operand,
2899 const absl::Span<const int64> out_dim_size,
2900 const absl::Span<const int64> broadcast_dimensions) {
2901 return operand.builder()->BroadcastInDim(operand, out_dim_size,
2902 broadcast_dimensions);
2903 }
2904
Copy(const XlaOp operand)2905 XlaOp Copy(const XlaOp operand) {
2906 return operand.builder()->UnaryOp(HloOpcode::kCopy, operand);
2907 }
2908
Pad(const XlaOp operand,const XlaOp padding_value,const PaddingConfig & padding_config)2909 XlaOp Pad(const XlaOp operand, const XlaOp padding_value,
2910 const PaddingConfig& padding_config) {
2911 return operand.builder()->Pad(operand, padding_value, padding_config);
2912 }
2913
Reshape(const XlaOp operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes)2914 XlaOp Reshape(const XlaOp operand, absl::Span<const int64> dimensions,
2915 absl::Span<const int64> new_sizes) {
2916 return operand.builder()->Reshape(operand, dimensions, new_sizes);
2917 }
2918
Reshape(const XlaOp operand,absl::Span<const int64> new_sizes)2919 XlaOp Reshape(const XlaOp operand, absl::Span<const int64> new_sizes) {
2920 return operand.builder()->Reshape(operand, new_sizes);
2921 }
2922
ReshapeWithInferredDimension(XlaOp operand,absl::Span<const int64> new_sizes,int64 inferred_dimension)2923 XlaOp ReshapeWithInferredDimension(XlaOp operand,
2924 absl::Span<const int64> new_sizes,
2925 int64 inferred_dimension) {
2926 return operand.builder()->Reshape(operand, new_sizes, inferred_dimension);
2927 }
2928
Collapse(const XlaOp operand,absl::Span<const int64> dimensions)2929 XlaOp Collapse(const XlaOp operand, absl::Span<const int64> dimensions) {
2930 return operand.builder()->Collapse(operand, dimensions);
2931 }
2932
Slice(const XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)2933 XlaOp Slice(const XlaOp operand, absl::Span<const int64> start_indices,
2934 absl::Span<const int64> limit_indices,
2935 absl::Span<const int64> strides) {
2936 return operand.builder()->Slice(operand, start_indices, limit_indices,
2937 strides);
2938 }
2939
SliceInDim(const XlaOp operand,int64 start_index,int64 limit_index,int64 stride,int64 dimno)2940 XlaOp SliceInDim(const XlaOp operand, int64 start_index, int64 limit_index,
2941 int64 stride, int64 dimno) {
2942 return operand.builder()->SliceInDim(operand, start_index, limit_index,
2943 stride, dimno);
2944 }
2945
DynamicSlice(const XlaOp operand,const XlaOp start_indices,absl::Span<const int64> slice_sizes)2946 XlaOp DynamicSlice(const XlaOp operand, const XlaOp start_indices,
2947 absl::Span<const int64> slice_sizes) {
2948 return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
2949 }
DynamicSlice(const XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)2950 XlaOp DynamicSlice(const XlaOp operand, absl::Span<const XlaOp> start_indices,
2951 absl::Span<const int64> slice_sizes) {
2952 return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
2953 }
2954
DynamicUpdateSlice(const XlaOp operand,const XlaOp update,const XlaOp start_indices)2955 XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update,
2956 const XlaOp start_indices) {
2957 return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
2958 }
2959
DynamicUpdateSlice(const XlaOp operand,const XlaOp update,absl::Span<const XlaOp> start_indices)2960 XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update,
2961 absl::Span<const XlaOp> start_indices) {
2962 return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
2963 }
2964
ConcatInDim(XlaBuilder * builder,absl::Span<const XlaOp> operands,int64 dimension)2965 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
2966 int64 dimension) {
2967 return builder->ConcatInDim(operands, dimension);
2968 }
2969
Trace(const string & tag,const XlaOp operand)2970 void Trace(const string& tag, const XlaOp operand) {
2971 return operand.builder()->Trace(tag, operand);
2972 }
2973
Select(const XlaOp pred,const XlaOp on_true,const XlaOp on_false)2974 XlaOp Select(const XlaOp pred, const XlaOp on_true, const XlaOp on_false) {
2975 return pred.builder()->Select(pred, on_true, on_false);
2976 }
2977
Tuple(XlaBuilder * builder,absl::Span<const XlaOp> elements)2978 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements) {
2979 return builder->Tuple(elements);
2980 }
2981
GetTupleElement(const XlaOp tuple_data,int64 index)2982 XlaOp GetTupleElement(const XlaOp tuple_data, int64 index) {
2983 return tuple_data.builder()->GetTupleElement(tuple_data, index);
2984 }
2985
Eq(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)2986 XlaOp Eq(const XlaOp lhs, const XlaOp rhs,
2987 absl::Span<const int64> broadcast_dimensions) {
2988 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
2989 }
2990
Ne(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)2991 XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
2992 absl::Span<const int64> broadcast_dimensions) {
2993 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe);
2994 }
2995
Ge(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)2996 XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
2997 absl::Span<const int64> broadcast_dimensions) {
2998 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe);
2999 }
3000
Gt(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3001 XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
3002 absl::Span<const int64> broadcast_dimensions) {
3003 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt);
3004 }
3005
Le(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3006 XlaOp Le(const XlaOp lhs, const XlaOp rhs,
3007 absl::Span<const int64> broadcast_dimensions) {
3008 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe);
3009 }
3010
Lt(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3011 XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
3012 absl::Span<const int64> broadcast_dimensions) {
3013 return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
3014 }
3015
Compare(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions,ComparisonDirection direction)3016 XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
3017 absl::Span<const int64> broadcast_dimensions,
3018 ComparisonDirection direction) {
3019 return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
3020 broadcast_dimensions, direction);
3021 }
3022
Dot(const XlaOp lhs,const XlaOp rhs,const PrecisionConfig * precision_config)3023 XlaOp Dot(const XlaOp lhs, const XlaOp rhs,
3024 const PrecisionConfig* precision_config) {
3025 return lhs.builder()->Dot(lhs, rhs, precision_config);
3026 }
3027
DotGeneral(const XlaOp lhs,const XlaOp rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config)3028 XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs,
3029 const DotDimensionNumbers& dimension_numbers,
3030 const PrecisionConfig* precision_config) {
3031 return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
3032 precision_config);
3033 }
3034
Conv(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)3035 XlaOp Conv(const XlaOp lhs, const XlaOp rhs,
3036 absl::Span<const int64> window_strides, Padding padding,
3037 int64 feature_group_count, int64 batch_group_count,
3038 const PrecisionConfig* precision_config) {
3039 return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
3040 feature_group_count, batch_group_count,
3041 precision_config);
3042 }
3043
ConvWithGeneralPadding(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)3044 XlaOp ConvWithGeneralPadding(const XlaOp lhs, const XlaOp rhs,
3045 absl::Span<const int64> window_strides,
3046 absl::Span<const std::pair<int64, int64>> padding,
3047 int64 feature_group_count, int64 batch_group_count,
3048 const PrecisionConfig* precision_config) {
3049 return lhs.builder()->ConvWithGeneralPadding(
3050 lhs, rhs, window_strides, padding, feature_group_count, batch_group_count,
3051 precision_config);
3052 }
3053
ConvWithGeneralDimensions(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)3054 XlaOp ConvWithGeneralDimensions(
3055 const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
3056 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
3057 int64 feature_group_count, int64 batch_group_count,
3058 const PrecisionConfig* precision_config) {
3059 return lhs.builder()->ConvWithGeneralDimensions(
3060 lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
3061 batch_group_count, precision_config);
3062 }
3063
ConvGeneral(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)3064 XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
3065 absl::Span<const int64> window_strides,
3066 absl::Span<const std::pair<int64, int64>> padding,
3067 const ConvolutionDimensionNumbers& dimension_numbers,
3068 int64 feature_group_count, int64 batch_group_count,
3069 const PrecisionConfig* precision_config) {
3070 return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding,
3071 dimension_numbers, feature_group_count,
3072 batch_group_count, precision_config);
3073 }
3074
ConvGeneralDilated(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)3075 XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
3076 absl::Span<const int64> window_strides,
3077 absl::Span<const std::pair<int64, int64>> padding,
3078 absl::Span<const int64> lhs_dilation,
3079 absl::Span<const int64> rhs_dilation,
3080 const ConvolutionDimensionNumbers& dimension_numbers,
3081 int64 feature_group_count, int64 batch_group_count,
3082 const PrecisionConfig* precision_config) {
3083 return lhs.builder()->ConvGeneralDilated(
3084 lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
3085 dimension_numbers, feature_group_count, batch_group_count,
3086 precision_config);
3087 }
3088
Fft(const XlaOp operand,FftType fft_type,absl::Span<const int64> fft_length)3089 XlaOp Fft(const XlaOp operand, FftType fft_type,
3090 absl::Span<const int64> fft_length) {
3091 return operand.builder()->Fft(operand, fft_type, fft_length);
3092 }
3093
TriangularSolve(XlaOp a,XlaOp b,bool left_side,bool lower,bool unit_diagonal,TriangularSolveOptions::Transpose transpose_a)3094 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
3095 bool unit_diagonal,
3096 TriangularSolveOptions::Transpose transpose_a) {
3097 XlaBuilder* builder = a.builder();
3098 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3099 HloInstructionProto instr;
3100 TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a));
3101 TF_ASSIGN_OR_RETURN(const Shape* b_shape, builder->GetShapePtr(b));
3102 xla::TriangularSolveOptions& options =
3103 *instr.mutable_triangular_solve_options();
3104 options.set_left_side(left_side);
3105 options.set_lower(lower);
3106 options.set_unit_diagonal(unit_diagonal);
3107 options.set_transpose_a(transpose_a);
3108 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTriangularSolveShape(
3109 *a_shape, *b_shape, options));
3110 *instr.mutable_shape() = shape.ToProto();
3111
3112 return builder->AddInstruction(std::move(instr),
3113 HloOpcode::kTriangularSolve, {a, b});
3114 });
3115 }
3116
Cholesky(XlaOp a,bool lower)3117 XlaOp Cholesky(XlaOp a, bool lower) {
3118 XlaBuilder* builder = a.builder();
3119 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3120 HloInstructionProto instr;
3121 TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a));
3122 xla::CholeskyOptions& options = *instr.mutable_cholesky_options();
3123 options.set_lower(lower);
3124 TF_ASSIGN_OR_RETURN(Shape shape,
3125 ShapeInference::InferCholeskyShape(*a_shape));
3126 *instr.mutable_shape() = shape.ToProto();
3127
3128 return builder->AddInstruction(std::move(instr), HloOpcode::kCholesky, {a});
3129 });
3130 }
3131
Infeed(XlaBuilder * builder,const Shape & shape,const string & config)3132 XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) {
3133 return builder->Infeed(shape, config);
3134 }
3135
Outfeed(const XlaOp operand,const Shape & shape_with_layout,const string & outfeed_config)3136 void Outfeed(const XlaOp operand, const Shape& shape_with_layout,
3137 const string& outfeed_config) {
3138 return operand.builder()->Outfeed(operand, shape_with_layout, outfeed_config);
3139 }
3140
Call(XlaBuilder * builder,const XlaComputation & computation,absl::Span<const XlaOp> operands)3141 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
3142 absl::Span<const XlaOp> operands) {
3143 return builder->Call(computation, operands);
3144 }
3145
CustomCall(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque)3146 XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
3147 absl::Span<const XlaOp> operands, const Shape& shape,
3148 const string& opaque) {
3149 return builder->CustomCall(call_target_name, operands, shape, opaque,
3150 /*operand_shapes_with_layout=*/absl::nullopt);
3151 }
3152
CustomCallWithLayout(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,absl::Span<const Shape> operand_shapes_with_layout,const string & opaque)3153 XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
3154 absl::Span<const XlaOp> operands, const Shape& shape,
3155 absl::Span<const Shape> operand_shapes_with_layout,
3156 const string& opaque) {
3157 return builder->CustomCall(call_target_name, operands, shape, opaque,
3158 operand_shapes_with_layout);
3159 }
3160
Complex(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3161 XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
3162 absl::Span<const int64> broadcast_dimensions) {
3163 return lhs.builder()->BinaryOp(HloOpcode::kComplex, lhs, rhs,
3164 broadcast_dimensions);
3165 }
3166
Conj(const XlaOp operand)3167 XlaOp Conj(const XlaOp operand) {
3168 return Complex(Real(operand), Neg(Imag(operand)));
3169 }
3170
Add(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3171 XlaOp Add(const XlaOp lhs, const XlaOp rhs,
3172 absl::Span<const int64> broadcast_dimensions) {
3173 return lhs.builder()->BinaryOp(HloOpcode::kAdd, lhs, rhs,
3174 broadcast_dimensions);
3175 }
3176
Sub(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3177 XlaOp Sub(const XlaOp lhs, const XlaOp rhs,
3178 absl::Span<const int64> broadcast_dimensions) {
3179 return lhs.builder()->BinaryOp(HloOpcode::kSubtract, lhs, rhs,
3180 broadcast_dimensions);
3181 }
3182
Mul(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3183 XlaOp Mul(const XlaOp lhs, const XlaOp rhs,
3184 absl::Span<const int64> broadcast_dimensions) {
3185 return lhs.builder()->BinaryOp(HloOpcode::kMultiply, lhs, rhs,
3186 broadcast_dimensions);
3187 }
3188
Div(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3189 XlaOp Div(const XlaOp lhs, const XlaOp rhs,
3190 absl::Span<const int64> broadcast_dimensions) {
3191 return lhs.builder()->BinaryOp(HloOpcode::kDivide, lhs, rhs,
3192 broadcast_dimensions);
3193 }
3194
Rem(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3195 XlaOp Rem(const XlaOp lhs, const XlaOp rhs,
3196 absl::Span<const int64> broadcast_dimensions) {
3197 return lhs.builder()->BinaryOp(HloOpcode::kRemainder, lhs, rhs,
3198 broadcast_dimensions);
3199 }
3200
Max(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3201 XlaOp Max(const XlaOp lhs, const XlaOp rhs,
3202 absl::Span<const int64> broadcast_dimensions) {
3203 return lhs.builder()->BinaryOp(HloOpcode::kMaximum, lhs, rhs,
3204 broadcast_dimensions);
3205 }
3206
Min(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3207 XlaOp Min(const XlaOp lhs, const XlaOp rhs,
3208 absl::Span<const int64> broadcast_dimensions) {
3209 return lhs.builder()->BinaryOp(HloOpcode::kMinimum, lhs, rhs,
3210 broadcast_dimensions);
3211 }
3212
And(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3213 XlaOp And(const XlaOp lhs, const XlaOp rhs,
3214 absl::Span<const int64> broadcast_dimensions) {
3215 return lhs.builder()->BinaryOp(HloOpcode::kAnd, lhs, rhs,
3216 broadcast_dimensions);
3217 }
3218
Or(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3219 XlaOp Or(const XlaOp lhs, const XlaOp rhs,
3220 absl::Span<const int64> broadcast_dimensions) {
3221 return lhs.builder()->BinaryOp(HloOpcode::kOr, lhs, rhs,
3222 broadcast_dimensions);
3223 }
3224
Xor(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3225 XlaOp Xor(const XlaOp lhs, const XlaOp rhs,
3226 absl::Span<const int64> broadcast_dimensions) {
3227 return lhs.builder()->BinaryOp(HloOpcode::kXor, lhs, rhs,
3228 broadcast_dimensions);
3229 }
3230
Not(const XlaOp operand)3231 XlaOp Not(const XlaOp operand) {
3232 return operand.builder()->UnaryOp(HloOpcode::kNot, operand);
3233 }
3234
PopulationCount(const XlaOp operand)3235 XlaOp PopulationCount(const XlaOp operand) {
3236 return operand.builder()->UnaryOp(HloOpcode::kPopulationCount, operand);
3237 }
3238
ShiftLeft(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3239 XlaOp ShiftLeft(const XlaOp lhs, const XlaOp rhs,
3240 absl::Span<const int64> broadcast_dimensions) {
3241 return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs,
3242 broadcast_dimensions);
3243 }
3244
ShiftRightArithmetic(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3245 XlaOp ShiftRightArithmetic(const XlaOp lhs, const XlaOp rhs,
3246 absl::Span<const int64> broadcast_dimensions) {
3247 return lhs.builder()->BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs,
3248 broadcast_dimensions);
3249 }
3250
ShiftRightLogical(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3251 XlaOp ShiftRightLogical(const XlaOp lhs, const XlaOp rhs,
3252 absl::Span<const int64> broadcast_dimensions) {
3253 return lhs.builder()->BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs,
3254 broadcast_dimensions);
3255 }
3256
Reduce(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)3257 XlaOp Reduce(const XlaOp operand, const XlaOp init_value,
3258 const XlaComputation& computation,
3259 absl::Span<const int64> dimensions_to_reduce) {
3260 return operand.builder()->Reduce(operand, init_value, computation,
3261 dimensions_to_reduce);
3262 }
3263
3264 // Reduces several arrays simultaneously among the provided dimensions, given
3265 // "computation" as a reduction operator.
Reduce(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)3266 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
3267 absl::Span<const XlaOp> init_values,
3268 const XlaComputation& computation,
3269 absl::Span<const int64> dimensions_to_reduce) {
3270 return builder->Reduce(operands, init_values, computation,
3271 dimensions_to_reduce);
3272 }
3273
ReduceAll(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation)3274 XlaOp ReduceAll(const XlaOp operand, const XlaOp init_value,
3275 const XlaComputation& computation) {
3276 return operand.builder()->ReduceAll(operand, init_value, computation);
3277 }
3278
ReduceWindow(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)3279 XlaOp ReduceWindow(const XlaOp operand, const XlaOp init_value,
3280 const XlaComputation& computation,
3281 absl::Span<const int64> window_dimensions,
3282 absl::Span<const int64> window_strides, Padding padding) {
3283 return operand.builder()->ReduceWindow(operand, init_value, computation,
3284 window_dimensions, window_strides,
3285 padding);
3286 }
3287
ReduceWindowWithGeneralPadding(const XlaOp operand,const XlaOp init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)3288 XlaOp ReduceWindowWithGeneralPadding(
3289 const XlaOp operand, const XlaOp init_value,
3290 const XlaComputation& computation,
3291 absl::Span<const int64> window_dimensions,
3292 absl::Span<const int64> window_strides,
3293 absl::Span<const int64> base_dilations,
3294 absl::Span<const int64> window_dilations,
3295 absl::Span<const std::pair<int64, int64>> padding) {
3296 return operand.builder()->ReduceWindowWithGeneralPadding(
3297 operand, init_value, computation, window_dimensions, window_strides,
3298 base_dilations, window_dilations, padding);
3299 }
3300
CrossReplicaSum(const XlaOp operand,absl::Span<const ReplicaGroup> replica_groups)3301 XlaOp CrossReplicaSum(const XlaOp operand,
3302 absl::Span<const ReplicaGroup> replica_groups) {
3303 return operand.builder()->CrossReplicaSum(operand, replica_groups);
3304 }
3305
AllReduce(const XlaOp operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id,const absl::optional<Shape> & shape_with_layout)3306 XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
3307 absl::Span<const ReplicaGroup> replica_groups,
3308 const absl::optional<ChannelHandle>& channel_id,
3309 const absl::optional<Shape>& shape_with_layout) {
3310 return operand.builder()->AllReduce(operand, computation, replica_groups,
3311 channel_id, shape_with_layout);
3312 }
3313
AllToAll(const XlaOp operand,int64 split_dimension,int64 concat_dimension,int64 split_count,const std::vector<ReplicaGroup> & replica_groups)3314 XlaOp AllToAll(const XlaOp operand, int64 split_dimension,
3315 int64 concat_dimension, int64 split_count,
3316 const std::vector<ReplicaGroup>& replica_groups) {
3317 return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
3318 split_count, replica_groups);
3319 }
3320
CollectivePermute(const XlaOp operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)3321 XlaOp CollectivePermute(
3322 const XlaOp operand,
3323 const std::vector<std::pair<int64, int64>>& source_target_pairs) {
3324 return operand.builder()->CollectivePermute(operand, source_target_pairs);
3325 }
3326
ReplicaId(XlaBuilder * builder)3327 XlaOp ReplicaId(XlaBuilder* builder) { return builder->ReplicaId(); }
3328
SelectAndScatter(const XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding,const XlaOp source,const XlaOp init_value,const XlaComputation & scatter)3329 XlaOp SelectAndScatter(const XlaOp operand, const XlaComputation& select,
3330 absl::Span<const int64> window_dimensions,
3331 absl::Span<const int64> window_strides, Padding padding,
3332 const XlaOp source, const XlaOp init_value,
3333 const XlaComputation& scatter) {
3334 return operand.builder()->SelectAndScatter(operand, select, window_dimensions,
3335 window_strides, padding, source,
3336 init_value, scatter);
3337 }
3338
SelectAndScatterWithGeneralPadding(const XlaOp operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const XlaOp source,const XlaOp init_value,const XlaComputation & scatter)3339 XlaOp SelectAndScatterWithGeneralPadding(
3340 const XlaOp operand, const XlaComputation& select,
3341 absl::Span<const int64> window_dimensions,
3342 absl::Span<const int64> window_strides,
3343 absl::Span<const std::pair<int64, int64>> padding, const XlaOp source,
3344 const XlaOp init_value, const XlaComputation& scatter) {
3345 return operand.builder()->SelectAndScatterWithGeneralPadding(
3346 operand, select, window_dimensions, window_strides, padding, source,
3347 init_value, scatter);
3348 }
3349
Abs(const XlaOp operand)3350 XlaOp Abs(const XlaOp operand) {
3351 return operand.builder()->UnaryOp(HloOpcode::kAbs, operand);
3352 }
3353
Atan2(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3354 XlaOp Atan2(const XlaOp lhs, const XlaOp rhs,
3355 absl::Span<const int64> broadcast_dimensions) {
3356 return lhs.builder()->BinaryOp(HloOpcode::kAtan2, lhs, rhs,
3357 broadcast_dimensions);
3358 }
3359
Exp(const XlaOp operand)3360 XlaOp Exp(const XlaOp operand) {
3361 return operand.builder()->UnaryOp(HloOpcode::kExp, operand);
3362 }
Expm1(const XlaOp operand)3363 XlaOp Expm1(const XlaOp operand) {
3364 return operand.builder()->UnaryOp(HloOpcode::kExpm1, operand);
3365 }
Floor(const XlaOp operand)3366 XlaOp Floor(const XlaOp operand) {
3367 return operand.builder()->UnaryOp(HloOpcode::kFloor, operand);
3368 }
Ceil(const XlaOp operand)3369 XlaOp Ceil(const XlaOp operand) {
3370 return operand.builder()->UnaryOp(HloOpcode::kCeil, operand);
3371 }
Round(const XlaOp operand)3372 XlaOp Round(const XlaOp operand) {
3373 return operand.builder()->UnaryOp(HloOpcode::kRoundNearestAfz, operand);
3374 }
Log(const XlaOp operand)3375 XlaOp Log(const XlaOp operand) {
3376 return operand.builder()->UnaryOp(HloOpcode::kLog, operand);
3377 }
Log1p(const XlaOp operand)3378 XlaOp Log1p(const XlaOp operand) {
3379 return operand.builder()->UnaryOp(HloOpcode::kLog1p, operand);
3380 }
Sign(const XlaOp operand)3381 XlaOp Sign(const XlaOp operand) {
3382 return operand.builder()->UnaryOp(HloOpcode::kSign, operand);
3383 }
Clz(const XlaOp operand)3384 XlaOp Clz(const XlaOp operand) {
3385 return operand.builder()->UnaryOp(HloOpcode::kClz, operand);
3386 }
Cos(const XlaOp operand)3387 XlaOp Cos(const XlaOp operand) {
3388 return operand.builder()->UnaryOp(HloOpcode::kCos, operand);
3389 }
Sin(const XlaOp operand)3390 XlaOp Sin(const XlaOp operand) {
3391 return operand.builder()->UnaryOp(HloOpcode::kSin, operand);
3392 }
Tanh(const XlaOp operand)3393 XlaOp Tanh(const XlaOp operand) {
3394 return operand.builder()->UnaryOp(HloOpcode::kTanh, operand);
3395 }
Real(const XlaOp operand)3396 XlaOp Real(const XlaOp operand) {
3397 return operand.builder()->UnaryOp(HloOpcode::kReal, operand);
3398 }
Imag(const XlaOp operand)3399 XlaOp Imag(const XlaOp operand) {
3400 return operand.builder()->UnaryOp(HloOpcode::kImag, operand);
3401 }
Sqrt(const XlaOp operand)3402 XlaOp Sqrt(const XlaOp operand) {
3403 return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand);
3404 }
Rsqrt(const XlaOp operand)3405 XlaOp Rsqrt(const XlaOp operand) {
3406 return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand);
3407 }
3408
Pow(const XlaOp lhs,const XlaOp rhs,absl::Span<const int64> broadcast_dimensions)3409 XlaOp Pow(const XlaOp lhs, const XlaOp rhs,
3410 absl::Span<const int64> broadcast_dimensions) {
3411 return lhs.builder()->BinaryOp(HloOpcode::kPower, lhs, rhs,
3412 broadcast_dimensions);
3413 }
3414
IsFinite(const XlaOp operand)3415 XlaOp IsFinite(const XlaOp operand) {
3416 return operand.builder()->UnaryOp(HloOpcode::kIsFinite, operand);
3417 }
3418
ConvertElementType(const XlaOp operand,PrimitiveType new_element_type)3419 XlaOp ConvertElementType(const XlaOp operand, PrimitiveType new_element_type) {
3420 return operand.builder()->ConvertElementType(operand, new_element_type);
3421 }
3422
BitcastConvertType(const XlaOp operand,PrimitiveType new_element_type)3423 XlaOp BitcastConvertType(const XlaOp operand, PrimitiveType new_element_type) {
3424 return operand.builder()->BitcastConvertType(operand, new_element_type);
3425 }
3426
Neg(const XlaOp operand)3427 XlaOp Neg(const XlaOp operand) {
3428 return operand.builder()->UnaryOp(HloOpcode::kNegate, operand);
3429 }
3430
Transpose(const XlaOp operand,absl::Span<const int64> permutation)3431 XlaOp Transpose(const XlaOp operand, absl::Span<const int64> permutation) {
3432 return operand.builder()->Transpose(operand, permutation);
3433 }
3434
Rev(const XlaOp operand,absl::Span<const int64> dimensions)3435 XlaOp Rev(const XlaOp operand, absl::Span<const int64> dimensions) {
3436 return operand.builder()->Rev(operand, dimensions);
3437 }
3438
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64 dimension,bool is_stable)3439 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
3440 int64 dimension, bool is_stable) {
3441 return operands[0].builder()->Sort(operands, comparator, dimension,
3442 is_stable);
3443 }
3444
Clamp(const XlaOp min,const XlaOp operand,const XlaOp max)3445 XlaOp Clamp(const XlaOp min, const XlaOp operand, const XlaOp max) {
3446 return min.builder()->Clamp(min, operand, max);
3447 }
3448
Map(XlaBuilder * builder,absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64> dimensions,absl::Span<const XlaOp> static_operands)3449 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
3450 const XlaComputation& computation, absl::Span<const int64> dimensions,
3451 absl::Span<const XlaOp> static_operands) {
3452 return builder->Map(operands, computation, dimensions, static_operands);
3453 }
3454
RngNormal(const XlaOp mu,const XlaOp sigma,const Shape & shape)3455 XlaOp RngNormal(const XlaOp mu, const XlaOp sigma, const Shape& shape) {
3456 return mu.builder()->RngNormal(mu, sigma, shape);
3457 }
3458
RngUniform(const XlaOp a,const XlaOp b,const Shape & shape)3459 XlaOp RngUniform(const XlaOp a, const XlaOp b, const Shape& shape) {
3460 return a.builder()->RngUniform(a, b, shape);
3461 }
3462
While(const XlaComputation & condition,const XlaComputation & body,const XlaOp init)3463 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
3464 const XlaOp init) {
3465 return init.builder()->While(condition, body, init);
3466 }
3467
Conditional(const XlaOp predicate,const XlaOp true_operand,const XlaComputation & true_computation,const XlaOp false_operand,const XlaComputation & false_computation)3468 XlaOp Conditional(const XlaOp predicate, const XlaOp true_operand,
3469 const XlaComputation& true_computation,
3470 const XlaOp false_operand,
3471 const XlaComputation& false_computation) {
3472 return predicate.builder()->Conditional(predicate, true_operand,
3473 true_computation, false_operand,
3474 false_computation);
3475 }
3476
Conditional(const XlaOp branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)3477 XlaOp Conditional(const XlaOp branch_index,
3478 absl::Span<const XlaComputation* const> branch_computations,
3479 absl::Span<const XlaOp> branch_operands) {
3480 return branch_index.builder()->Conditional(branch_index, branch_computations,
3481 branch_operands);
3482 }
3483
ReducePrecision(const XlaOp operand,const int exponent_bits,const int mantissa_bits)3484 XlaOp ReducePrecision(const XlaOp operand, const int exponent_bits,
3485 const int mantissa_bits) {
3486 return operand.builder()->ReducePrecision(operand, exponent_bits,
3487 mantissa_bits);
3488 }
3489
Gather(const XlaOp input,const XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)3490 XlaOp Gather(const XlaOp input, const XlaOp start_indices,
3491 const GatherDimensionNumbers& dimension_numbers,
3492 absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
3493 return input.builder()->Gather(input, start_indices, dimension_numbers,
3494 slice_sizes, indices_are_sorted);
3495 }
3496
Scatter(const XlaOp input,const XlaOp scatter_indices,const XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)3497 XlaOp Scatter(const XlaOp input, const XlaOp scatter_indices,
3498 const XlaOp updates, const XlaComputation& update_computation,
3499 const ScatterDimensionNumbers& dimension_numbers,
3500 bool indices_are_sorted, bool unique_indices) {
3501 return input.builder()->Scatter(input, scatter_indices, updates,
3502 update_computation, dimension_numbers,
3503 indices_are_sorted, unique_indices);
3504 }
3505
Send(const XlaOp operand,const ChannelHandle & handle)3506 void Send(const XlaOp operand, const ChannelHandle& handle) {
3507 return operand.builder()->Send(operand, handle);
3508 }
3509
Recv(XlaBuilder * builder,const Shape & shape,const ChannelHandle & handle)3510 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
3511 const ChannelHandle& handle) {
3512 return builder->Recv(shape, handle);
3513 }
3514
SendWithToken(const XlaOp operand,const XlaOp token,const ChannelHandle & handle)3515 XlaOp SendWithToken(const XlaOp operand, const XlaOp token,
3516 const ChannelHandle& handle) {
3517 return operand.builder()->SendWithToken(operand, token, handle);
3518 }
3519
RecvWithToken(const XlaOp token,const Shape & shape,const ChannelHandle & handle)3520 XlaOp RecvWithToken(const XlaOp token, const Shape& shape,
3521 const ChannelHandle& handle) {
3522 return token.builder()->RecvWithToken(token, shape, handle);
3523 }
3524
SendToHost(const XlaOp operand,const XlaOp token,const Shape & shape_with_layout,const ChannelHandle & handle)3525 XlaOp SendToHost(const XlaOp operand, const XlaOp token,
3526 const Shape& shape_with_layout, const ChannelHandle& handle) {
3527 return operand.builder()->SendToHost(operand, token, shape_with_layout,
3528 handle);
3529 }
3530
RecvFromHost(const XlaOp token,const Shape & shape,const ChannelHandle & handle)3531 XlaOp RecvFromHost(const XlaOp token, const Shape& shape,
3532 const ChannelHandle& handle) {
3533 return token.builder()->RecvFromHost(token, shape, handle);
3534 }
3535
InfeedWithToken(const XlaOp token,const Shape & shape,const string & config)3536 XlaOp InfeedWithToken(const XlaOp token, const Shape& shape,
3537 const string& config) {
3538 return token.builder()->InfeedWithToken(token, shape, config);
3539 }
3540
OutfeedWithToken(const XlaOp operand,const XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)3541 XlaOp OutfeedWithToken(const XlaOp operand, const XlaOp token,
3542 const Shape& shape_with_layout,
3543 const string& outfeed_config) {
3544 return operand.builder()->OutfeedWithToken(operand, token, shape_with_layout,
3545 outfeed_config);
3546 }
3547
CreateToken(XlaBuilder * builder)3548 XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); }
3549
AfterAll(XlaBuilder * builder,absl::Span<const XlaOp> tokens)3550 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens) {
3551 return builder->AfterAll(tokens);
3552 }
3553
BatchNormTraining(const XlaOp operand,const XlaOp scale,const XlaOp offset,float epsilon,int64 feature_index)3554 XlaOp BatchNormTraining(const XlaOp operand, const XlaOp scale,
3555 const XlaOp offset, float epsilon,
3556 int64 feature_index) {
3557 return operand.builder()->BatchNormTraining(operand, scale, offset, epsilon,
3558 feature_index);
3559 }
3560
BatchNormInference(const XlaOp operand,const XlaOp scale,const XlaOp offset,const XlaOp mean,const XlaOp variance,float epsilon,int64 feature_index)3561 XlaOp BatchNormInference(const XlaOp operand, const XlaOp scale,
3562 const XlaOp offset, const XlaOp mean,
3563 const XlaOp variance, float epsilon,
3564 int64 feature_index) {
3565 return operand.builder()->BatchNormInference(
3566 operand, scale, offset, mean, variance, epsilon, feature_index);
3567 }
3568
BatchNormGrad(const XlaOp operand,const XlaOp scale,const XlaOp batch_mean,const XlaOp batch_var,const XlaOp grad_output,float epsilon,int64 feature_index)3569 XlaOp BatchNormGrad(const XlaOp operand, const XlaOp scale,
3570 const XlaOp batch_mean, const XlaOp batch_var,
3571 const XlaOp grad_output, float epsilon,
3572 int64 feature_index) {
3573 return operand.builder()->BatchNormGrad(operand, scale, batch_mean, batch_var,
3574 grad_output, epsilon, feature_index);
3575 }
3576
Iota(XlaBuilder * builder,PrimitiveType type,int64 size)3577 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) {
3578 return builder->Iota(type, size);
3579 }
3580
Iota(XlaBuilder * builder,const Shape & shape,int64 iota_dimension)3581 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) {
3582 return builder->Iota(shape, iota_dimension);
3583 }
3584
GetDimensionSize(const XlaOp operand,int64 dimension)3585 XlaOp GetDimensionSize(const XlaOp operand, int64 dimension) {
3586 return operand.builder()->GetDimensionSize(operand, dimension);
3587 }
3588
SetDimensionSize(const XlaOp operand,const XlaOp val,int64 dimension)3589 XlaOp SetDimensionSize(const XlaOp operand, const XlaOp val, int64 dimension) {
3590 return operand.builder()->SetDimensionSize(operand, val, dimension);
3591 }
3592
3593 } // namespace xla
3594