1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
18
19 #include <cstdint>
20 #include <deque>
21 #include <functional>
22 #include <map>
23 #include <memory>
24 #include <optional>
25 #include <ostream>
26 #include <string>
27 #include <type_traits>
28 #include <utility>
29 #include <vector>
30
31 #include "absl/container/flat_hash_map.h"
32 #include "absl/container/flat_hash_set.h"
33 #include "absl/strings/string_view.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/client/padding.h"
36 #include "tensorflow/compiler/xla/client/xla_computation.h"
37 #include "tensorflow/compiler/xla/comparison_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/literal_util.h"
40 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
41 #include "tensorflow/compiler/xla/service/hlo.pb.h"
42 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
43 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
44 #include "tensorflow/compiler/xla/shape_util.h"
45 #include "tensorflow/compiler/xla/status_macros.h"
46 #include "tensorflow/compiler/xla/statusor.h"
47 #include "tensorflow/compiler/xla/types.h"
48 #include "tensorflow/compiler/xla/xla_data.pb.h"
49 #include "tensorflow/core/platform/stacktrace.h"
50
51 namespace xla {
52
53 class XlaBuilder;
54 class XlaOp;
55 class HloInstruction;
56
57 namespace internal {
58
59 struct XlaBuilderFriend {
60 static XlaOp BuildAddDependency(XlaBuilder* builder, XlaOp operand,
61 XlaOp token, const Shape& shape);
62
63 static XlaOp BuildFusion(XlaBuilder* builder,
64 absl::Span<const XlaOp> operands,
65 absl::string_view fusion_kind,
66 const XlaComputation& fused_computation);
67
68 static XlaOp BuildBitcast(XlaBuilder* builder, XlaOp operand,
69 const Shape& shape);
70
71 static XlaOp BuildPartitionId(XlaBuilder* builder, const Shape& shape);
72
73 static XlaOp BuildDomain(XlaBuilder* builder, XlaOp operand,
74 const OpSharding entry, const OpSharding exit,
75 const Shape& shape);
76
77 static XlaOp BuildRngGetAndUpdateState(XlaBuilder* builder, int64_t delta,
78 const Shape& shape);
79
80 static HloInstructionProto* GetInstruction(XlaOp op);
81 static HloInstructionProto* GetInstructionByHandle(XlaBuilder* builder,
82 int64_t handle);
83 };
84
85 } // namespace internal
86
87 // This represents an instruction that has been enqueued using the XlaBuilder.
88 // This is used to pass to subsequent computations that depends upon the
89 // instruction as an operand.
90 class XlaOp {
91 public:
XlaOp()92 XlaOp() : handle_(-1), builder_(nullptr) {
93 static_assert(std::is_trivially_destructible<XlaOp>::value,
94 "XlaOp should be trivially destructible");
95 }
96 ~XlaOp() = default;
97
98 XlaOp(const XlaOp& other) = default;
99 XlaOp& operator=(const XlaOp& other) = default;
100
101 // Precondition: !IsUninitialized().
102 //
103 // It's very common to do foo.builder()->bar(). Without this precondition, if
104 // foo.builder() is null, the call to bar will segfault at some point possibly
105 // deep in the callstack when we finally dereference `this`. The precondition
106 // lets us avoid this tricky-to-debug problem.
builder()107 XlaBuilder* builder() const {
108 CHECK(builder_ != nullptr);
109 return builder_;
110 }
111
112 // Returns true if the XlaOp represents valid, non-erroneous value.
valid()113 bool valid() const { return handle_ >= 0; }
114
115 // Returns true if the XlaOp was created by the XlaOp() constructor and
116 // not returned by a builder.
IsUninitialized()117 bool IsUninitialized() const { return builder_ == nullptr; }
118
IsIdenticalTo(XlaOp rhs)119 bool IsIdenticalTo(XlaOp rhs) const {
120 return handle_ == rhs.handle_ && builder_ == rhs.builder_;
121 }
122
123 friend std::ostream& operator<<(std::ostream& out, XlaOp op) {
124 out << op.handle();
125 return out;
126 }
127
128 private:
XlaOp(XlaBuilder * builder)129 explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {}
XlaOp(int64_t handle,XlaBuilder * builder)130 XlaOp(int64_t handle, XlaBuilder* builder)
131 : handle_(handle), builder_(builder) {}
132
handle()133 int64_t handle() const { return handle_; }
134
135 friend class XlaBuilder;
136 friend class ValueInference;
137 friend class MlirHloBuilder;
138 friend struct internal::XlaBuilderFriend;
139
140 // < 0 means "invalid handle".
141 int64_t handle_;
142
143 // Not owned. Non-null for any handle returned by XlaBuilder, even if the
144 // handle is invalid.
145 XlaBuilder* builder_;
146 };
147
148 // Arithmetic operator overloads for the XlaOp type.
149 XlaOp operator-(XlaOp x);
150 XlaOp operator+(XlaOp x, XlaOp y);
151 XlaOp operator-(XlaOp x, XlaOp y);
152 XlaOp operator*(XlaOp x, XlaOp y);
153 XlaOp operator/(XlaOp x, XlaOp y);
154 XlaOp operator%(XlaOp x, XlaOp y);
155
156 // Bitwise operator overloads for the XlaOp type.
157 XlaOp operator~(XlaOp x);
158 XlaOp operator&(XlaOp x, XlaOp y);
159 XlaOp operator|(XlaOp x, XlaOp y);
160 XlaOp operator^(XlaOp x, XlaOp y);
161 XlaOp operator<<(XlaOp x, XlaOp y);
162 // Performs a right arithmetic shift if 'x' is a signed type, otherwise performs
163 // a right logical shift.
164 XlaOp operator>>(XlaOp x, XlaOp y);
165
166 // We don't overload the relational operators (==, !=, <, <=, >, >=) because the
167 // semantics might be surprising since their result types are usually 'bool'.
168 // Further programmers may expect == to be a structural equality.
169 // We also choose not to overload any of the mutating operators (e.g., +=, -=)
170 // because the semantics might be misleading — XLA computations are immutable.
171
172 // A convenient interface for building up computations.
173 //
174 // Thread-compatible.
175 class XlaBuilder {
176 public:
177 // computation_name: name to use for the built computation.
178 explicit XlaBuilder(const std::string& computation_name);
179
180 XlaBuilder(const XlaBuilder&) = delete;
181 XlaBuilder& operator=(const XlaBuilder&) = delete;
182
183 virtual ~XlaBuilder();
184
185 // Returns the computation name.
name()186 const std::string& name() const { return name_; }
187
188 // Sets OpMetadata that will be added to all instructions until cleared.
189 //
190 // OpMetadata is often applied to a series of XLA HLO instructions. As a
191 // result, OpMetadata is set on the computation builder. All subsequent
192 // instructions generated via this computation builder will have the same
193 // OpMetadata attached until a call to ClearOpMetadata.
SetOpMetadata(OpMetadata metadata)194 void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); }
195
196 // Swaps the passed op metadata with the ones currently set.
197 //
198 // Returns the old op metadata.
SwapOpMetadata(OpMetadata metadata)199 OpMetadata SwapOpMetadata(OpMetadata metadata) {
200 OpMetadata old_metadata = std::move(metadata_);
201 metadata_ = std::move(metadata);
202 return old_metadata;
203 }
204
205 // Similar to SetOpMetadata, but only set the metadata for the next op.
SetOneShotOpMetadata(OpMetadata metadata)206 void SetOneShotOpMetadata(OpMetadata metadata) {
207 one_shot_metadata_ = std::move(metadata);
208 }
209
210 // Clears the HloMetadata state.
ClearOpMetadata()211 void ClearOpMetadata() { metadata_.Clear(); }
212
213 // Sets an OpSharding that will be attached to all instructions until cleared.
SetSharding(const OpSharding & sharding)214 void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
215
216 // Sets the FrontendAttributes that will be added to all instructions until
217 // cleared.
218 //
219 // FrontendAttributes are often applied to a series of XLA HLO instructions.
220 // As a result they are set on the computation builder and all the
221 // instructions generated via the computation builder will have the same
222 // frontend attributes attached to them.
SetFrontendAttributes(const FrontendAttributes & frontend_attributes)223 void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) {
224 frontend_attributes_ = frontend_attributes;
225 }
226
227 // Swap the passed FrontendAttributes with the ones currently set.
228 //
229 // Return the old attributes.
SwapFrontendAttributes(const FrontendAttributes & frontend_attributes)230 FrontendAttributes SwapFrontendAttributes(
231 const FrontendAttributes& frontend_attributes) {
232 FrontendAttributes old_attributes = std::move(frontend_attributes_);
233 frontend_attributes_ = frontend_attributes;
234 return old_attributes;
235 }
236
237 // Returns the FrontendAttributes that will be attached to all instructions.
frontend_attributes()238 const FrontendAttributes& frontend_attributes() const {
239 return frontend_attributes_;
240 }
241
242 // Clears all the frontend attributes.
ClearFrontendAttributes()243 void ClearFrontendAttributes() { frontend_attributes_.Clear(); }
244
245 // Clears the sharding. Ops will be sharded according to the default placement
246 // policy.
ClearSharding()247 void ClearSharding() { sharding_ = std::nullopt; }
248
249 // Returns the OpSharding that will be attached to all instructions.
sharding()250 const std::optional<OpSharding>& sharding() const { return sharding_; }
251
252 // Sets the builder to a mode where it will die immediately when an error is
253 // encountered, rather than producing it in a deferred fashion when Build() is
254 // called (which is the default).
set_die_immediately_on_error(bool enabled)255 void set_die_immediately_on_error(bool enabled) {
256 die_immediately_on_error_ = enabled;
257 }
258
259 // Default dimension numbers used for a 2D convolution.
260 static constexpr int64_t kConvBatchDimension = 0;
261 static constexpr int64_t kConvFeatureDimension = 1;
262 static constexpr int64_t kConvFirstSpatialDimension = 2;
263 static constexpr int64_t kConvSecondSpatialDimension = 3;
264 static constexpr int64_t kConvKernelOutputDimension = 0;
265 static constexpr int64_t kConvKernelInputDimension = 1;
266 static constexpr int64_t kConvKernelFirstSpatialDimension = 2;
267 static constexpr int64_t kConvKernelSecondSpatialDimension = 3;
268
269 // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
270 // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
271 // the kernel operand
272 // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
273 static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
274 int num_spatial_dims = 2);
275
276 // Returns an error if the convolution dimension numbers have conflicts.
277 static Status Validate(const ConvolutionDimensionNumbers& dnum);
278
279 // Returns a new XlaBuilder whose resultant Computation is used only by this
280 // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
281 // behavior as the parent.
282 std::unique_ptr<XlaBuilder> CreateSubBuilder(
283 const std::string& computation_name);
284
285 // Builds the computation with the requested operations, or returns a non-ok
286 // status. Note that all ops that have been enqueued will be moved to the
287 // computation being returned. The root of the computation will be the last
288 // added operation.
289 //
290 // `remove_dynamic_dimensions` tells the builder whether to remove the
291 // dynamic dimensions information in all ops.
292 //
293 // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the
294 // dynamic dimensions information when XLA backend can handle dynamic
295 // dimensions.
296 StatusOr<XlaComputation> Build(bool remove_dynamic_dimensions = false);
297
298 // Overload of Build which specifies a particular root instruction for the
299 // computation.
300 StatusOr<XlaComputation> Build(XlaOp root,
301 bool remove_dynamic_dimensions = false);
302
303 // Builds the computation with the requested operations, or notes an error in
304 // the parent XlaBuilder and returns an empty computation if building failed.
305 // This function is intended to be used where the returned XlaComputation is
306 // only used by the parent XlaBuilder and hence further operation on the
307 // returned XlaComputation will simply be error'ed out if an error occurred
308 // while building this computation. If the built computation is to be used by
309 // a XlaBuilder other than the parent XlaBuilder then Build() should be used
310 // instead.
311 XlaComputation BuildAndNoteError();
312
313 // Returns a subgraph that roots on the given root. If the root is not a
314 // compile-time constant (see `IsConstant`), returns an error.
315 //
316 // This will copy the needed ops/computations to the subgraph.
317 StatusOr<XlaComputation> BuildConstantSubGraph(
318 XlaOp root_op, bool dynamic_dimension_is_uint_max = false);
319
320 // Returns the first error that was encountered while building the
321 // computation. When an error is encountered, by default we return a vacuous
322 // XlaOp and inform the user of the error that occurred while
323 // building the computation when they make a final call to Build().
324 //
325 // See also set_die_immediately_on_error().
first_error()326 Status first_error() const { return first_error_; }
327
328 // Returns the current status of the builder, complete with the stack trace
329 // information.
330 Status GetCurrentStatus() const;
331
332 // Returns the shape of the given op.
333 StatusOr<Shape> GetShape(XlaOp op) const;
334
335 // Returns the shape of the given op.
336 virtual StatusOr<const Shape*> GetShapePtr(XlaOp op) const;
337
338 // Returns the (inferred) result for the current computation's shape. This
339 // assumes the root instruction is the last added instruction.
340 StatusOr<ProgramShape> GetProgramShape() const;
341
342 // Returns the (inferred) result for the current computation's shape using the
343 // given operation as the root.
344 StatusOr<ProgramShape> GetProgramShape(XlaOp root) const;
345
346 // Reports an error to the builder, by
347 // * storing it internally and capturing a backtrace if it's the first error
348 // (this deferred value will be produced on the call to
349 // Build()/GetShape()/...)
350 // * dying if die_immediately_on_error_ is true.
351 // Returns an XlaOp with an invalid handle but a valid builder. This value can
352 // be returned in place of a value in APIs that return an XlaOp.
353 XlaOp ReportError(const Status& error);
354
355 // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
356 // If the Status was an error, reports the error to builder and returns an
357 // invalid XlaOp handle.
358 XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
359
360 // A helper function that runs a function that returns a StatusOr<XlaOp> and
361 // returns an XlaOp.
362 XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
363
364 // Returns true if 'operand' is a compile-time constant. A compile-time
365 // constant does not depend on any parameters, or on stateful operators such
366 // as `RngNormal` or `Infeed`.
367 //
368 // This tests whether a computation is a compile-time constant without
369 // evaluating the computation.
370 StatusOr<bool> IsConstant(XlaOp operand) const;
371
372 // Sets up binding which indicates that the `target_dim_num` in the subshape
373 // `target_param_index` of parameter `target_param_num` is a dynamic dimension
374 // and its real dynamic size is represented by `dynamic_param_index` in
375 // parameter `dynamic_param_num`.
376 //
377 // Note that this should be called before the dynamic parameters are used to
378 // create other operations, otherwise created operations won't have the
379 // dynamic dimensions information.
380 //
381 // TODO(b/119520625): Remove this API once we have more dynamic shape infra
382 // ready.
383 ABSL_DEPRECATED("Use SetDimensionSize to set a dynamic dimension.")
384 Status SetDynamicBinding(int64_t dynamic_size_param_num,
385 ShapeIndex dynamic_size_param_index,
386 int64_t target_param_num,
387 ShapeIndex target_param_index,
388 int64_t target_dim_num);
389
390 // Adds a new input/output alias. Since the input/output shape information are
391 // not available until the computation is built, and eventual error in the
392 // arguments of this API will be detected only at computation Build() time.
393 //
394 // Note: Except when 'must-alias' is true, alias is assumed to be 'may-alias'
395 // and only donated buffer at runtime will be aliased with output. If a buffer
396 // is not donated at runtime, a copy will be inserted by XLA to prevent buffer
397 // clobbering.
398 void SetUpAlias(const ShapeIndex& output_index, int64_t param_number,
399 const ShapeIndex& param_index,
400 HloInputOutputAliasConfig::AliasKind kind =
401 HloInputOutputAliasConfig::AliasKind::kMayAlias) {
402 input_output_aliases_.push_back(
403 {output_index, param_number, param_index, kind});
404 }
405
406 // Describes an input/output alias as inserted by the SetUpAlias() API.
407 struct InputOutputAlias {
408 // Specifies the index of the aliased buffer in the result tuple.
409 ShapeIndex output_index;
410 // Specifies the parameter containing the buffer to be aliased.
411 int64_t param_number;
412 // Specifies the index of the aliased buffer in the parameter
413 ShapeIndex param_index;
414 // Specifies if the alias is a must alias or may alias.
415 HloInputOutputAliasConfig::AliasKind kind;
416 };
417
418 // Looks up the HloInstruction and sets the frontend attribute "attribute" to
419 // "value".
420 //
421 // If the attribute already existed then its value is updated.
422 //
423 // Note: the attribute is only added to the HloInstruction, not to the
424 // builder.
425 Status SetInstructionFrontendAttribute(XlaOp op, std::string attribute,
426 std::string value);
427
428 // Returns shapes for the operands.
429 StatusOr<std::vector<Shape>> GetOperandShapes(
430 absl::Span<const XlaOp> operands) const;
431
432 // Converts the op to string for the ease of debugging.
433 std::string OpToString(XlaOp op) const;
434
435 private:
436 void ToStringHelper(std::string* out, int ident, int64_t op_handle) const;
437
438 // Build helper which takes the id of the root operation..
439 StatusOr<XlaComputation> Build(int64_t root_id,
440 bool remove_dynamic_dimensions);
441
442 // Description for the methods below can be found in the corresponding public
443 // functions section in this file.
444
445 XlaOp Parameter(int64_t parameter_number, const Shape& shape,
446 const std::string& name,
447 const std::vector<bool>& replicated_at_leaf_buffers);
Parameter(int64_t parameter_number,const Shape & shape,const std::string & name)448 XlaOp Parameter(int64_t parameter_number, const Shape& shape,
449 const std::string& name) {
450 std::vector<bool> empty_bools;
451 return Parameter(parameter_number, shape, name, empty_bools);
452 }
453
454 virtual XlaOp ConstantLiteral(const LiteralSlice& literal);
455
456 XlaOp Broadcast(XlaOp operand, absl::Span<const int64_t> broadcast_sizes);
457
458 XlaOp BroadcastInDim(XlaOp operand,
459 const absl::Span<const int64_t> out_dim_size,
460 const absl::Span<const int64_t> broadcast_dimensions);
461
462 XlaOp Pad(XlaOp operand, XlaOp padding_value,
463 const PaddingConfig& padding_config);
464 XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
465 int64_t pad_lo, int64_t pad_hi);
466
467 virtual StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
468 XlaOp padding_value,
469 const PaddingConfig& padding_config);
470
471 XlaOp Reshape(XlaOp operand, absl::Span<const int64_t> dimensions,
472 absl::Span<const int64_t> new_sizes,
473 int64_t inferred_dimension = -1);
474
475 XlaOp Reshape(XlaOp operand, absl::Span<const int64_t> new_sizes,
476 int64_t inferred_dimension = -1);
477
478 XlaOp Reshape(const Shape& shape, XlaOp operand,
479 int64_t inferred_dimension = -1);
480
481 XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
482 absl::Span<const int64_t> new_size_bounds,
483 const std::vector<bool>& dims_are_dynamic);
484
485 XlaOp Collapse(XlaOp operand, absl::Span<const int64_t> dimensions);
486
487 XlaOp Slice(XlaOp operand, absl::Span<const int64_t> start_indices,
488 absl::Span<const int64_t> limit_indices,
489 absl::Span<const int64_t> strides);
490 virtual StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand,
491 absl::Span<const int64_t> start_indices,
492 absl::Span<const int64_t> limit_indices,
493 absl::Span<const int64_t> strides);
494 virtual XlaOp SliceInDim(XlaOp operand, int64_t start_index,
495 int64_t limit_index, int64_t stride, int64_t dimno);
496
497 XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
498 absl::Span<const int64_t> slice_sizes);
499 virtual StatusOr<XlaOp> DynamicSliceInternal(
500 const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
501 absl::Span<const int64_t> slice_sizes);
502
503 XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
504 absl::Span<const XlaOp> start_indices);
505 virtual StatusOr<XlaOp> DynamicUpdateSliceInternal(
506 const Shape& shape, XlaOp operand, XlaOp update,
507 absl::Span<const XlaOp> start_indices);
508
509 XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64_t dimension);
510 virtual StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
511 absl::Span<const XlaOp> operands,
512 int64_t dimension);
513
514 XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
515
516 XlaOp Tuple(absl::Span<const XlaOp> elements);
517 virtual StatusOr<XlaOp> TupleInternal(const Shape& shape,
518 absl::Span<const XlaOp> elements);
519
520 XlaOp GetTupleElement(XlaOp tuple_data, int64_t index);
521 virtual StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape,
522 XlaOp tuple_data,
523 int64_t index);
524
525 XlaOp Dot(XlaOp lhs, XlaOp rhs,
526 const PrecisionConfig* precision_config = nullptr,
527 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
528
529 XlaOp DotGeneral(
530 XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
531 const PrecisionConfig* precision_config = nullptr,
532 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
533
534 XlaOp Conv(
535 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
536 Padding padding, int64_t feature_group_count = 1,
537 int64_t batch_group_count = 1,
538 const PrecisionConfig* precision_config = nullptr,
539 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
540
541 XlaOp ConvWithGeneralPadding(
542 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
543 absl::Span<const std::pair<int64_t, int64_t>> padding,
544 int64_t feature_group_count = 1, int64_t batch_group_count = 1,
545 const PrecisionConfig* precision_config = nullptr,
546 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
547
548 XlaOp ConvWithGeneralDimensions(
549 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
550 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
551 int64_t feature_group_count = 1, int64_t batch_group_count = 1,
552 const PrecisionConfig* precision_config = nullptr,
553 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
554
555 XlaOp ConvGeneral(
556 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
557 absl::Span<const std::pair<int64_t, int64_t>> padding,
558 const ConvolutionDimensionNumbers& dimension_numbers,
559 int64_t feature_group_count = 1, int64_t batch_group_count = 1,
560 const PrecisionConfig* precision_config = nullptr,
561 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
562
563 XlaOp ConvGeneralDilated(
564 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
565 absl::Span<const std::pair<int64_t, int64_t>> padding,
566 absl::Span<const int64_t> lhs_dilation,
567 absl::Span<const int64_t> rhs_dilation,
568 const ConvolutionDimensionNumbers& dimension_numbers,
569 int64_t feature_group_count = 1, int64_t batch_group_count = 1,
570 const PrecisionConfig* precision_config = nullptr,
571 std::optional<PrimitiveType> preferred_element_type = std::nullopt,
572 std::optional<std::vector<bool>> window_reversal = std::nullopt);
573
574 XlaOp DynamicConvForward(
575 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
576 absl::Span<const std::pair<int64_t, int64_t>> padding,
577 absl::Span<const int64_t> lhs_dilation,
578 absl::Span<const int64_t> rhs_dilation,
579 const ConvolutionDimensionNumbers& dimension_numbers,
580 int64_t feature_group_count, int64_t batch_group_count,
581 const PrecisionConfig* precision_config, PaddingType padding_type,
582 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
583
584 XlaOp DynamicConvInputGrad(
585 XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
586 absl::Span<const int64_t> window_strides,
587 absl::Span<const std::pair<int64_t, int64_t>> padding,
588 absl::Span<const int64_t> lhs_dilation,
589 absl::Span<const int64_t> rhs_dilation,
590 const ConvolutionDimensionNumbers& dimension_numbers,
591 int64_t feature_group_count, int64_t batch_group_count,
592 const PrecisionConfig* precision_config, PaddingType padding_type,
593 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
594
595 XlaOp DynamicConvKernelGrad(
596 XlaOp activations, XlaOp gradients,
597 absl::Span<const int64_t> window_strides,
598 absl::Span<const std::pair<int64_t, int64_t>> padding,
599 absl::Span<const int64_t> lhs_dilation,
600 absl::Span<const int64_t> rhs_dilation,
601 const ConvolutionDimensionNumbers& dimension_numbers,
602 int64_t feature_group_count, int64_t batch_group_count,
603 const PrecisionConfig* precision_config, PaddingType padding_type,
604 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
605
606 StatusOr<HloInstructionProto> DynamicConvInstruction(
607 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
608 absl::Span<const std::pair<int64_t, int64_t>> padding,
609 absl::Span<const int64_t> lhs_dilation,
610 absl::Span<const int64_t> rhs_dilation,
611 const ConvolutionDimensionNumbers& dimension_numbers,
612 int64_t feature_group_count, int64_t batch_group_count,
613 const PrecisionConfig* precision_config, PaddingType padding_type,
614 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
615
616 virtual StatusOr<XlaOp> ConvGeneralDilatedInternal(
617 const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
618 absl::Span<const int64_t> window_strides,
619 absl::Span<const std::pair<int64_t, int64_t>> padding,
620 absl::Span<const int64_t> lhs_dilation,
621 absl::Span<const int64_t> rhs_dilation,
622 const ConvolutionDimensionNumbers& dimension_numbers,
623 int64_t feature_group_count, int64_t batch_group_count,
624 const PrecisionConfig* precision_config);
625
626 XlaOp Fft(XlaOp operand, FftType fft_type,
627 absl::Span<const int64_t> fft_length);
628 virtual StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand,
629 FftType fft_type,
630 absl::Span<const int64_t> fft_length);
631
632 virtual StatusOr<XlaOp> TriangularSolveInternal(
633 const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options);
634
635 virtual StatusOr<XlaOp> CholeskyInternal(const Shape& shape, XlaOp a,
636 bool lower);
637
638 XlaOp Infeed(const Shape& shape, const std::string& config = "");
639 XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
640 const std::string& config);
641 virtual StatusOr<XlaOp> InfeedWithTokenInternal(
642 const Shape& infeed_instruction_shape, XlaOp token,
643 const std::string& config);
644
645 void Outfeed(XlaOp operand, const Shape& shape_with_layout,
646 const std::string& outfeed_config);
647 XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
648 const Shape& shape_with_layout,
649 const std::string& outfeed_config);
650 virtual StatusOr<XlaOp> OutfeedWithTokenInternal(
651 XlaOp operand, XlaOp token, const Shape& shape_with_layout,
652 const std::string& outfeed_config);
653 XlaOp Call(const XlaComputation& computation,
654 absl::Span<const XlaOp> operands);
655
656 XlaOp CustomCall(
657 const std::string& call_target_name, absl::Span<const XlaOp> operands,
658 const Shape& shape_with_layout, const std::string& opaque,
659 std::optional<absl::Span<const Shape>> operand_shapes_with_layout,
660 bool has_side_effect,
661 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
662 output_operand_aliasing,
663 const Literal* literal, std::optional<Window> window,
664 std::optional<ConvolutionDimensionNumbers> dnums,
665 CustomCallSchedule schedule, CustomCallApiVersion api_version);
666
667 // Internal version of CustomCall without computation that doesn't do op
668 // specific error handling and expects arguments to be legal. CustomCall
669 // method above calls this method after error handling.
670 virtual StatusOr<XlaOp> CustomCallInternal(
671 const std::string& call_target_name, absl::Span<const XlaOp> operands,
672 const XlaComputation* computation, const Shape& shape_with_layout,
673 const std::string& opaque,
674 std::optional<absl::Span<const Shape>> operand_shapes_with_layout,
675 bool has_side_effect,
676 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
677 output_operand_aliasing,
678 const Literal* literal, std::optional<Window> window,
679 std::optional<ConvolutionDimensionNumbers> dnums,
680 CustomCallSchedule schedule, CustomCallApiVersion api_version);
681
682 // TODO(b/239474321) Remove this overload as it has simply led to code
683 // duplication.
684 XlaOp CustomCall(
685 const std::string& call_target_name, absl::Span<const XlaOp> operands,
686 const XlaComputation& computation, const Shape& shape_with_layout,
687 const std::string& opaque,
688 std::optional<absl::Span<const Shape>> operand_shapes_with_layout,
689 bool has_side_effect,
690 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
691 output_operand_aliasing,
692 const Literal* literal, CustomCallSchedule schedule,
693 CustomCallApiVersion api_version);
694
695 XlaOp OptimizationBarrier(XlaOp operand);
696
697 XlaOp Reduce(XlaOp operand, XlaOp init_value,
698 const XlaComputation& computation,
699 absl::Span<const int64_t> dimensions_to_reduce);
700
701 XlaOp Reduce(absl::Span<const XlaOp> operands,
702 absl::Span<const XlaOp> init_values,
703 const XlaComputation& computation,
704 absl::Span<const int64_t> dimensions_to_reduce);
705
706 virtual StatusOr<XlaOp> ReduceInternal(
707 const Shape& shape, absl::Span<const XlaOp> all_operands,
708 const XlaComputation& computation,
709 absl::Span<const int64_t> dimensions_to_reduce);
710
711 XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
712 const XlaComputation& computation);
713
714 XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
715 const XlaComputation& computation,
716 absl::Span<const int64_t> window_dimensions,
717 absl::Span<const int64_t> window_strides, Padding padding);
718
719 XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
720 absl::Span<const XlaOp> init_values,
721 const XlaComputation& computation,
722 absl::Span<const int64_t> window_dimensions,
723 absl::Span<const int64_t> window_strides, Padding padding);
724
725 XlaOp ReduceWindowWithGeneralPadding(
726 absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
727 const XlaComputation& computation,
728 absl::Span<const int64_t> window_dimensions,
729 absl::Span<const int64_t> window_strides,
730 absl::Span<const int64_t> base_dilations,
731 absl::Span<const int64_t> window_dilations,
732 absl::Span<const std::pair<int64_t, int64_t>> padding);
733 StatusOr<HloInstructionProto> ReduceWindowInternal(
734 absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
735 const XlaComputation& computation,
736 absl::Span<const int64_t> window_dimensions,
737 absl::Span<const int64_t> window_strides,
738 absl::Span<const int64_t> base_dilations,
739 absl::Span<const int64_t> window_dilations,
740 absl::Span<const std::pair<int64_t, int64_t>> padding);
741 virtual StatusOr<XlaOp> ReduceWindowInternal(
742 const Shape& shape, XlaOp operand, XlaOp init_value,
743 const XlaComputation& computation, Window window);
744 XlaOp CrossReplicaSum(XlaOp operand,
745 absl::Span<const ReplicaGroup> replica_groups = {});
746
747 XlaOp AllGather(
748 XlaOp operand, int64_t all_gather_dimension, int64_t shard_count,
749 absl::Span<const ReplicaGroup> replica_groups = {},
750 const std::optional<ChannelHandle>& channel_id = std::nullopt,
751 const std::optional<Layout>& layout = std::nullopt,
752 const std::optional<bool> use_global_device_ids = std::nullopt);
753
754 XlaOp AllReduce(
755 XlaOp operand, const XlaComputation& computation,
756 absl::Span<const ReplicaGroup> replica_groups = {},
757 const std::optional<ChannelHandle>& channel_id = std::nullopt,
758 const std::optional<Shape>& shape_with_layout = std::nullopt,
759 const std::optional<bool> use_global_device_ids = std::nullopt);
760
761 XlaOp ReduceScatter(
762 XlaOp operand, const XlaComputation& computation,
763 int64_t scatter_dimension, int64_t shard_count,
764 absl::Span<const ReplicaGroup> replica_groups = {},
765 const std::optional<ChannelHandle>& channel_id = std::nullopt,
766 const std::optional<Layout>& layout = std::nullopt,
767 const std::optional<bool> use_global_device_ids = std::nullopt);
768
769 XlaOp AllToAll(XlaOp operand, int64_t split_dimension,
770 int64_t concat_dimension, int64_t split_count,
771 absl::Span<const ReplicaGroup> replica_groups,
772 const std::optional<Layout>& layout = std::nullopt);
773
774 XlaOp AllToAllTuple(absl::Span<const XlaOp> operands,
775 absl::Span<const ReplicaGroup> replica_groups,
776 const std::optional<Layout>& layout);
777
778 XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension,
779 int64_t concat_dimension, int64_t split_count,
780 absl::Span<const ReplicaGroup> replica_groups,
781 const std::optional<Layout>& layout);
782
783 XlaOp CollectivePermute(
784 XlaOp operand,
785 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs);
786
787 XlaOp ReplicaId();
788
789 XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
790 absl::Span<const int64_t> window_dimensions,
791 absl::Span<const int64_t> window_strides,
792 Padding padding, XlaOp source, XlaOp init_value,
793 const XlaComputation& scatter);
794
795 XlaOp SelectAndScatterWithGeneralPadding(
796 XlaOp operand, const XlaComputation& select,
797 absl::Span<const int64_t> window_dimensions,
798 absl::Span<const int64_t> window_strides,
799 absl::Span<const std::pair<int64_t, int64_t>> padding, XlaOp source,
800 XlaOp init_value, const XlaComputation& scatter);
801
802 StatusOr<HloInstructionProto> SelectAndScatterInternal(
803 XlaOp operand, const XlaComputation& select,
804 absl::Span<const int64_t> window_dimensions,
805 absl::Span<const int64_t> window_strides,
806 absl::Span<const std::pair<int64_t, int64_t>> padding, XlaOp source,
807 XlaOp init_value, const XlaComputation& scatter);
808
809 virtual XlaOp Iota(const Shape& shape, int64_t iota_dimension);
810
811 XlaOp Iota(PrimitiveType type, int64_t size);
812
813 XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
814
815 XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
816 virtual StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape,
817 XlaOp operand);
818
819 XlaOp Transpose(XlaOp operand, absl::Span<const int64_t> permutation);
820 virtual StatusOr<XlaOp> TransposeInternal(
821 const Shape& shape, XlaOp operand, absl::Span<const int64_t> permutation);
822
823 XlaOp Rev(XlaOp operand, absl::Span<const int64_t> dimensions);
824 virtual StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand,
825 absl::Span<const int64_t> dimensions);
826
827 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
828 int64_t dimension = -1, bool is_stable = false);
829 virtual StatusOr<XlaOp> SortInternal(const Shape& shape,
830 absl::Span<const XlaOp> operands,
831 const XlaComputation& comparator,
832 int64_t dimension, bool is_stable);
833
834 XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
835
836 XlaOp Map(absl::Span<const XlaOp> operands, const XlaComputation& computation,
837 absl::Span<const int64_t> dimensions,
838 absl::Span<const XlaOp> static_operands = {});
839
840 XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
841
842 XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
843
844 XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
845 const Shape& shape);
846 // Internal variant for the op with the full result shape containing both data
847 // and state shape as a tuple.
848 virtual StatusOr<XlaOp> RngBitGeneratorInternal(
849 const Shape& full_result_shape, RandomAlgorithm algorithm,
850 XlaOp initial_state);
851
852 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
853 XlaOp init);
854 virtual StatusOr<XlaOp> WhileInternal(const Shape& shape,
855 const XlaComputation& condition,
856 const XlaComputation& body, XlaOp init);
857
858 XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
859 const XlaComputation& true_computation, XlaOp false_operand,
860 const XlaComputation& false_computation);
861
862 XlaOp Conditional(XlaOp branch_index,
863 absl::Span<const XlaComputation* const> branch_computations,
864 absl::Span<const XlaOp> branch_operands);
865
866 XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
867 const int mantissa_bits);
868 virtual StatusOr<XlaOp> ReducePrecisionInternal(const Shape& shape,
869 XlaOp operand,
870 const int exponent_bits,
871 const int mantissa_bits);
872
873 XlaOp Gather(XlaOp input, XlaOp start_indices,
874 const GatherDimensionNumbers& dimension_numbers,
875 absl::Span<const int64_t> slice_sizes,
876 bool indices_are_sorted = false);
877
878 virtual StatusOr<XlaOp> GatherInternal(
879 const Shape& shape, XlaOp input, XlaOp start_indices,
880 const GatherDimensionNumbers& dimension_numbers,
881 absl::Span<const int64_t> slice_sizes, bool indices_are_sorted);
882
883 XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
884 const XlaComputation& update_computation,
885 const ScatterDimensionNumbers& dimension_numbers,
886 bool indices_are_sorted = false, bool unique_indices = false);
887 XlaOp Scatter(absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
888 absl::Span<const XlaOp> updates,
889 const XlaComputation& update_computation,
890 const ScatterDimensionNumbers& dimension_numbers,
891 bool indices_are_sorted = false, bool unique_indices = false);
892
893 virtual StatusOr<XlaOp> ScatterInternal(
894 const Shape& shape, absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
895 absl::Span<const XlaOp> updates, const XlaComputation& update_computation,
896 const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
897 bool unique_indices);
898
899 void Send(XlaOp operand, const ChannelHandle& handle);
900 XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle);
901
902 XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout,
903 const ChannelHandle& handle);
904
905 XlaOp RecvFromHost(XlaOp token, const Shape& shape,
906 const ChannelHandle& handle);
907
908 virtual XlaOp CreateToken();
909
910 XlaOp AfterAll(absl::Span<const XlaOp> tokens);
911
912 XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
913 XlaOp RecvWithToken(XlaOp token, const Shape& shape,
914 const ChannelHandle& handle);
915
916 XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
917 float epsilon, int64_t feature_index);
918
919 XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean,
920 XlaOp variance, float epsilon,
921 int64_t feature_index);
922
923 XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
924 XlaOp batch_var, XlaOp grad_output, float epsilon,
925 int64_t feature_index);
926
927 XlaOp GetDimensionSize(XlaOp operand, int64_t dimension);
928
929 XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension);
930
931 virtual StatusOr<XlaOp> SetDimensionSizeInternal(const Shape& shape,
932 XlaOp operand, XlaOp val,
933 int64_t dimension);
934
935 XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension);
936
937 virtual StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr,
938 HloOpcode opcode,
939 absl::Span<const XlaOp> operands);
AddInstruction(HloInstructionProto && instr,HloOpcode opcode)940 StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr,
941 HloOpcode opcode) {
942 return AddInstruction(std::move(instr), opcode, /*operands=*/{});
943 }
944
945 void AddCalledComputation(const XlaComputation& computation,
946 HloInstructionProto* instr);
947
948 StatusOr<const HloInstructionProto*> LookUpInstruction(XlaOp op) const;
949 StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
950 int64_t handle) const;
951 StatusOr<HloInstructionProto*> LookUpMutableInstruction(XlaOp op);
952 StatusOr<HloInstructionProto*> LookUpMutableInstructionByHandle(
953 int64_t handle);
954
955 // Internal helper method that does the building for an arbitrary unary op.
956 virtual XlaOp UnaryOp(HloOpcode unop, XlaOp operand);
957
958 // Internal helper method that does the building for an arbitrary binary op.
959 // broadcast_dimensions specifies which dimensions to use for broadcasting
960 // when the operation is between tensors of different ranks. The direction is
961 // only used if opcode is kCompare.
962 XlaOp BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
963 absl::Span<const int64_t> broadcast_dimensions,
964 std::optional<ComparisonDirection> direction = std::nullopt,
965 std::optional<Comparison::Type> type = std::nullopt);
966
967 StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
968 ComparisonDirection direction);
969
970 // Internal helper method for binary op compare without broadcast dimensions.
971 virtual StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
972 ComparisonDirection direction,
973 Comparison::Type type);
974
975 // Internal helper method that does the building for an arbitrary binary op
976 // with same ranked operands that doesn't broadcast.
977 virtual XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
978 XlaOp lhs, XlaOp rhs);
979
980 // Internal helper method that does the building for an arbitrary ternary op.
981 XlaOp TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs);
982
983 XlaOp RngOp(RandomDistribution distribution,
984 absl::Span<const XlaOp> parameters, const Shape& shape);
985
986 virtual StatusOr<XlaOp> RngOpInternal(RandomDistribution distribution,
987 absl::Span<const XlaOp> parameters,
988 const Shape& shape);
989
990 virtual StatusOr<XlaOp> InDimBroadcast(
991 const Shape& shape, XlaOp operand,
992 absl::Span<const int64_t> broadcast_dimensions);
993
994 // Internal helper method that creates a sequence of instructions that
995 // performs an explicit broadcast of the operand to the target shape.
996 StatusOr<XlaOp> AddBroadcastSequence(const Shape& output_shape,
997 XlaOp operand);
998
999 // Internal helper method for creating a Reshape op with the already inferred
1000 // shape.
1001 virtual StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand,
1002 int64_t inferred_dimension);
1003
1004 // Returns the (inferred) result for the program shape using the given root.
1005 StatusOr<ProgramShape> GetProgramShape(int64_t root_id) const;
1006
1007 // A visitor which checks whether an operation is a compile-time constant,
1008 // meaning that it doesn't depend on any parameters, or on any stateful
1009 // operation such as `RngNormal` or `Infeed`. The visitor walks the
1010 // computation starting at a given operation and sets is_constant to false iff
1011 // a parameter or stateful operation is encountered.
1012 void IsConstantVisitor(const int64_t op_handle, int depth,
1013 absl::flat_hash_set<int64_t>* visited,
1014 bool* is_constant) const;
1015
1016 // Checks bounds for convolution parameters.
1017 Status VerifyConvolution(
1018 const Shape& lhs_shape, const Shape& rhs_shape,
1019 const ConvolutionDimensionNumbers& dimension_numbers) const;
1020
GetNextId()1021 int64_t GetNextId() { return ++next_id_; }
1022
1023 // Populates the module with the input/output alias information stored within
1024 // the input_output_aliases vector.
1025 static Status PopulateInputOutputAlias(
1026 HloModuleProto* module, const ProgramShape& program_shape,
1027 const std::vector<InputOutputAlias>& input_output_aliases);
1028
1029 std::string name_; // Name to use for the built computation.
1030
1031 // The next sequential ID for every instruction/computation contained within
1032 // this computation.
1033 int64_t next_id_ = 0;
1034
1035 // The first error encountered while building the computation.
1036 // This is OK until the first error is encountered.
1037 Status first_error_;
1038
1039 // The saved stack trace from the point at which the first error occurred.
1040 tensorflow::SavedStackTrace first_error_backtrace_;
1041
1042 // The instructions of this computation.
1043 // Use a deque so pointers into this are stable, for example the return
1044 // value of LookUpInstructionByHandle().
1045 std::deque<HloInstructionProto> instructions_;
1046 // An cache for the HloInstructionProto shapes, to avoid recreating Shape
1047 // objects from protos and to support the GetShapePtr() API.
1048 std::vector<std::unique_ptr<Shape>> instruction_shapes_;
1049
1050 // Dynamic parameter configuration of this computation.
1051 DynamicParameterBinding dynamic_parameter_binding_;
1052
1053 // Holds the input/output alias information populated by the SetUpAlias() API.
1054 std::vector<InputOutputAlias> input_output_aliases_;
1055
1056 // A map from XlaOp::Handle to the index in the instructions_ vector where the
1057 // instruction is held.
1058 absl::flat_hash_map<int64_t, int64_t> handle_to_index_;
1059
1060 // Track imported instructions by their computation id and the position in
1061 // their computation's instruction list.
1062 struct ImportedInstruction {
1063 int64_t computation_id;
1064 int64_t instruction_index;
1065 };
1066
1067 absl::flat_hash_map<int64_t, ImportedInstruction> handle_to_imported_index_;
1068
1069 // The embedded computations used by this computation. Each computation was
1070 // the entry computation of some XlaComputation, the key is the unique id of
1071 // that XlaComputation.
1072 std::map<int64_t, HloComputationProto> embedded_;
1073
1074 // The unique parameter numbers.
1075 absl::flat_hash_set<int64_t> parameter_numbers_;
1076
1077 // The metadata to attach to each op. This is structured as a "modal"-like
1078 // operation, in order to simplify client code (and not sprinkle this metadata
1079 // throughout the TensorFlow op kernel implementations).
1080 OpMetadata metadata_;
1081
1082 // A temporary metadata that will only be applied to the next op created.
1083 std::optional<OpMetadata> one_shot_metadata_;
1084
1085 // Sharding for this operator. This is structured as a "model"-like operation,
1086 // in order to simplify client code, similar to metadata_.
1087 std::optional<OpSharding> sharding_;
1088
1089 // Mode bit that indicates whether to die when a first error is encountered.
1090 bool die_immediately_on_error_ = false;
1091
1092 XlaBuilder* parent_builder_{nullptr};
1093
1094 FrontendAttributes frontend_attributes_;
1095
1096 friend XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
1097 const Shape& shape, const std::string& name,
1098 const std::vector<bool>& replicated_at_leaf_buffers);
1099 friend XlaOp ConstantLiteral(XlaBuilder* builder,
1100 const LiteralSlice& literal);
1101
1102 friend XlaOp Broadcast(XlaOp operand,
1103 absl::Span<const int64_t> broadcast_sizes);
1104
1105 friend XlaOp BroadcastInDim(
1106 XlaOp operand, const absl::Span<const int64_t> out_dim_size,
1107 const absl::Span<const int64_t> broadcast_dimensions);
1108
1109 friend XlaOp Copy(XlaOp operand);
1110
1111 friend XlaOp Pad(XlaOp operand, XlaOp padding_value,
1112 const PaddingConfig& padding_config);
1113
1114 friend XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
1115 int64_t pad_lo, int64_t pad_hi);
1116
1117 friend XlaOp Reshape(XlaOp operand, absl::Span<const int64_t> dimensions,
1118 absl::Span<const int64_t> new_sizes);
1119
1120 friend XlaOp Reshape(XlaOp operand, absl::Span<const int64_t> new_sizes);
1121
1122 friend XlaOp Reshape(const Shape& shape, XlaOp operand);
1123
1124 friend XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
1125 absl::Span<const int64_t> new_size_bounds,
1126 const std::vector<bool>& dims_are_dynamic);
1127
1128 friend XlaOp ReshapeWithInferredDimension(XlaOp operand,
1129 absl::Span<const int64_t> new_sizes,
1130 int64_t inferred_dimension);
1131
1132 friend XlaOp Collapse(XlaOp operand, absl::Span<const int64_t> dimensions);
1133
1134 friend XlaOp Slice(XlaOp operand, absl::Span<const int64_t> start_indices,
1135 absl::Span<const int64_t> limit_indices,
1136 absl::Span<const int64_t> strides);
1137
1138 friend XlaOp SliceInDim(XlaOp operand, int64_t start_index,
1139 int64_t limit_index, int64_t stride, int64_t dimno);
1140
1141 friend XlaOp DynamicSlice(XlaOp operand,
1142 absl::Span<const XlaOp> start_indices,
1143 absl::Span<const int64_t> slice_sizes);
1144
1145 friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
1146 absl::Span<const XlaOp> start_indices);
1147
1148 friend XlaOp ConcatInDim(XlaBuilder* builder,
1149 absl::Span<const XlaOp> operands, int64_t dimension);
1150
1151 friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
1152 friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
1153 friend XlaOp GetTupleElement(XlaOp tuple_data, int64_t index);
1154 friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
1155 absl::Span<const int64_t> broadcast_dimensions,
1156 ComparisonDirection direction);
1157 friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
1158 absl::Span<const int64_t> broadcast_dimensions,
1159 ComparisonDirection direction,
1160 Comparison::Type compare_type);
1161 friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
1162 const PrecisionConfig* precision_config,
1163 std::optional<PrimitiveType> preferred_element_type);
1164 friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
1165 const DotDimensionNumbers& dimension_number,
1166 const PrecisionConfig* precision_config,
1167 std::optional<PrimitiveType> preferred_element_type);
1168 virtual StatusOr<XlaOp> DotGeneralInternal(
1169 const Shape& shape, XlaOp lhs, XlaOp rhs,
1170 const DotDimensionNumbers& dimension_number,
1171 const PrecisionConfig* precision_config);
1172 friend XlaOp Conv(XlaOp lhs, XlaOp rhs,
1173 absl::Span<const int64_t> window_strides, Padding padding,
1174 int64_t feature_group_count, int64_t batch_group_count,
1175 const PrecisionConfig* precision_config,
1176 std::optional<PrimitiveType> preferred_element_type);
1177 friend XlaOp ConvWithGeneralPadding(
1178 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1179 absl::Span<const std::pair<int64_t, int64_t>> padding,
1180 int64_t feature_group_count, int64_t batch_group_count,
1181 const PrecisionConfig* precision_config,
1182 std::optional<PrimitiveType> preferred_element_type);
1183 friend XlaOp ConvWithGeneralDimensions(
1184 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1185 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1186 int64_t feature_group_count, int64_t batch_group_count,
1187 const PrecisionConfig* precision_config,
1188 std::optional<PrimitiveType> preferred_element_type);
1189 friend XlaOp ConvGeneral(
1190 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1191 absl::Span<const std::pair<int64_t, int64_t>> padding,
1192 const ConvolutionDimensionNumbers& dimension_numbers,
1193 int64_t feature_group_count, int64_t batch_group_count,
1194 const PrecisionConfig* precision_config,
1195 std::optional<PrimitiveType> preferred_element_type);
1196 friend XlaOp DynamicConvForward(
1197 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1198 absl::Span<const std::pair<int64_t, int64_t>> padding,
1199 absl::Span<const int64_t> lhs_dilation,
1200 absl::Span<const int64_t> rhs_dilation,
1201 const ConvolutionDimensionNumbers& dimension_numbers,
1202 int64_t feature_group_count, int64_t batch_group_count,
1203 const PrecisionConfig* precision_config, PaddingType padding_type,
1204 std::optional<PrimitiveType> preferred_element_type);
1205 friend XlaOp DynamicConvKernelGrad(
1206 XlaOp activations, XlaOp gradients,
1207 absl::Span<const int64_t> window_strides,
1208 absl::Span<const std::pair<int64_t, int64_t>> padding,
1209 absl::Span<const int64_t> lhs_dilation,
1210 absl::Span<const int64_t> rhs_dilation,
1211 const ConvolutionDimensionNumbers& dimension_numbers,
1212 int64_t feature_group_count, int64_t batch_group_count,
1213 const PrecisionConfig* precision_config, PaddingType padding_type,
1214 std::optional<PrimitiveType> preferred_element_type);
1215 friend XlaOp DynamicConvInputGrad(
1216 XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
1217 absl::Span<const int64_t> window_strides,
1218 absl::Span<const std::pair<int64_t, int64_t>> padding,
1219 absl::Span<const int64_t> lhs_dilation,
1220 absl::Span<const int64_t> rhs_dilation,
1221 const ConvolutionDimensionNumbers& dimension_numbers,
1222 int64_t feature_group_count, int64_t batch_group_count,
1223 const PrecisionConfig* precision_config, PaddingType padding_type,
1224 std::optional<PrimitiveType> preferred_element_type);
1225
1226 friend XlaOp ConvKernelGrad(
1227 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1228 absl::Span<const std::pair<int64_t, int64_t>> padding,
1229 absl::Span<const int64_t> lhs_dilation,
1230 absl::Span<const int64_t> rhs_dilation,
1231 const ConvolutionDimensionNumbers& dimension_numbers,
1232 int64_t feature_group_count, int64_t batch_group_count,
1233 const PrecisionConfig* precision_config,
1234 std::optional<PrimitiveType> preferred_element_type);
1235
1236 friend XlaOp ConvGeneralDilated(
1237 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1238 absl::Span<const std::pair<int64_t, int64_t>> padding,
1239 absl::Span<const int64_t> lhs_dilation,
1240 absl::Span<const int64_t> rhs_dilation,
1241 const ConvolutionDimensionNumbers& dimension_numbers,
1242 int64_t feature_group_count, int64_t batch_group_count,
1243 const PrecisionConfig* precision_config,
1244 std::optional<PrimitiveType> preferred_element_type,
1245 std::optional<std::vector<bool>> window_reversal);
1246
1247 friend XlaOp Fft(XlaOp operand, FftType fft_type,
1248 absl::Span<const int64_t> fft_length);
1249 friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
1250 bool unit_diagonal,
1251 TriangularSolveOptions::Transpose transpose_a);
1252 friend XlaOp Cholesky(XlaOp a, bool lower);
1253 friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
1254 const std::string& config);
1255 friend void Outfeed(XlaOp operand, const Shape& shape_with_layout,
1256 const std::string& outfeed_config);
1257 friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
1258 absl::Span<const XlaOp> operands);
1259 friend XlaOp CustomCall(
1260 XlaBuilder* builder, const std::string& call_target_name,
1261 absl::Span<const XlaOp> operands, const Shape& shape,
1262 const std::string& opaque, bool has_side_effect,
1263 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
1264 output_operand_aliasing,
1265 const Literal* literal, CustomCallSchedule schedule,
1266 CustomCallApiVersion api_version);
1267 friend XlaOp CustomCallWithComputation(
1268 XlaBuilder* builder, const std::string& call_target_name,
1269 absl::Span<const XlaOp> operands, const XlaComputation& computation,
1270 const Shape& shape, const std::string& opaque, bool has_side_effect,
1271 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
1272 output_operand_aliasing,
1273 const Literal* literal, CustomCallSchedule schedule,
1274 CustomCallApiVersion api_version);
1275 friend XlaOp CustomCallWithLayout(
1276 XlaBuilder* builder, const std::string& call_target_name,
1277 absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
1278 absl::Span<const Shape> operand_shapes_with_layout,
1279 const std::string& opaque, bool has_side_effect,
1280 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
1281 output_operand_aliasing,
1282 const Literal* literal, CustomCallSchedule schedule,
1283 CustomCallApiVersion api_version);
1284 friend XlaOp CustomCallWithConvDnums(
1285 XlaBuilder* builder, const std::string& call_target_name,
1286 absl::Span<const XlaOp> operands, const Shape& shape,
1287 absl::Span<const Shape> operand_shapes_with_layout,
1288 const std::string& opaque, bool has_side_effect,
1289 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
1290 output_operand_aliasing,
1291 const Literal* literal, Window window, ConvolutionDimensionNumbers dnums,
1292 CustomCallSchedule schedule, CustomCallApiVersion api_version);
1293 friend XlaOp OptimizationBarrier(XlaOp operand);
1294 friend XlaOp Complex(XlaOp real, XlaOp imag,
1295 absl::Span<const int64_t> broadcast_dimensions);
1296 friend XlaOp Conj(XlaOp operand);
1297 friend XlaOp Add(XlaOp lhs, XlaOp rhs,
1298 absl::Span<const int64_t> broadcast_dimensions);
1299 friend XlaOp Sub(XlaOp lhs, XlaOp rhs,
1300 absl::Span<const int64_t> broadcast_dimensions);
1301 friend XlaOp Mul(XlaOp lhs, XlaOp rhs,
1302 absl::Span<const int64_t> broadcast_dimensions);
1303 friend XlaOp Div(XlaOp lhs, XlaOp rhs,
1304 absl::Span<const int64_t> broadcast_dimensions);
1305 friend XlaOp Rem(XlaOp lhs, XlaOp rhs,
1306 absl::Span<const int64_t> broadcast_dimensions);
1307 friend XlaOp Max(XlaOp lhs, XlaOp rhs,
1308 absl::Span<const int64_t> broadcast_dimensions);
1309 friend XlaOp Min(XlaOp lhs, XlaOp rhs,
1310 absl::Span<const int64_t> broadcast_dimensions);
1311 friend XlaOp And(XlaOp lhs, XlaOp rhs,
1312 absl::Span<const int64_t> broadcast_dimensions);
1313 friend XlaOp Or(XlaOp lhs, XlaOp rhs,
1314 absl::Span<const int64_t> broadcast_dimensions);
1315 friend XlaOp Xor(XlaOp lhs, XlaOp rhs,
1316 absl::Span<const int64_t> broadcast_dimensions);
1317 friend XlaOp Not(XlaOp operand);
1318 friend XlaOp PopulationCount(XlaOp operand);
1319 friend XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs,
1320 absl::Span<const int64_t> broadcast_dimensions);
1321 friend XlaOp ShiftRightArithmetic(
1322 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> broadcast_dimensions);
1323 friend XlaOp ShiftRightLogical(
1324 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> broadcast_dimensions);
1325 friend XlaOp Reduce(XlaOp operand, XlaOp init_value,
1326 const XlaComputation& computation,
1327 absl::Span<const int64_t> dimensions_to_reduce);
1328 friend XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1329 absl::Span<const XlaOp> init_values,
1330 const XlaComputation& computation,
1331 absl::Span<const int64_t> dimensions_to_reduce);
1332 friend XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
1333 const XlaComputation& computation);
1334 friend XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
1335 const XlaComputation& computation,
1336 absl::Span<const int64_t> window_dimensions,
1337 absl::Span<const int64_t> window_strides,
1338 Padding padding);
1339 friend XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
1340 absl::Span<const XlaOp> init_values,
1341 const XlaComputation& computation,
1342 absl::Span<const int64_t> window_dimensions,
1343 absl::Span<const int64_t> window_strides,
1344 Padding padding);
1345 friend XlaOp ReduceWindowWithGeneralPadding(
1346 XlaOp operand, XlaOp init_value, const XlaComputation& computation,
1347 absl::Span<const int64_t> window_dimensions,
1348 absl::Span<const int64_t> window_strides,
1349 absl::Span<const int64_t> base_dilations,
1350 absl::Span<const int64_t> window_dilations,
1351 absl::Span<const std::pair<int64_t, int64_t>> padding);
1352 friend XlaOp ReduceWindowWithGeneralPadding(
1353 absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
1354 const XlaComputation& computation,
1355 absl::Span<const int64_t> window_dimensions,
1356 absl::Span<const int64_t> window_strides,
1357 absl::Span<const int64_t> base_dilations,
1358 absl::Span<const int64_t> window_dilations,
1359 absl::Span<const std::pair<int64_t, int64_t>> padding);
1360
1361 friend XlaOp CrossReplicaSum(XlaOp operand,
1362 absl::Span<const ReplicaGroup> replica_groups);
1363 friend XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension,
1364 int64_t shard_count,
1365 absl::Span<const ReplicaGroup> replica_groups,
1366 const std::optional<ChannelHandle>& channel_id,
1367 const std::optional<Layout>& layout,
1368 const std::optional<bool> use_global_device_ids);
1369 friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
1370 absl::Span<const ReplicaGroup> replica_groups,
1371 const std::optional<ChannelHandle>& channel_id,
1372 const std::optional<Shape>& shape_with_layout,
1373 const std::optional<bool> use_global_device_ids);
1374 friend XlaOp ReduceScatter(XlaOp operand, const XlaComputation& computation,
1375 int64_t scatter_dimension, int64_t shard_count,
1376 absl::Span<const ReplicaGroup> replica_groups,
1377 const std::optional<ChannelHandle>& channel_id,
1378 const std::optional<Layout>& layout,
1379 const std::optional<bool> use_global_device_ids);
1380
1381 friend XlaOp AllToAll(XlaOp operand, int64_t split_dimension,
1382 int64_t concat_dimension, int64_t split_count,
1383 absl::Span<const ReplicaGroup> replica_groups,
1384 const std::optional<Layout>& layout);
1385 friend XlaOp AllToAllTuple(absl::Span<const XlaOp> operands,
1386 absl::Span<const ReplicaGroup> replica_groups,
1387 const std::optional<Layout>& layout);
1388 friend XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension,
1389 int64_t concat_dimension, int64_t split_count,
1390 absl::Span<const ReplicaGroup> replica_groups,
1391 const std::optional<Layout>& layout);
1392 friend XlaOp CollectivePermute(
1393 XlaOp operand,
1394 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs);
1395 friend XlaOp ReplicaId(XlaBuilder* builder);
1396 friend XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
1397 absl::Span<const int64_t> window_dimensions,
1398 absl::Span<const int64_t> window_strides,
1399 Padding padding, XlaOp source, XlaOp init_value,
1400 const XlaComputation& scatter);
1401 friend XlaOp SelectAndScatterWithGeneralPadding(
1402 XlaOp operand, const XlaComputation& select,
1403 absl::Span<const int64_t> window_dimensions,
1404 absl::Span<const int64_t> window_strides,
1405 absl::Span<const std::pair<int64_t, int64_t>> padding, XlaOp source,
1406 XlaOp init_value, const XlaComputation& scatter);
1407 friend XlaOp Abs(XlaOp operand);
1408 friend XlaOp Atan2(XlaOp y, XlaOp x,
1409 absl::Span<const int64_t> broadcast_dimensions);
1410 friend XlaOp Exp(XlaOp operand);
1411 friend XlaOp Expm1(XlaOp operand);
1412 friend XlaOp Floor(XlaOp operand);
1413 friend XlaOp Ceil(XlaOp operand);
1414 friend XlaOp Round(XlaOp operand);
1415 friend XlaOp RoundNearestEven(XlaOp operand);
1416 friend XlaOp Log(XlaOp operand);
1417 friend XlaOp Log1p(XlaOp operand);
1418 friend XlaOp Logistic(XlaOp operand);
1419 friend XlaOp Sign(XlaOp operand);
1420 friend XlaOp Clz(XlaOp operand);
1421 friend XlaOp Cos(XlaOp operand);
1422 friend XlaOp Sin(XlaOp operand);
1423 friend XlaOp Tanh(XlaOp operand);
1424 friend XlaOp Real(XlaOp operand);
1425 friend XlaOp Imag(XlaOp operand);
1426 friend XlaOp Sqrt(XlaOp operand);
1427 friend XlaOp Rsqrt(XlaOp operand);
1428 friend XlaOp Cbrt(XlaOp operand);
1429 friend XlaOp Pow(XlaOp lhs, XlaOp rhs,
1430 absl::Span<const int64_t> broadcast_dimensions);
1431 friend XlaOp IsFinite(XlaOp operand);
1432 friend XlaOp Iota(XlaBuilder* builder, const Shape& shape,
1433 int64_t iota_dimension);
1434 friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size);
1435 friend XlaOp ConvertElementType(XlaOp operand,
1436 PrimitiveType new_element_type);
1437 friend XlaOp BitcastConvertType(XlaOp operand,
1438 PrimitiveType new_element_type);
1439 friend XlaOp Neg(XlaOp operand);
1440 friend XlaOp Transpose(XlaOp operand, absl::Span<const int64_t> permutation);
1441 friend XlaOp Rev(XlaOp operand, absl::Span<const int64_t> dimensions);
1442 friend XlaOp Sort(absl::Span<const XlaOp> operands,
1443 const XlaComputation& comparator, int64_t dimension,
1444 bool is_stable);
1445 friend XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
1446 friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1447 const XlaComputation& computation,
1448 absl::Span<const int64_t> dimensions,
1449 absl::Span<const XlaOp> static_operands);
1450 friend XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
1451 friend XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
1452 friend XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
1453 const Shape& shape);
1454 friend XlaOp While(const XlaComputation& condition,
1455 const XlaComputation& body, XlaOp init);
1456 friend XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
1457 const XlaComputation& true_computation,
1458 XlaOp false_operand,
1459 const XlaComputation& false_computation);
1460 friend XlaOp Conditional(
1461 XlaOp branch_index,
1462 absl::Span<const XlaComputation* const> branch_computations,
1463 absl::Span<const XlaOp> branch_operands);
1464 friend XlaOp ConditionalImpl(
1465 XlaOp branch_index,
1466 absl::Span<const XlaComputation* const> branch_computations,
1467 absl::Span<const XlaOp> branch_operands);
1468 friend XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
1469 const int mantissa_bits);
1470 friend XlaOp Gather(XlaOp input, XlaOp start_indices,
1471 const GatherDimensionNumbers& dimension_numbers,
1472 absl::Span<const int64_t> slice_sizes,
1473 bool indices_are_sorted);
1474 friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
1475 const XlaComputation& update_computation,
1476 const ScatterDimensionNumbers& dimension_numbers,
1477 bool indices_are_sorted, bool unique_indices);
1478 friend XlaOp Scatter(absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
1479 absl::Span<const XlaOp> updates,
1480 const XlaComputation& update_computation,
1481 const ScatterDimensionNumbers& dimension_numbers,
1482 bool indices_are_sorted, bool unique_indices);
1483 friend void Send(XlaOp operand, const ChannelHandle& handle);
1484 friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
1485 const ChannelHandle& handle);
1486 friend XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
1487 float epsilon, int64_t feature_index);
1488 friend XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset,
1489 XlaOp mean, XlaOp variance, float epsilon,
1490 int64_t feature_index);
1491 friend XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
1492 XlaOp batch_var, XlaOp grad_output, float epsilon,
1493 int64_t feature_index);
1494 friend XlaOp SendWithToken(XlaOp operand, XlaOp token,
1495 const ChannelHandle& handle);
1496 friend XlaOp RecvWithToken(XlaOp token, const Shape& shape,
1497 const ChannelHandle& handle);
1498 friend XlaOp SendToHost(XlaOp operand, XlaOp token,
1499 const Shape& shape_with_layout,
1500 const ChannelHandle& handle);
1501 friend XlaOp RecvFromHost(XlaOp token, const Shape& shape,
1502 const ChannelHandle& handle);
1503 friend XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
1504 const std::string& config);
1505 friend XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
1506 const Shape& shape_with_layout,
1507 const std::string& outfeed_config);
1508 friend XlaOp CreateToken(XlaBuilder* builder);
1509 friend XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
1510
1511 friend XlaOp GetDimensionSize(XlaOp operand, int64_t dimension);
1512 friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension);
1513 friend XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension);
1514
1515 protected:
1516 // Returns OK status if the given op was built using this builder. Otherwise,
1517 // returns an error.
1518 Status CheckOpBuilder(XlaOp op) const;
1519
1520 private:
1521 XlaOp ConditionalImpl(
1522 XlaOp branch_index,
1523 absl::Span<const XlaComputation* const> branch_computations,
1524 absl::Span<const XlaOp> branch_operands);
1525
1526 XlaOp AllToAllArray(XlaOp operand, int64_t split_dimension,
1527 int64_t concat_dimension, int64_t split_count,
1528 absl::Span<const ReplicaGroup> replica_groups);
1529
1530 // Creates an op with the given opcode and the output shape.
1531 virtual StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
1532 absl::Span<const XlaOp> operands);
1533
1534 // Here, InstructionType is either const HloInstructionProto* or non-const
1535 // HloInstructionProto*.
1536 template <typename InstructionType>
LookUpInstructionByHandleInternal(int64_t handle)1537 StatusOr<InstructionType> LookUpInstructionByHandleInternal(
1538 int64_t handle) const {
1539 auto it = handle_to_index_.find(handle);
1540 if (it == handle_to_index_.end()) {
1541 // Try look for the instruction in the imported instructions.
1542 auto imported_it = handle_to_imported_index_.find(handle);
1543 if (imported_it != handle_to_imported_index_.end()) {
1544 ImportedInstruction imported = imported_it->second;
1545 return const_cast<InstructionType>(
1546 &embedded_.at(imported.computation_id)
1547 .instructions(imported.instruction_index));
1548 }
1549 return InvalidArgument("No XlaOp with handle %d", handle);
1550 }
1551 return const_cast<InstructionType>(&instructions_.at(it->second));
1552 }
1553
1554 // Here, InstructionType is either const HloInstructionProto* or non-const
1555 // HloInstructionProto*.
1556 //
1557 // TODO(hinsu): Return const pointer within StatusOr and use
1558 // absl::implicit_cast at callsites. This requires implicit_cast support in
1559 // stream_executor::port::StatusOr similar to absl::StatusOr.
1560 template <typename InstructionType>
LookUpInstructionInternal(XlaOp op)1561 StatusOr<InstructionType> LookUpInstructionInternal(XlaOp op) const {
1562 TF_RETURN_IF_ERROR(CheckOpBuilder(op));
1563 return LookUpInstructionByHandleInternal<InstructionType>(op.handle());
1564 }
1565
1566 friend struct internal::XlaBuilderFriend;
1567
1568 friend class ValueInference;
1569 };
1570
1571 // RAII-style object: sets the current sharding assignment in builder on
1572 // construction, and sets back to the previous assignment on destruction.
1573 class XlaScopedShardingAssignment {
1574 public:
XlaScopedShardingAssignment(xla::XlaBuilder * builder,std::optional<OpSharding> sharding)1575 XlaScopedShardingAssignment(xla::XlaBuilder* builder,
1576 std::optional<OpSharding> sharding)
1577 : builder_(builder), prev_sharding_(builder->sharding()) {
1578 SetSharding(sharding);
1579 }
1580
1581 XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
1582 XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
1583 delete;
1584
~XlaScopedShardingAssignment()1585 ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
1586
1587 private:
SetSharding(const std::optional<OpSharding> & sharding)1588 void SetSharding(const std::optional<OpSharding>& sharding) {
1589 if (sharding.has_value()) {
1590 builder_->SetSharding(sharding.value());
1591 } else {
1592 builder_->ClearSharding();
1593 }
1594 }
1595
1596 xla::XlaBuilder* const builder_;
1597 std::optional<OpSharding> prev_sharding_;
1598 };
1599
1600 // RAII-style object: save the current builder's frontend attributes, and merge
1601 // them with the new ones on construction.
1602 // Restore the original attributes on destruction.
1603 class XlaScopedFrontendAttributesAssignment {
1604 public:
XlaScopedFrontendAttributesAssignment(xla::XlaBuilder * builder,FrontendAttributes attributes)1605 XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder,
1606 FrontendAttributes attributes)
1607 : builder_(builder) {
1608 saved_ = builder_->SwapFrontendAttributes(attributes);
1609 }
1610
~XlaScopedFrontendAttributesAssignment()1611 ~XlaScopedFrontendAttributesAssignment() {
1612 builder_->SetFrontendAttributes(saved_);
1613 }
1614
1615 private:
1616 xla::XlaBuilder* const builder_;
1617 FrontendAttributes saved_;
1618
1619 XlaScopedFrontendAttributesAssignment(
1620 const XlaScopedFrontendAttributesAssignment&) = delete;
1621 XlaScopedFrontendAttributesAssignment& operator=(
1622 const XlaScopedFrontendAttributesAssignment&) = delete;
1623 };
1624
1625 // RAII-style object: sets the current op metadata in builder on construction,
1626 // and sets back to the previous assignment on destruction.
1627 class XlaScopedOpMetadataAssignment {
1628 public:
XlaScopedOpMetadataAssignment(xla::XlaBuilder * builder,OpMetadata metadata)1629 XlaScopedOpMetadataAssignment(xla::XlaBuilder* builder, OpMetadata metadata)
1630 : builder_(builder) {
1631 saved_ = builder_->SwapOpMetadata(metadata);
1632 }
1633
~XlaScopedOpMetadataAssignment()1634 ~XlaScopedOpMetadataAssignment() { builder_->SwapOpMetadata(saved_); }
1635
1636 private:
1637 xla::XlaBuilder* const builder_;
1638 OpMetadata saved_;
1639
1640 XlaScopedOpMetadataAssignment(const XlaScopedOpMetadataAssignment&) = delete;
1641 XlaScopedOpMetadataAssignment& operator=(
1642 const XlaScopedOpMetadataAssignment&) = delete;
1643 };
1644
1645 // Free functions for building XlaOps. The intention is that these will
1646 // become the public API for building XlaOps rather than calling methods on
1647 // XlaBuilder directly.
1648 //
1649
1650 // Enqueues a "retrieve parameter value" instruction for a parameter that was
1651 // passed to the computation.
1652 XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
1653 const Shape& shape, const std::string& name);
1654
1655 // Same as above, but with leaf buffer replication annotation.
1656 XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
1657 const Shape& shape, const std::string& name,
1658 const std::vector<bool>& replicated_at_leaf_buffers);
1659
1660 // Enqueues a constant with the value of the given literal onto the
1661 // computation.
1662 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal);
1663
1664 // Enqueues a constant onto the computation. Methods are templated on the
1665 // native host type (NativeT) which corresponds to a specific XLA
1666 // PrimitiveType as given in the following table:
1667 //
1668 // Native Type PrimitiveType
1669 // -----------------------------
1670 // bool PRED
1671 // int32_t S32
1672 // int64_t S64
1673 // uint32_t U32
1674 // uint64_t U64
1675 // float F32
1676 // double F64
1677 //
1678 // Note: not all primitive types defined in xla_data.proto have a
1679 // corresponding native type yet.
1680 template <typename NativeT>
1681 XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
1682 template <typename NativeT>
1683 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values);
1684 XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values);
1685 template <typename NativeT>
1686 XlaOp ConstantR2(XlaBuilder* builder,
1687 std::initializer_list<std::initializer_list<NativeT>> values);
1688 template <typename NativeT>
1689 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
1690 const Array<NativeT>& values,
1691 const Layout& layout);
1692 template <typename NativeT>
1693 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values);
1694 template <typename NativeT>
1695 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
1696 const Array2D<NativeT>& values,
1697 const Layout& layout);
1698 template <typename NativeT>
1699 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
1700 const Array2D<NativeT>& values);
1701 template <typename NativeT>
1702 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
1703 const Array3D<NativeT>& values,
1704 const Layout& layout);
1705 template <typename NativeT>
1706 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
1707 const Array3D<NativeT>& values);
1708 template <typename NativeT>
1709 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
1710 const Array4D<NativeT>& values,
1711 const Layout& layout);
1712 template <typename NativeT>
1713 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
1714 const Array4D<NativeT>& values);
1715
1716 // Enqueues a rank one constant (XlaBuilder* builder, vector) onto the
1717 // computation. The vector has size 'length' and every element has the value
1718 // 'value'.
1719 template <typename NativeT>
1720 XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value);
1721
1722 // Adds dimensions to an array by duplicating the data in the array.
1723 //
1724 // The new dimensions are inserted on the left, i.e. if
1725 // broadcast_sizes has values {a0, ..., aN} and the operand shape
1726 // has dimensions {b0, ..., bM} then the shape of the output has
1727 // dimensions {a0, ..., aN, b0, ..., bM}.
1728 //
1729 // The new dimensions index into copies of the operand, i.e.
1730 //
1731 // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
1732 XlaOp Broadcast(XlaOp operand, absl::Span<const int64_t> broadcast_sizes);
1733
1734 // This op broadcasts the `operand` to an output with the given `shape`.
1735 // `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the
1736 // i'th dimension of the operand is mapped to the broadcast_dimensions[i]'th
1737 // dimension of the output. This also requires that the i'th input dimension is
1738 // either 1 or is the same as the output dimension it's broadcasting into.
1739 //
1740 // For example, say operand = {1, 2}, i.e., a 1D tensor in shape s32[2]; the
1741 // output shape is s32[2,2]:
1742 // - Specifying {1} as broadcast_dimension will generate output
1743 // {{1, 2},
1744 // {1, 2}}
1745 // - On the other hand, specifying {0} as broadcast_dimension
1746 // will generate output
1747 // {{1 , 1},
1748 // {2 , 2}}
1749 XlaOp BroadcastInDim(XlaOp operand,
1750 const absl::Span<const int64_t> out_dim_size,
1751 const absl::Span<const int64_t> broadcast_dimensions);
1752
1753 // Copies the input operand to the output. This operation is for internal
1754 // purpose and is only used by the compiler for optimization purposes or to
1755 // ensure correctness. The XLA client should never have to generate this
1756 // instruction.
1757 //
1758 // Copy has two potential use cases:
1759 //
1760 // * Create a copy of the operand with a new layout.
1761 //
1762 // * Create a copy of the operand in a separately allocated buffer. This is
1763 // necessary for some backends if the operand is a parameter or constant and
1764 // the operand is returned within a tuple. In this case, the lifetime of the
1765 // operand buffer must be the same as the lifetime of the output result.
1766 // However, the lifetimes of parameters and constants are managed separately
1767 // from the lifetime of the output result. Creating a separate copy of the
1768 // parameter or constant buffer resolves this issue.
1769 XlaOp Copy(XlaOp operand);
1770
1771 // Enqueues a pad operation onto the computation that pads the given value on
1772 // the edges as well as between the elements of the input. padding_config
1773 // specifies the padding amount for each dimension.
1774 XlaOp Pad(XlaOp operand, XlaOp padding_value,
1775 const PaddingConfig& padding_config);
1776
1777 // Enqueues a pad operation in a given dimension, taking all other
1778 // dimensions as they are.
1779 XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
1780 int64_t pad_lo, int64_t pad_hi);
1781
1782 // Enqueues an operation onto the computation that flattens the operand based
1783 // on the dimension order (major/slowest-varying to minor/fastest-varying)
1784 // given, followed by reshaping it into the shape with the given dimension
1785 // sizes (also major to minor). Conceptually, this is a limited form of
1786 // "shape casting".
1787 XlaOp Reshape(XlaOp operand, absl::Span<const int64_t> dimensions,
1788 absl::Span<const int64_t> new_sizes);
1789
1790 // Enqueues a dynamic reshape operation. The dynamic reshape takes additional
1791 // XlaOps as sizes for the result dimension. The result dim i is a dynamic
1792 // dimension dimension if dims_are_dynamic[i] is true.
1793 XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
1794 absl::Span<const int64_t> new_size_bounds,
1795 const std::vector<bool>& dims_are_dynamic);
1796
1797 // Enqueues an operation onto the computation that collapses the operand,
1798 // from first to last dimension (C order), then reshapes it to the given
1799 // dimension sizes. Conceptually, this is a limited form of "shape casting".
1800 XlaOp Reshape(XlaOp operand, absl::Span<const int64_t> new_sizes);
1801
1802 // Enqueues a Reshape op that uses an explicit target shape.
1803 XlaOp Reshape(const Shape& shape, XlaOp operand);
1804
1805 // `inferred_dimension` represents the output dimension that's inferred by
1806 // upper-level framework by dividing the input element count by the known
1807 // output element count. While an inferred_dimension can be static, if there
1808 // is a dynamic dimension in the output, it must be the inferred dimension.
1809 XlaOp ReshapeWithInferredDimension(XlaOp operand,
1810 absl::Span<const int64_t> new_sizes,
1811 int64_t inferred_dimension);
1812
1813 // Wrapper for Reshape.
1814 // Enqueues an operation to collapse the provided dimensions; e.g. an
1815 // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
1816 // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
1817 // be a consecutive, in-order subsequence of the operand dimensions.
1818 //
1819 // Note that collapsing a single dimension does nothing:
1820 //
1821 // {256} collapsing {0} => {256}
1822 // {1} collapsing {0} => {1}
1823 //
1824 // Collapsing multiple dimensions produces a single result dimension:
1825 //
1826 // {256, 2} collapsing {0,1} => {512}
1827 // {256, 2, 3} collapsing {0,1} => {512, 3}
1828 //
1829 // This could potentially cause data to be moved -- it provides a more
1830 // structured form of reshaping than an arbitrary Reshape operation.
1831 XlaOp Collapse(XlaOp operand, absl::Span<const int64_t> dimensions);
1832
1833 // Enqueues a slice operation onto the computation that slices the operand
1834 // from the start indices to the limit indices; e.g.
1835 //
1836 // x
1837 // [ 0 1 2 3 ]
1838 // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
1839 // [ 8 9 a b ]
1840 //
1841 // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
1842 // range notation.
1843 // The strides parameter determines the stride over the slice
1844 XlaOp Slice(XlaOp operand, absl::Span<const int64_t> start_indices,
1845 absl::Span<const int64_t> limit_indices,
1846 absl::Span<const int64_t> strides);
1847
1848 // Enqueues a slice operation in a given dimension, taking all other
1849 // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
1850 // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
1851 // for:
1852 //
1853 // array[:, 2:4:1, :]
1854 XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index,
1855 int64_t stride, int64_t dimno);
1856
1857 // Enqueues a slice operation onto the computation that slices the 'operand'
1858 // from dynamic start indices which are passed in 'start_indices'.
1859 // The size of the slice in each dimension is passed in 'slice_sizes',
1860 // which specify the end point of exclusive slice intervals in each
1861 // dimension [start, start + size).
1862 // The shape of each element of 'start_indices' must be scalar, with the span
1863 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1864 // have the same shape.
1865 // Slice index calculations are computed modulo input dimension sizes to
1866 // prevent dynamic start indices from generating out-of-bound array accesses.
1867 XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
1868 absl::Span<const int64_t> slice_sizes);
1869
1870 // Enqueues a dynamic update slice operation onto the computation, which
1871 // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
1872 // The shape of 'update' determines the shape of the slice of 'operand'
1873 // which is updated.
1874 // The indices specified in 'start_indices' specify the offset of the slice
1875 // of 'operand' which is updated.
1876 //
1877 // update = {10, 11} // calculated at runtime.
1878 // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ]
1879 // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11]
1880 // [7 8 9] [7 8 9 ]
1881 //
1882 // The shape of each element of 'start_indices' must be scalar, with the span
1883 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1884 // have the same shape.
1885 // Slice index calculations are computed modulo update dimension sizes to
1886 // prevent dynamic start indices from generating out-of-bound array accesses.
1887 XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
1888 absl::Span<const XlaOp> start_indices);
1889
1890 // Enqueues a concatenate instruction onto the computation. 'operands' must
1891 // have >= 1 entry.
1892 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1893 int64_t dimension);
1894
1895 // Enqueues a conditional-move-like select operation onto the computation;
1896 // predicated on pred, selects between on_true and on_false.
1897 XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
1898
1899 // Enqueues a tuple-creation instruction onto the computation.
1900 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
1901
1902 // Enqueues a tuple-element-get instruction onto the computation.
1903 XlaOp GetTupleElement(XlaOp tuple_data, int64_t index);
1904
1905 // Enqueues an equal-to comparison instruction onto the computation.
1906 XlaOp Eq(XlaOp lhs, XlaOp rhs,
1907 absl::Span<const int64_t> broadcast_dimensions = {});
1908 XlaOp EqTotalOrder(XlaOp lhs, XlaOp rhs,
1909 absl::Span<const int64_t> broadcast_dimensions = {});
1910
1911 // Enqueues a not-equal comparison instruction onto the computation.
1912 XlaOp Ne(XlaOp lhs, XlaOp rhs,
1913 absl::Span<const int64_t> broadcast_dimensions = {});
1914 XlaOp NeTotalOrder(XlaOp lhs, XlaOp rhs,
1915 absl::Span<const int64_t> broadcast_dimensions = {});
1916
1917 // Enqueues a greater-or-equal comparison instruction onto the computation.
1918 XlaOp Ge(XlaOp lhs, XlaOp rhs,
1919 absl::Span<const int64_t> broadcast_dimensions = {});
1920 XlaOp GeTotalOrder(XlaOp lhs, XlaOp rhs,
1921 absl::Span<const int64_t> broadcast_dimensions = {});
1922
1923 // Enqueues a greater-than comparison instruction onto the computation.
1924 XlaOp Gt(XlaOp lhs, XlaOp rhs,
1925 absl::Span<const int64_t> broadcast_dimensions = {});
1926 XlaOp GtTotalOrder(XlaOp lhs, XlaOp rhs,
1927 absl::Span<const int64_t> broadcast_dimensions = {});
1928
1929 // Enqueues a less-than comparison instruction onto the computation.
1930 XlaOp Lt(XlaOp lhs, XlaOp rhs,
1931 absl::Span<const int64_t> broadcast_dimensions = {});
1932 XlaOp LtTotalOrder(XlaOp lhs, XlaOp rhs,
1933 absl::Span<const int64_t> broadcast_dimensions = {});
1934
1935 // Enqueues a less-or-equal comparison instruction onto the computation.
1936 XlaOp Le(XlaOp lhs, XlaOp rhs,
1937 absl::Span<const int64_t> broadcast_dimensions = {});
1938 XlaOp LeTotalOrder(XlaOp lhs, XlaOp rhs,
1939 absl::Span<const int64_t> broadcast_dimensions = {});
1940
1941 // Enqueues a comparison instruction onto the computation (optionally without
1942 // broadcast_dimensions for consistency with others).
1943 XlaOp Compare(XlaOp lhs, XlaOp rhs,
1944 absl::Span<const int64_t> broadcast_dimensions,
1945 ComparisonDirection direction, Comparison::Type compare_type);
1946 XlaOp Compare(XlaOp lhs, XlaOp rhs,
1947 absl::Span<const int64_t> broadcast_dimensions,
1948 ComparisonDirection direction);
1949 XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction);
1950
1951 // Enqueues a dot instruction onto the computation.
1952 XlaOp Dot(XlaOp lhs, XlaOp rhs,
1953 const PrecisionConfig* precision_config = nullptr,
1954 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
1955
1956 // Enqueues a general dot instruction onto the computation.
1957 XlaOp DotGeneral(
1958 XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
1959 const PrecisionConfig* precision_config = nullptr,
1960 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
1961
1962 // Enqueues a convolution instruction onto the computation, which uses the
1963 // default convolution dimension numbers.
1964 XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1965 Padding padding, int64_t feature_group_count = 1,
1966 int64_t batch_group_count = 1,
1967 const PrecisionConfig* precision_config = nullptr,
1968 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
1969
1970 // Enqueues a convolution instruction onto the computation, with the caller
1971 // provided padding configuration in the format returned by MakePadding().
1972 XlaOp ConvWithGeneralPadding(
1973 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1974 absl::Span<const std::pair<int64_t, int64_t>> padding,
1975 int64_t feature_group_count = 1, int64_t batch_group_count = 1,
1976 const PrecisionConfig* precision_config = nullptr,
1977 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
1978
1979 // Enqueues a convolution instruction onto the computation, with the caller
1980 // provided dimension numbers configuration.
1981 XlaOp ConvWithGeneralDimensions(
1982 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1983 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1984 int64_t feature_group_count = 1, int64_t batch_group_count = 1,
1985 const PrecisionConfig* precision_config = nullptr,
1986 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
1987
1988 // Enqueues a convolution instruction onto the computation, with the caller
1989 // provided padding configuration as well as the dimension numbers.
1990 XlaOp ConvGeneral(
1991 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1992 absl::Span<const std::pair<int64_t, int64_t>> padding,
1993 const ConvolutionDimensionNumbers& dimension_numbers,
1994 int64_t feature_group_count = 1, int64_t batch_group_count = 1,
1995 const PrecisionConfig* precision_config = nullptr,
1996 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
1997
1998 // Enqueues a convolution instruction onto the computation, with the caller
1999 // provided padding configuration, dilation factors and dimension numbers.
2000 XlaOp ConvGeneralDilated(
2001 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
2002 absl::Span<const std::pair<int64_t, int64_t>> padding,
2003 absl::Span<const int64_t> lhs_dilation,
2004 absl::Span<const int64_t> rhs_dilation,
2005 const ConvolutionDimensionNumbers& dimension_numbers,
2006 int64_t feature_group_count = 1, int64_t batch_group_count = 1,
2007 const PrecisionConfig* precision_config = nullptr,
2008 std::optional<PrimitiveType> preferred_element_type = std::nullopt,
2009 std::optional<std::vector<bool>> window_reversal = std::nullopt);
2010
2011 XlaOp DynamicConvForward(
2012 XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
2013 absl::Span<const std::pair<int64_t, int64_t>> padding,
2014 absl::Span<const int64_t> lhs_dilation,
2015 absl::Span<const int64_t> rhs_dilation,
2016 const ConvolutionDimensionNumbers& dimension_numbers,
2017 int64_t feature_group_count, int64_t batch_group_count,
2018 const PrecisionConfig* precision_config, PaddingType padding_type,
2019 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
2020
2021 XlaOp DynamicConvInputGrad(
2022 XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
2023 absl::Span<const int64_t> window_strides,
2024 absl::Span<const std::pair<int64_t, int64_t>> padding,
2025 absl::Span<const int64_t> lhs_dilation,
2026 absl::Span<const int64_t> rhs_dilation,
2027 const ConvolutionDimensionNumbers& dimension_numbers,
2028 int64_t feature_group_count, int64_t batch_group_count,
2029 const PrecisionConfig* precision_config, PaddingType padding_type,
2030 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
2031
2032 XlaOp DynamicConvKernelGrad(
2033 XlaOp activations, XlaOp gradients,
2034 absl::Span<const int64_t> window_strides,
2035 absl::Span<const std::pair<int64_t, int64_t>> padding,
2036 absl::Span<const int64_t> lhs_dilation,
2037 absl::Span<const int64_t> rhs_dilation,
2038 const ConvolutionDimensionNumbers& dimension_numbers,
2039 int64_t feature_group_count, int64_t batch_group_count,
2040 const PrecisionConfig* precision_config, PaddingType padding_type,
2041 std::optional<PrimitiveType> preferred_element_type = std::nullopt);
2042
2043 // Enqueues an FFT instruction onto the computation, of the given type and
2044 // with the given FFT length.
2045 XlaOp Fft(XlaOp operand, FftType fft_type,
2046 absl::Span<const int64_t> fft_length);
2047
2048 // Solves systems of linear equations with lower or upper triangular coefficient
2049 // matrices by forward- or back-substitution. Broadcasting along leading
2050 // dimensions, this routine solves for x in one of the matrix systems
2051 // `op(a) * x = b`, or `x * op(a) = b`,
2052 // for the variable `x` given `a` and `b`, where `op(a)` is either
2053 // `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`.
2054 //
2055 // * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form
2056 // square matrices. If `lower` is true (false), then the strictly upper
2057 // (lower) triangular part of each innermost matrix in `a` is assumed to be
2058 // zero and is not accessed.
2059 // * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a
2060 // tensor of shape `[..., K, M]`.
2061 // * `left_side` is a boolean, indicating whether to solve a system of the form
2062 // op(a) * x = b (true) or x * op(a) = b (false).
2063 // * `lower` is a boolean, indicating whether the argument `a` is
2064 // lower-triangular (true) or upper-triangular (false).
2065 // * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be
2066 // 1 and not accessed.
2067 // * `transpose_a` indicates which function `op` we use to transform the tensor
2068 // `a`: the identity function, transpose(a), or conjugate(transpose(a))
2069 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
2070 bool unit_diagonal,
2071 TriangularSolveOptions::Transpose transpose_a);
2072
2073 // Computes the Cholesky decompositions of a batch of symmetric (Hermitian)
2074 // positive definite matrices.
2075 // `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the
2076 // two minor dimensions equal.
2077 // If `lower` is true, the data from the lower triangle is used; if false, the
2078 // upper triangle is used. The input data in the other triangle of the input
2079 // does not affect the output. Returns the output in the same lower/upper
2080 // triangle. The data returned in the other output triangle is arbitrary and
2081 // implementation-defined.
2082 //
2083 // If `a` is not Hermitian positive definite, returns an array full of NaNs.
2084 XlaOp Cholesky(XlaOp a, bool lower);
2085
2086 // Enqueues an infeed instruction onto the computation, which writes data of
2087 // the given shape to the infeed buffer of the device.
2088 XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
2089 const std::string& config = "");
2090
2091 // Variant of Infeed which takes a token-shaped operand and produces a
2092 // two-element tuple containing the data value and a token-shaped value.
2093 // Tokens are used for ordering side-effecting operations.
2094 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2095 XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
2096 const std::string& config = "");
2097
2098 // Enqueues an outfeed instruction onto the computation. This instruction
2099 // generates outgoing data transfers for the given data.
2100 //
2101 // shape_with_layout communicates the laid out shape that we want to outfeed
2102 // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
2103 // will occur.
2104 void Outfeed(XlaOp operand, const Shape& shape_with_layout,
2105 const std::string& outfeed_config);
2106
2107 // Variant of Outfeed which takes a token-shaped operand and produces a
2108 // token-shaped value. Tokens are used for ordering side-effecting operations.
2109 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2110 XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
2111 const Shape& shape_with_layout,
2112 const std::string& outfeed_config);
2113
2114 // Enqueues a call instruction onto the computation.
2115 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
2116 absl::Span<const XlaOp> operands);
2117
2118 // Enqueues a custom call instruction onto the computation. A custom call
2119 // invokes code external to XLA. The |operands| are passed to the external code,
2120 // and the external code is expected to produce a result of the given
2121 // |shape|. The exact mechanism is backend-specific. For example, in the CPU
2122 // backend, a call instruction is emitted which targets a symbol with the name
2123 // |call_target_name|. |call_target_name| and |opaque| can arbitrary strings,
2124 // but |call_target_name| should be short as it may be used in labels. |opaque|
2125 // can encode arbitrarily large amounts of information. |has_side_effect|
2126 // specifies whether the instruction can have side effects.
2127 // |output_operand_aliasing| specifies a list of output/operand buffer pairs
2128 // that alias each other, where the output buffer is represented as a
2129 // ShapeIndex, and the operand buffer is represented as the operand index and
2130 // the ShapeIndex.
2131 XlaOp CustomCall(
2132 XlaBuilder* builder, const std::string& call_target_name,
2133 absl::Span<const XlaOp> operands, const Shape& shape,
2134 const std::string& opaque = "", bool has_side_effect = false,
2135 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
2136 output_operand_aliasing = {},
2137 const Literal* literal = nullptr,
2138 CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE,
2139 CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
2140
2141 // Overload which constructs a custom call that applies an Xla computation.
2142 XlaOp CustomCallWithComputation(
2143 XlaBuilder* builder, const std::string& call_target_name,
2144 absl::Span<const XlaOp> operands, const XlaComputation& computation,
2145 const Shape& shape, const std::string& opaque = "",
2146 bool has_side_effect = false,
2147 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
2148 output_operand_aliasing = {},
2149 const Literal* literal = nullptr,
2150 CustomCallSchedule = CustomCallSchedule::SCHEDULE_NONE,
2151 CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
2152
2153 // Overload which constructs a custom call with fixed layouts. The operands will
2154 // have the layouts specified by |operand_shapes_with_layout| when provided to
2155 // external code, and the external code is expected to produce a result with the
2156 // layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
2157 // and |operand_shapes_with_layout| must have layouts.
2158 XlaOp CustomCallWithLayout(
2159 XlaBuilder* builder, const std::string& call_target_name,
2160 absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
2161 absl::Span<const Shape> operand_shapes_with_layout,
2162 const std::string& opaque = "", bool has_side_effect = false,
2163 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
2164 output_operand_aliasing = {},
2165 const Literal* literal = nullptr,
2166 CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE,
2167 CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
2168
2169 // Overload which annotates a custom call with the given Window and
2170 // ConvolutionDimensionNumbers. Useful for custom-calls which represent
2171 // convolutions.
2172 //
2173 // This sets the layout of its operands if operand_shapes_with_layout is
2174 // nonempty, and it sets the layout of its result if `shape` has a layout.
2175 XlaOp CustomCallWithConvDnums(
2176 XlaBuilder* builder, const std::string& call_target_name,
2177 absl::Span<const XlaOp> operands, const Shape& shape,
2178 absl::Span<const Shape> operand_shapes_with_layout,
2179 const std::string& opaque, bool has_side_effect,
2180 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
2181 output_operand_aliasing,
2182 const Literal* literal, Window window, ConvolutionDimensionNumbers dnums,
2183 CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE,
2184 CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
2185
2186 // Enqueues an optimization barrier onto the computation.
2187 XlaOp OptimizationBarrier(XlaOp operand);
2188
2189 // The following methods enqueue element-wise binary arithmetic operations
2190 // onto the computation. The shapes of the operands have to match unless one
2191 // of the operands is a scalar, or an explicit broadcast dimension is given
2192 // (see g3doc for more details).
2193
2194 // Enqueues a complex compose instruction onto the computation.
2195 XlaOp Complex(XlaOp real, XlaOp imag,
2196 absl::Span<const int64_t> broadcast_dimensions = {});
2197
2198 // Enqueues a complex conjugate instruction onto the computation.
2199 XlaOp Conj(XlaOp operand);
2200
2201 // Enqueues an add instruction onto the computation.
2202 XlaOp Add(XlaOp lhs, XlaOp rhs,
2203 absl::Span<const int64_t> broadcast_dimensions = {});
2204
2205 // Enqueues a subtract instruction onto the computation.
2206 XlaOp Sub(XlaOp lhs, XlaOp rhs,
2207 absl::Span<const int64_t> broadcast_dimensions = {});
2208
2209 // Enqueues a multiply instruction onto the computation.
2210 XlaOp Mul(XlaOp lhs, XlaOp rhs,
2211 absl::Span<const int64_t> broadcast_dimensions = {});
2212
2213 // Enqueues a divide instruction onto the computation.
2214 XlaOp Div(XlaOp lhs, XlaOp rhs,
2215 absl::Span<const int64_t> broadcast_dimensions = {});
2216
2217 // Enqueues a remainder instruction onto the computation.
2218 XlaOp Rem(XlaOp lhs, XlaOp rhs,
2219 absl::Span<const int64_t> broadcast_dimensions = {});
2220
2221 // Enqueues a max instruction onto the computation.
2222 XlaOp Max(XlaOp lhs, XlaOp rhs,
2223 absl::Span<const int64_t> broadcast_dimensions = {});
2224
2225 // Enqueues a min instruction onto the computation.
2226 XlaOp Min(XlaOp lhs, XlaOp rhs,
2227 absl::Span<const int64_t> broadcast_dimensions = {});
2228
2229 // Element-wise logical operators
2230 XlaOp And(XlaOp lhs, XlaOp rhs,
2231 absl::Span<const int64_t> broadcast_dimensions = {});
2232
2233 // Overload to call And with 3 or more operands. We need the following somewhat
2234 // convoluted overload set to disambiguate with the overload that takes the
2235 // `broadcast_dimensions` optional param.
And(XlaOp op1,XlaOp op2,XlaOp op3)2236 inline XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3) {
2237 return And(op1, And(op2, op3));
2238 }
2239 template <typename... XlaOpTs>
And(XlaOp op1,XlaOp op2,XlaOp op3,const XlaOpTs &...operands)2240 XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) {
2241 return And(op1, And(op2, And(op3, operands...)));
2242 }
2243
2244 XlaOp Or(XlaOp lhs, XlaOp rhs,
2245 absl::Span<const int64_t> broadcast_dimensions = {});
2246
2247 // Overload to call Or with 3 or more operands. As with `And`, we need the
2248 // following complicated overload set to handle the default arg in the `Or`
2249 // overload above.
Or(XlaOp op1,XlaOp op2,XlaOp op3)2250 inline XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3) {
2251 return Or(op1, Or(op2, op3));
2252 }
2253 template <typename... XlaOpTs>
Or(XlaOp op1,XlaOp op2,XlaOp op3,const XlaOpTs &...operands)2254 XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) {
2255 return Or(op1, Or(op2, Or(op3, operands...)));
2256 }
2257
2258 XlaOp Xor(XlaOp lhs, XlaOp rhs,
2259 absl::Span<const int64_t> broadcast_dimensions = {});
2260
2261 XlaOp Not(XlaOp operand);
2262
2263 XlaOp PopulationCount(XlaOp operand);
2264
2265 XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs,
2266 absl::Span<const int64_t> broadcast_dimensions = {});
2267 XlaOp ShiftRightArithmetic(XlaOp lhs, XlaOp rhs,
2268 absl::Span<const int64_t> broadcast_dimensions = {});
2269 XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs,
2270 absl::Span<const int64_t> broadcast_dimensions = {});
2271
2272 // Reduces an array among the provided dimensions, given "computation" as a
2273 // reduction operator.
2274 XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation,
2275 absl::Span<const int64_t> dimensions_to_reduce);
2276
2277 // Reduces several arrays simultaneously among the provided dimensions, given
2278 // "computation" as a reduction operator.
2279 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
2280 absl::Span<const XlaOp> init_values,
2281 const XlaComputation& computation,
2282 absl::Span<const int64_t> dimensions_to_reduce);
2283
2284 // Convenience wrapper around the above that reduces all the dimensions in the
2285 // operand shape.
2286 XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
2287 const XlaComputation& computation);
2288
2289 // Enqueues a windowed reduce instruction onto the computation.
2290 XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
2291 const XlaComputation& computation,
2292 absl::Span<const int64_t> window_dimensions,
2293 absl::Span<const int64_t> window_strides, Padding padding);
2294
2295 XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
2296 absl::Span<const XlaOp> init_values,
2297 const XlaComputation& computation,
2298 absl::Span<const int64_t> window_dimensions,
2299 absl::Span<const int64_t> window_strides, Padding padding);
2300
2301 // As ReduceWindow(), but the padding is given in the format
2302 // returned by MakePadding().
2303 XlaOp ReduceWindowWithGeneralPadding(
2304 XlaOp operand, XlaOp init_value, const XlaComputation& computation,
2305 absl::Span<const int64_t> window_dimensions,
2306 absl::Span<const int64_t> window_strides,
2307 absl::Span<const int64_t> base_dilations,
2308 absl::Span<const int64_t> window_dilations,
2309 absl::Span<const std::pair<int64_t, int64_t>> padding);
2310 XlaOp ReduceWindowWithGeneralPadding(
2311 absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
2312 const XlaComputation& computation,
2313 absl::Span<const int64_t> window_dimensions,
2314 absl::Span<const int64_t> window_strides,
2315 absl::Span<const int64_t> base_dilations,
2316 absl::Span<const int64_t> window_dilations,
2317 absl::Span<const std::pair<int64_t, int64_t>> padding);
2318
2319 // Returns the sum of the operand value within each subgroup of replicas. All
2320 // replicas supply one input to the sum and all replicas receive the resulting
2321 // sum for each subgroup.
2322 XlaOp CrossReplicaSum(XlaOp operand,
2323 absl::Span<const ReplicaGroup> replica_groups = {});
2324
2325 XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension,
2326 int64_t shard_count,
2327 absl::Span<const ReplicaGroup> replica_groups = {},
2328 const std::optional<ChannelHandle>& channel_id = std::nullopt,
2329 const std::optional<Layout>& layout = std::nullopt,
2330 const std::optional<bool> use_global_device_ids = std::nullopt);
2331
2332 // Enqueues an operation that do an AllReduce of the operand cross cores. Here
2333 // AllReduce means doing a reduction on the input operand cross cores and then
2334 // broadcasting the reduction result to those cores. The reduction function is
2335 // defined by `computation`, which should be a commutative computation on
2336 // scalars, e.g., add, min, or max. The way that AllReduce is applied is
2337 // configured by:
2338 //
2339 // - `replica_groups`: each ReplicaGroup contains a list of replica id. If
2340 // empty, all replicas belong to one group. Allreduce will be applied within
2341 // subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}}
2342 // means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
2343 //
2344 // - `channel_id`: for Allreduce nodes from different modules, if they have the
2345 // same channel_id, they will be 'AllReduce'd. If empty, AllReduce will not be
2346 // applied cross modules.
2347 //
2348 // - `shape_with_layout`: forces the layout of the AllReduce to the given
2349 // layout. This is used to guarantee the same layout for a group of AllReduce
2350 // ops compiled separately.
2351 XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
2352 absl::Span<const ReplicaGroup> replica_groups = {},
2353 const std::optional<ChannelHandle>& channel_id = std::nullopt,
2354 const std::optional<Shape>& shape_with_layout = std::nullopt,
2355 const std::optional<bool> use_global_device_ids = std::nullopt);
2356
2357 XlaOp ReduceScatter(
2358 XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension,
2359 int64_t shard_count, absl::Span<const ReplicaGroup> replica_groups = {},
2360 const std::optional<ChannelHandle>& channel_id = std::nullopt,
2361 const std::optional<Layout>& layout = std::nullopt,
2362 const std::optional<bool> use_global_device_ids = std::nullopt);
2363
2364 // Enqueues an operation that do an Alltoall of the operand cross cores.
2365 // An optional `layout` can be specified to force the layout of the instruction.
2366 // This is used to guarantee the same layout for a group of AllToAll ops
2367 // compiled separately.
2368 XlaOp AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension,
2369 int64_t split_count,
2370 absl::Span<const ReplicaGroup> replica_groups = {},
2371 const std::optional<Layout>& layout = std::nullopt);
2372
2373 XlaOp AllToAllTuple(absl::Span<const XlaOp> operand,
2374 absl::Span<const ReplicaGroup> replica_groups = {},
2375 const std::optional<Layout>& layout = std::nullopt);
2376
2377 XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension,
2378 int64_t concat_dimension, int64_t split_count,
2379 absl::Span<const ReplicaGroup> replica_groups = {},
2380 const std::optional<Layout>& layout = std::nullopt);
2381
2382 // Enqueues an collective operation that sends and receives data cross replicas.
2383 //
2384 // - `source_target_pair`: a list of (source_replica_id, target_replica_id)
2385 // pairs. For each pair, the operand is sent from source replica to target
2386 // replica. Note that, 1) any two pairs should not have the same target replica
2387 // id, and they should not have the same source replica id; 2) if a replica id
2388 // is not a target in any pair, then the output on that replica is a tensor
2389 // consists of 0(s) with the same shape as the input.
2390 XlaOp CollectivePermute(
2391 XlaOp operand,
2392 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs);
2393
2394 // Enqueues an operation that returns the replica ID.
2395 XlaOp ReplicaId(XlaBuilder* builder);
2396
2397 // Enqueues an operation that scatters the `source` array to the selected
2398 // indices of each window.
2399 XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
2400 absl::Span<const int64_t> window_dimensions,
2401 absl::Span<const int64_t> window_strides,
2402 Padding padding, XlaOp source, XlaOp init_value,
2403 const XlaComputation& scatter);
2404
2405 // As SelectAndScatter(), but the padding is given in the format
2406 // returned by MakePadding().
2407 XlaOp SelectAndScatterWithGeneralPadding(
2408 XlaOp operand, const XlaComputation& select,
2409 absl::Span<const int64_t> window_dimensions,
2410 absl::Span<const int64_t> window_strides,
2411 absl::Span<const std::pair<int64_t, int64_t>> padding, XlaOp source,
2412 XlaOp init_value, const XlaComputation& scatter);
2413
2414 // Enqueues an abs instruction onto the computation.
2415 XlaOp Abs(XlaOp operand);
2416
2417 // Enqueues a atan2 instruction onto the computation.
2418 XlaOp Atan2(XlaOp y, XlaOp x,
2419 absl::Span<const int64_t> broadcast_dimensions = {});
2420
2421 // Enqueues an exp instruction onto the computation.
2422 XlaOp Exp(XlaOp operand);
2423
2424 // Enqueues an expm1 instruction onto the computation.
2425 XlaOp Expm1(XlaOp operand);
2426
2427 // Enqueues a floor instruction onto the computation.
2428 XlaOp Floor(XlaOp operand);
2429
2430 // Enqueues a ceil instruction onto the computation.
2431 XlaOp Ceil(XlaOp operand);
2432
2433 // Enqueues a round instruction onto the computation,
2434 // with half-way cases rounding away from zero.
2435 XlaOp Round(XlaOp operand);
2436
2437 // Enqueues a round instruction onto the computation, rounding to nearest even
2438 XlaOp RoundNearestEven(XlaOp operand);
2439
2440 // Enqueues an log instruction (natural logarithm) onto the computation.
2441 XlaOp Log(XlaOp operand);
2442
2443 // Enqueues an log1p instruction (log(x+1)) onto the computation.
2444 XlaOp Log1p(XlaOp operand);
2445
2446 // Enqueues a logistic instruction onto the computation.
2447 XlaOp Logistic(XlaOp operand);
2448
2449 // Enqueues a sign instruction onto the computation.
2450 XlaOp Sign(XlaOp operand);
2451
2452 // Enqueues a count leading zeros instruction onto the computation.
2453 XlaOp Clz(XlaOp operand);
2454
2455 // Enqueues a cosine instruction onto the computation.
2456 XlaOp Cos(XlaOp operand);
2457
2458 // Enqueues a sine instruction onto the computation.
2459 XlaOp Sin(XlaOp operand);
2460
2461 // Enqueues a tanh instruction onto the computation.
2462 XlaOp Tanh(XlaOp operand);
2463
2464 // Enqueues a real-part instruction onto the computation.
2465 XlaOp Real(XlaOp operand);
2466
2467 // Enqueues an imaginary-part instruction onto the computation.
2468 XlaOp Imag(XlaOp operand);
2469
2470 // Enqueues a sqrt computation onto the computation.
2471 XlaOp Sqrt(XlaOp operand);
2472
2473 // Enqueues a cbrt computation onto the computation.
2474 XlaOp Cbrt(XlaOp operand);
2475
2476 // Enqueues a rsqrt computation onto the computation.
2477 XlaOp Rsqrt(XlaOp operand);
2478
2479 // Enqueues a lhs^rhs computation onto the computation.
2480 XlaOp Pow(XlaOp lhs, XlaOp rhs,
2481 absl::Span<const int64_t> broadcast_dimensions = {});
2482
2483 // Enqueues an operator that tests if the operand's values are finite, i.e., not
2484 // +/-Inf or NaN. Returns an array of booleans with the same shape where
2485 // entries are true iff the corresponding entry was not infinite or NaN.
2486 //
2487 // Defined only for real-valued (i.e. not complex) floating-point types; raises
2488 // an error for other types.
2489 //
2490 // See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h.
2491 XlaOp IsFinite(XlaOp operand);
2492
2493 // Enqueues an iota operation onto the computation.
2494 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64_t iota_dimension);
2495
2496 // Enqueues a rank-1 iota operation onto the computation.
2497 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size);
2498
2499 // Enqueues a convert instruction onto the computation that changes the
2500 // element type of the operand array to primitive_type.
2501 XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
2502
2503 // Enqueues a no-op instruction onto the computation that changes
2504 // the element type of the operand array to primitive_type. The
2505 // bit-widths of the source and destination element types must be
2506 // identical.
2507 XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
2508
2509 // Enqueues a negate instruction onto the computation.
2510 XlaOp Neg(XlaOp operand);
2511
2512 // Enqueues a transpose instruction onto the computation.
2513 XlaOp Transpose(XlaOp operand, absl::Span<const int64_t> permutation);
2514
2515 // Enqueues a reverse instruction onto the computation. The order of the
2516 // elements in the given dimensions is reversed (i.e., the element at index i
2517 // is moved to index dimension_size - 1 - i).
2518 XlaOp Rev(XlaOp operand, absl::Span<const int64_t> dimensions);
2519
2520 // Enqueues a sort instruction onto the computation, using 'comparator' for
2521 // comparisons. 'comparator' needs to define a strict weak order. 'is_stable'
2522 // determines whether the stable sorting should be used.
2523 // If only one operand is provided:
2524 // * If the operand is a rank-1 tensor (an array), the result is a sorted array.
2525 // The resulting sorting order has the property that for all index positions
2526 // i, j with i < j, either
2527 // comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or
2528 // comparator(value[i], value[j]) = true.
2529 // * If the operand has higher rank, the operand is sorted along the provided
2530 // dimension. For example, for a rank-2 tensor (a matrix), a dimension value
2531 // of 0 will independently sort every column, and a dimension value of 1 will
2532 // independently sort each row. If no dimension number is provided, then the
2533 // last dimension is chosen by default. For the dimension which is sorted, the
2534 // same sorting order applies as in the rank-1 case.
2535 //
2536 // If more than one operand is provided:
2537 // * All operands must be tensors with the same dimensions. The element types of
2538 // the tensors may be different.
2539 // * The result is a tuple that consists of the operands in sorted order (along
2540 // the provided dimension, as above). The same permutation as implied by the
2541 // comparison computation is applied to all operand tensors. When comparing
2542 // two index positions, 'comparator' is called with 2 * n scalar parameters,
2543 // where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at
2544 // two index positions.
2545 // Default comparator computations can be found in lib/comparators.h
2546 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
2547 int64_t dimension = -1, bool is_stable = false);
2548
2549 // Enqueues a clamp instruction onto the computation.
2550 XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
2551
2552 // Enqueues a map instruction onto the computation.
2553 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
2554 const XlaComputation& computation,
2555 absl::Span<const int64_t> dimensions,
2556 absl::Span<const XlaOp> static_operands = {});
2557
2558 // Enqueues a N(mu, sigma) random number generation instruction onto the
2559 // computation.
2560 XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
2561
2562 // Enqueues a U(a, b) random number generation instruction onto the
2563 // computation. Returns values in the semi-open interval [a, b).
2564 XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
2565
2566 // Enqueues a B(initial_state) random bit generation instruction onto the
2567 // computation. Returns the new key and random bits with the specified shape.
2568 XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
2569 const Shape& shape);
2570
2571 // Enqueues a while node onto the computation.
2572 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
2573 XlaOp init);
2574
2575 // Enqueues a conditional node onto the computation.
2576 XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
2577 const XlaComputation& true_computation, XlaOp false_operand,
2578 const XlaComputation& false_computation);
2579
2580 // Enqueues either a predicated (if/else) or indexed (switch/case/default)
2581 // conditional node onto the computation. N >= 1 branch_computations and
2582 // branch_operands are matched by index. branch_index selects the branch that
2583 // will be executed. Out of range branch_index uses the N-1'th
2584 // branch_computation as default.
2585 XlaOp Conditional(XlaOp branch_index,
2586 absl::Span<const XlaComputation* const> branch_computations,
2587 absl::Span<const XlaOp> branch_operands);
2588
2589 // Enqueues a ReducePrecision node onto the computation.
2590 XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
2591 const int mantissa_bits);
2592
2593 // Enqueues a Gather node onto the computation.
2594 XlaOp Gather(XlaOp input, XlaOp start_indices,
2595 const GatherDimensionNumbers& dimension_numbers,
2596 absl::Span<const int64_t> slice_sizes,
2597 bool indices_are_sorted = false);
2598
2599 // Enqueues a Scatter node onto the computation.
2600 XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
2601 const XlaComputation& update_computation,
2602 const ScatterDimensionNumbers& dimension_numbers,
2603 bool indices_are_sorted = false, bool unique_indices = false);
2604 XlaOp Scatter(absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
2605 absl::Span<const XlaOp> updates,
2606 const XlaComputation& update_computation,
2607 const ScatterDimensionNumbers& dimension_numbers,
2608 bool indices_are_sorted = false, bool unique_indices = false);
2609
2610 // Enqueues a Send node onto the computation for device-to-device
2611 // communication. This operation sends the given operand to
2612 // a Recv instruction in a different computation that shares the same channel
2613 // handle.
2614 void Send(XlaOp operand, const ChannelHandle& handle);
2615
2616 // Variant of Send which takes a token-shaped operand and produces a
2617 // token-shaped value. Tokens are used for ordering side-effecting operations.
2618 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2619 XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle);
2620
2621 // Enqueues a Recv node onto the computation for device-to-device
2622 // communication. The data comes from a Send instruction in a different
2623 // computation that shares the same channel handle and its shape must be the
2624 // same as the given shape.
2625 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
2626 const ChannelHandle& handle);
2627
2628 // Variant of Recv which takes a token-shaped operand and produces a two-element
2629 // tuple containing the data value and a token-shaped value. Tokens are used
2630 // for ordering side-effecting operations.
2631 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2632 XlaOp RecvWithToken(XlaOp token, const Shape& shape,
2633 const ChannelHandle& handle);
2634
2635 // Enqueues a Send node which transfers data from the device to the host. The
2636 // 'shape_with_layout' argument defines the layout of the data transferred; its
2637 // shape must be compatible with the shape of the operand. The operand must be
2638 // array-shaped.
2639 // TODO(b/111544877): Support tuple shapes.
2640 XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout,
2641 const ChannelHandle& handle);
2642
2643 // Enqueues a Recv node which transfers data from the host to the device. The
2644 // given shape must contain a layout and must be an array.
2645 // TODO(b/111544877): Support tuple shapes.
2646 XlaOp RecvFromHost(XlaOp token, const Shape& shape,
2647 const ChannelHandle& handle);
2648
2649 // Enqueues an operation (AfterAll) with no operands that produces a
2650 // token-shaped value. Tokens are used for ordering side-effecting operations.
2651 // This is a separate method from AfterAll to facility the removal of
2652 // operand-less AfterAll instructions.
2653 // TODO(b/110532604): Remove this function when all tokens are derived from a
2654 // single token generated or passed into the entry computation.
2655 XlaOp CreateToken(XlaBuilder* builder);
2656
2657 // Enqueues an AfterAll instruction which produces a token-shaped value and
2658 // takes a variadic number of token-shaped operands. The number of operands must
2659 // be greater than zero. Used for joining tokens.
2660 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
2661
2662 // Normalizes operand across spatial and batch dimensions for each feature.
2663 //
2664 // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
2665 // is the normalized result and batch_mean and batch_var are the mean and
2666 // variance, respectively, across batch for the operand.
2667 XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, float epsilon,
2668 int64_t feature_index);
2669
2670 // Normalizes operand across spatial and batch dimensions for each feature.
2671 //
2672 // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
2673 // computing `mean` and `variance` for each batch inside the operation. It
2674 // uses the input `mean` and `variance` instead as estimated values. The
2675 // purpose of this op is to reduce latency in inference, hence the name
2676 // `BatchNormInference`.
2677 //
2678 // The output has the same shape as `operand`, and contains the normalized
2679 // values for each batch.
2680 XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean,
2681 XlaOp variance, float epsilon, int64_t feature_index);
2682
2683 // Calculates the gradients of a batch norm op.
2684 //
2685 // The inputs `batch_mean` and `batch_var` represent the mean and variance
2686 // across the batch.
2687 //
2688 // Returns a tuple of three elements:
2689 // - grad_operand: Gradient with respect to input `operand`
2690 // - grad_offset: Gradient with respect to input `offset`
2691 // - grad_scale: Gradient with respect to input `scale`
2692 XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
2693 XlaOp batch_var, XlaOp grad_output, float epsilon,
2694 int64_t feature_index);
2695
2696 // Returns the size of the given dimension of the operand. The operand must be
2697 // array shaped.
2698 XlaOp GetDimensionSize(XlaOp operand, int64_t dimension);
2699
2700 XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension);
2701
2702 // Returns the same op but with dynamic dimension removed.
2703 XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension);
2704
2705 // Implementation details below this point.
2706 //
2707
2708 // Free function template implementations.
2709
2710 template <typename NativeT>
ConstantR0(XlaBuilder * builder,NativeT value)2711 XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
2712 return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
2713 }
2714
2715 template <typename NativeT>
ConstantR1(XlaBuilder * builder,absl::Span<const NativeT> values)2716 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
2717 BorrowingLiteral literal(
2718 reinterpret_cast<const char*>(values.begin()),
2719 ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
2720 {static_cast<int64_t>(values.size())}));
2721 return ConstantLiteral(builder, literal);
2722 }
2723
2724 template <typename NativeT>
ConstantR1(XlaBuilder * builder,int64_t length,NativeT value)2725 XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value) {
2726 Literal literal(ShapeUtil::MakeShape(
2727 primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
2728 literal.PopulateWithValue(value);
2729 return ConstantLiteral(builder, literal);
2730 }
2731
ConstantR1(XlaBuilder * builder,const tensorflow::core::Bitmap & values)2732 inline XlaOp ConstantR1(XlaBuilder* builder,
2733 const tensorflow::core::Bitmap& values) {
2734 return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
2735 }
2736
2737 template <typename NativeT>
ConstantR2(XlaBuilder * builder,std::initializer_list<std::initializer_list<NativeT>> values)2738 XlaOp ConstantR2(XlaBuilder* builder,
2739 std::initializer_list<std::initializer_list<NativeT>> values) {
2740 return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
2741 }
2742
2743 template <typename NativeT>
ConstantFromArrayWithLayout(XlaBuilder * builder,const Array<NativeT> & values,const Layout & layout)2744 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
2745 const Array<NativeT>& values,
2746 const Layout& layout) {
2747 return ConstantLiteral(
2748 builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
2749 }
2750
2751 template <typename NativeT>
ConstantFromArray(XlaBuilder * builder,const Array<NativeT> & values)2752 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
2753 return ConstantLiteral(builder,
2754 LiteralUtil::CreateFromArray<NativeT>(values));
2755 }
2756
2757 template <typename NativeT>
ConstantR2FromArray2DWithLayout(XlaBuilder * builder,const Array2D<NativeT> & values,const Layout & layout)2758 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
2759 const Array2D<NativeT>& values,
2760 const Layout& layout) {
2761 return ConstantLiteral(
2762 builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
2763 }
2764
2765 template <typename NativeT>
ConstantR2FromArray2D(XlaBuilder * builder,const Array2D<NativeT> & values)2766 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
2767 const Array2D<NativeT>& values) {
2768 return ConstantLiteral(builder,
2769 LiteralUtil::CreateR2FromArray2D<NativeT>(values));
2770 }
2771
2772 template <typename NativeT>
ConstantR3FromArray3DWithLayout(XlaBuilder * builder,const Array3D<NativeT> & values,const Layout & layout)2773 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
2774 const Array3D<NativeT>& values,
2775 const Layout& layout) {
2776 return ConstantLiteral(
2777 builder,
2778 LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
2779 }
2780
2781 template <typename NativeT>
ConstantR3FromArray3D(XlaBuilder * builder,const Array3D<NativeT> & values)2782 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
2783 const Array3D<NativeT>& values) {
2784 return ConstantFromArray(builder, values);
2785 }
2786
2787 template <typename NativeT>
ConstantR4FromArray4DWithLayout(XlaBuilder * builder,const Array4D<NativeT> & values,const Layout & layout)2788 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
2789 const Array4D<NativeT>& values,
2790 const Layout& layout) {
2791 return ConstantFromArrayWithLayout(builder, values, layout);
2792 }
2793
2794 template <typename NativeT>
ConstantR4FromArray4D(XlaBuilder * builder,const Array4D<NativeT> & values)2795 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
2796 const Array4D<NativeT>& values) {
2797 return ConstantFromArray(builder, values);
2798 }
2799
2800 // Switches from automatic SPMD partitioning to manual partitioning. Converts a
2801 // full-shaped tensor (to be automatically partitioned by SPMD partitioner) to a
2802 // shard-shaped tensor to be consumed by manually partitioned ops.
2803 StatusOr<xla::XlaOp> ConvertSpmdFullToShardShape(
2804 xla::XlaBuilder* builder, xla::XlaOp input, int single_dim,
2805 const xla::OpSharding& manual_sharding,
2806 absl::Span<const int64_t> unspecified_dims);
2807
2808 // Switches from manual partitioning to automatic SPMD partitioning. Converts a
2809 // shard-shaped tensor (manually partitioned in SPMD-style) to a full-shaped
2810 // tensor to be partitioned automatically by the SPMD partitioner.
2811 StatusOr<xla::XlaOp> ConvertSpmdShardToFullShape(
2812 xla::XlaBuilder* builder, xla::XlaOp input, const xla::Shape& output_shape,
2813 int single_dim, const xla::OpSharding& manual_sharding,
2814 absl::Span<const int64_t> unspecified_dims);
2815
2816 } // namespace xla
2817
2818 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
2819