1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
18
19 #include <map>
20 #include <string>
21 #include <type_traits>
22 #include <utility>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/client/padding.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/comparison_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/literal_util.h"
33 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
34 #include "tensorflow/compiler/xla/service/hlo.pb.h"
35 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
36 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/stacktrace.h"
44 #include "tensorflow/core/platform/types.h"
45
46 namespace xla {
47
48 class XlaBuilder;
49 class XlaOp;
50 class HloInstruction;
51
52 namespace internal {
53
54 struct XlaBuilderFriend {
55 static XlaOp BuildFusion(XlaBuilder* builder,
56 absl::Span<const XlaOp> operands,
57 absl::string_view fusion_kind,
58 const XlaComputation& fused_computation);
59
60 static XlaOp BuildBitcast(XlaBuilder* builder, XlaOp operand,
61 const Shape& shape);
62
63 static HloInstructionProto* GetInstruction(XlaOp op);
64 };
65
66 } // namespace internal
67
68 // This represents an instruction that has been enqueued using the XlaBuilder.
69 // This is used to pass to subsequent computations that depends upon the
70 // instruction as an operand.
71 class XlaOp {
72 public:
XlaOp()73 XlaOp() : handle_(-1), builder_(nullptr) {
74 static_assert(std::is_trivially_destructible<XlaOp>::value,
75 "XlaOp should be trivially destructible");
76 }
77 ~XlaOp() = default;
78
79 XlaOp(const XlaOp& other) = default;
80 XlaOp& operator=(const XlaOp& other) = default;
81
82 // Precondition: !IsUninitialized().
83 //
84 // It's very common to do foo.builder()->bar(). Without this precondition, if
85 // foo.builder() is null, the call to bar will segfault at some point possibly
86 // deep in the callstack when we finally dereference `this`. The precondition
87 // lets us avoid this tricky-to-debug problem.
builder()88 XlaBuilder* builder() const {
89 CHECK(builder_ != nullptr);
90 return builder_;
91 }
92
93 // Returns true if the XlaOp represents valid, non-erroneous value.
valid()94 bool valid() const { return handle_ >= 0; }
95
96 // Returns true if the XlaOp was created by the XlaOp() constructor and
97 // not returned by a builder.
IsUninitialized()98 bool IsUninitialized() const { return builder_ == nullptr; }
99
IsIdenticalTo(XlaOp rhs)100 bool IsIdenticalTo(XlaOp rhs) const {
101 return handle_ == rhs.handle_ && builder_ == rhs.builder_;
102 }
103
104 friend std::ostream& operator<<(std::ostream& out, XlaOp op) {
105 out << op.handle();
106 return out;
107 }
108
109 private:
XlaOp(XlaBuilder * builder)110 explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {}
XlaOp(int64 handle,XlaBuilder * builder)111 XlaOp(int64 handle, XlaBuilder* builder)
112 : handle_(handle), builder_(builder) {}
113
handle()114 int64 handle() const { return handle_; }
115
116 friend class XlaBuilder;
117 friend class MlirHloBuilder;
118 friend struct internal::XlaBuilderFriend;
119
120 // < 0 means "invalid handle".
121 int64 handle_;
122
123 // Not owned. Non-null for any handle returned by XlaBuilder, even if the
124 // handle is invalid.
125 XlaBuilder* builder_;
126 };
127
128 // Arithmetic operator overloads for the XlaOp type.
129 XlaOp operator-(XlaOp x);
130 XlaOp operator+(XlaOp x, XlaOp y);
131 XlaOp operator-(XlaOp x, XlaOp y);
132 XlaOp operator*(XlaOp x, XlaOp y);
133 XlaOp operator/(XlaOp x, XlaOp y);
134 XlaOp operator%(XlaOp x, XlaOp y);
135
136 // Bitwise operator overloads for the XlaOp type.
137 XlaOp operator~(XlaOp x);
138 XlaOp operator&(XlaOp x, XlaOp y);
139 XlaOp operator|(XlaOp x, XlaOp y);
140 XlaOp operator^(XlaOp x, XlaOp y);
141 XlaOp operator<<(XlaOp x, XlaOp y);
142 // Performs a right arithmetic shift if 'x' is a signed type, otherwise performs
143 // a right logical shift.
144 XlaOp operator>>(XlaOp x, XlaOp y);
145
146 // We don't overload the relational operators (==, !=, <, <=, >, >=) because the
147 // semantics might be surprising since their result types are usually 'bool'.
148 // Further programmers may expect == to be a structural equality.
149 // We also choose not to overload any of the mutating operators (e.g., +=, -=)
150 // because the semantics might be misleading — XLA computations are immutable.
151
152 // A convenient interface for building up computations.
153 //
154 // Thread-compatible.
155 class XlaBuilder {
156 public:
157 // computation_name: name to use for the built computation.
158 XlaBuilder(const string& computation_name);
159
160 XlaBuilder(const XlaBuilder&) = delete;
161 XlaBuilder& operator=(const XlaBuilder&) = delete;
162
163 virtual ~XlaBuilder();
164
165 // Returns the computation name.
name()166 const string& name() const { return name_; }
167
168 // Sets OpMetadata that will be added to all instructions until cleared.
169 //
170 // OpMetadata is often applied to a series of XLA HLO instructions. As a
171 // result, OpMetadata is set on the computation builder. All subsequent
172 // instructions generated via this computation builder will have the same
173 // OpMetadata attached until a call to ClearOpMetadata.
SetOpMetadata(OpMetadata metadata)174 void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); }
175
176 // Swaps the passed op metadata with the ones currently set.
177 //
178 // Returns the old op metadata.
SwapOpMetadata(OpMetadata metadata)179 OpMetadata SwapOpMetadata(OpMetadata metadata) {
180 OpMetadata old_metadata = std::move(metadata_);
181 metadata_ = std::move(metadata);
182 return old_metadata;
183 }
184
185 // Similar to SetOpMetadata, but only set the metadata for the next op.
SetOneShotOpMetadata(OpMetadata metadata)186 void SetOneShotOpMetadata(OpMetadata metadata) {
187 one_shot_metadata_ = std::move(metadata);
188 }
189
190 // Clears the HloMetadata state.
ClearOpMetadata()191 void ClearOpMetadata() { metadata_.Clear(); }
192
193 // Sets an OpSharding that will be attached to all instructions until cleared.
SetSharding(const OpSharding & sharding)194 void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
195
196 // Sets the FrontendAttributes that will be added to all instructions until
197 // cleared.
198 //
199 // FrontendAttributes are often applied to a series of XLA HLO instructions.
200 // As a result they are set on the computation builder and all the
201 // instructions generated via the computation builder will have the same
202 // frontend attributes attached to them.
SetFrontendAttributes(const FrontendAttributes & frontend_attributes)203 void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) {
204 frontend_attributes_ = frontend_attributes;
205 }
206
207 // Swap the passed FrontendAttributes with the ones currently set.
208 //
209 // Return the old attributes.
SwapFrontendAttributes(const FrontendAttributes & frontend_attributes)210 FrontendAttributes SwapFrontendAttributes(
211 const FrontendAttributes& frontend_attributes) {
212 FrontendAttributes old_attributes = std::move(frontend_attributes_);
213 frontend_attributes_ = frontend_attributes;
214 return old_attributes;
215 }
216
217 // Returns the FrontendAttributes that will be attached to all instructions.
frontend_attributes()218 const FrontendAttributes& frontend_attributes() const {
219 return frontend_attributes_;
220 }
221
222 // Clears all the frontend attributes.
ClearFrontendAttributes()223 void ClearFrontendAttributes() { frontend_attributes_.Clear(); }
224
225 // Clears the sharding. Ops will be sharded according to the default placement
226 // policy.
ClearSharding()227 void ClearSharding() { sharding_ = absl::nullopt; }
228
229 // Returns the OpSharding that will be attached to all instructions.
sharding()230 const absl::optional<OpSharding>& sharding() const { return sharding_; }
231
232 // Sets the builder to a mode where it will die immediately when an error is
233 // encountered, rather than producing it in a deferred fashion when Build() is
234 // called (which is the default).
set_die_immediately_on_error(bool enabled)235 void set_die_immediately_on_error(bool enabled) {
236 die_immediately_on_error_ = enabled;
237 }
238
239 // Default dimension numbers used for a 2D convolution.
240 static constexpr int64 kConvBatchDimension = 0;
241 static constexpr int64 kConvFeatureDimension = 1;
242 static constexpr int64 kConvFirstSpatialDimension = 2;
243 static constexpr int64 kConvSecondSpatialDimension = 3;
244 static constexpr int64 kConvKernelOutputDimension = 0;
245 static constexpr int64 kConvKernelInputDimension = 1;
246 static constexpr int64 kConvKernelFirstSpatialDimension = 2;
247 static constexpr int64 kConvKernelSecondSpatialDimension = 3;
248
249 // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
250 // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
251 // the kernel operand
252 // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
253 static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
254 int num_spatial_dims = 2);
255
256 // Returns an error if the convolution dimension numbers have conflicts.
257 static Status Validate(const ConvolutionDimensionNumbers& dnum);
258
259 // Returns a new XlaBuilder whose resultant Computation is used only by this
260 // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
261 // behavior as the parent.
262 std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
263
264 // Builds the computation with the requested operations, or returns a non-ok
265 // status. Note that all ops that have been enqueued will be moved to the
266 // computation being returned. The root of the computation will be the last
267 // added operation.
268 //
269 // `remove_dynamic_dimensions` tells the builder whether to remove the
270 // dynamic dimensions information in all ops.
271 //
272 // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the
273 // dynamic dimensions information when XLA backend can handle dynamic
274 // dimensions.
275 StatusOr<XlaComputation> Build(bool remove_dynamic_dimensions = false);
276
277 // Overload of Build which specifies a particular root instruction for the
278 // computation.
279 StatusOr<XlaComputation> Build(XlaOp root,
280 bool remove_dynamic_dimensions = false);
281
282 // Builds the computation with the requested operations, or notes an error in
283 // the parent XlaBuilder and returns an empty computation if building failed.
284 // This function is intended to be used where the returned XlaComputation is
285 // only used by the parent XlaBuilder and hence further operation on the
286 // returned XlaComputation will simply be error'ed out if an error occurred
287 // while building this computation. If the built computation is to be used by
288 // a XlaBuilder other than the parent XlaBuilder then Build() should be used
289 // instead.
290 XlaComputation BuildAndNoteError();
291
292 // Returns a subgraph that roots on the given root. If the root is not a
293 // compile-time constant (see `IsConstant`), returns an error.
294 //
295 // This will copy the needed ops/computations to the subgraph.
296 StatusOr<XlaComputation> BuildConstantSubGraph(
297 XlaOp root_op, bool dynamic_dimension_is_uint_max = false);
298
299 // Similar to BuildConstantSubGraph, but with root element type changed to
300 // boolean. A true value in the root indicates that the value is dynamic while
301 // false value indicates that the value is a constant. This will copy the
302 // needed ops/computations to the subgraph.
303 //
304 // E.g.,
305 // Compuptation {
306 // a = 3
307 // b = param(0)
308 // ROOT Tuple(a + b, a + 1, b + 1)
309 // }
310 // Calling BuildDynamicInferenceGraph on root will produce the following
311 // graph:
312 //
313 // Compuptation {
314 // a = False
315 // b = True
316 // ROOT Tuple(a | b, a, b)
317 // }
318 //
319 // The result, which is (True, False, True) after evaluation, can be
320 // interpreted as "First element is dynamic; Second element is static; Third
321 // element is dynamic".
322 StatusOr<XlaComputation> BuildDynamicInferenceGraph(XlaOp root_op);
323
324 // Returns the first error that was encountered while building the
325 // computation. When an error is encountered, by default we return a vacuous
326 // XlaOp and inform the user of the error that occurred while
327 // building the computation when they make a final call to Build().
328 //
329 // See also set_die_immediately_on_error().
first_error()330 Status first_error() const { return first_error_; }
331
332 // Returns the current status of the builder, complete with the stack trace
333 // information.
334 Status GetCurrentStatus() const;
335
336 // Returns the shape of the given op.
337 StatusOr<Shape> GetShape(XlaOp op) const;
338
339 // Returns the shape of the given op.
340 virtual StatusOr<const Shape*> GetShapePtr(XlaOp op) const;
341
342 // Returns the (inferred) result for the current computation's shape. This
343 // assumes the root instruction is the last added instruction.
344 StatusOr<ProgramShape> GetProgramShape() const;
345
346 // Returns the (inferred) result for the current computation's shape using the
347 // given operation as the root.
348 StatusOr<ProgramShape> GetProgramShape(XlaOp root) const;
349
350 // Reports an error to the builder, by
351 // * storing it internally and capturing a backtrace if it's the first error
352 // (this deferred value will be produced on the call to
353 // Build()/GetShape()/...)
354 // * dying if die_immediately_on_error_ is true.
355 // Returns an XlaOp with an invalid handle but a valid builder. This value can
356 // be returned in place of a value in APIs that return an XlaOp.
357 XlaOp ReportError(const Status& error);
358
359 // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
360 // If the Status was an error, reports the error to builder and returns an
361 // invalid XlaOp handle.
362 XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
363
364 // A helper function that runs a function that returns a StatusOr<XlaOp> and
365 // returns an XlaOp.
366 XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
367
368 // Returns true if 'operand' is a compile-time constant. A compile-time
369 // constant does not depend on any parameters, or on stateful operators such
370 // as `RngNormal` or `Infeed`.
371 //
372 // This tests whether a computation is a compile-time constant without
373 // evaluating the computation.
374 StatusOr<bool> IsConstant(XlaOp operand) const;
375
376 // Sets up binding which indicates that the `target_dim_num` in the subshape
377 // `target_param_index` of parameter `target_param_num` is a dynamic dimension
378 // and its real dynamic size is represented by `dynamic_param_index` in
379 // parameter `dynamic_param_num`.
380 //
381 // Note that this should be called before the dynamic parameters are used to
382 // create other operations, otherwise created operations won't have the
383 // dynamic dimensions information.
384 //
385 // TODO(b/119520625): Remove this API once we have more dynamic shape infra
386 // ready.
387 ABSL_DEPRECATED("Use SetDimensionSize to set a dynamic dimension.")
388 Status SetDynamicBinding(int64 dynamic_size_param_num,
389 ShapeIndex dynamic_size_param_index,
390 int64 target_param_num,
391 ShapeIndex target_param_index, int64 target_dim_num);
392
393 // Adds a new input/output alias. Since the input/output shape information are
394 // not available until the computation is built, and eventual error in the
395 // arguments of this API will be detected only at computation Build() time.
396 //
397 // Note: Except when 'must-alias' is true, alias is assumed to be 'may-alias'
398 // and only donated buffer at runtime will be aliased with output. If a buffer
399 // is not donated at runtime, a copy will be inserted by XLA to prevent buffer
400 // clobbering.
401 void SetUpAlias(const ShapeIndex& output_index, int64 param_number,
402 const ShapeIndex& param_index,
403 HloInputOutputAliasConfig::AliasKind kind =
404 HloInputOutputAliasConfig::AliasKind::kMayAlias) {
405 input_output_aliases_.push_back(
406 {output_index, param_number, param_index, kind});
407 }
408
409 // Describes an input/output alias as inserted by the SetUpAlias() API.
410 struct InputOutputAlias {
411 // Specifies the index of the aliased buffer in the result tuple.
412 ShapeIndex output_index;
413 // Specifies the parameter containing the buffer to be aliased.
414 int64 param_number;
415 // Specifies the index of the aliased buffer in the parameter
416 ShapeIndex param_index;
417 // Specifies if the alias is a must alias or may alias.
418 HloInputOutputAliasConfig::AliasKind kind;
419 };
420
421 // Looks up the HloInstruction and sets the frontend attribute "attribute" to
422 // "value".
423 //
424 // If the attribute already existed then its value is updated.
425 //
426 // Note: the attribute is only added to the HloInstruction, not to the
427 // builder.
428 Status SetInstructionFrontendAttribute(XlaOp op, string attribute,
429 string value);
430
431 // Returns shapes for the operands.
432 StatusOr<std::vector<Shape>> GetOperandShapes(
433 absl::Span<const XlaOp> operands) const;
434
435 private:
436 // Build helper which takes the id of the root operation..
437 StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions);
438
439 // Description for the methods below can be found in the corresponding public
440 // functions section in this file.
441
442 XlaOp Parameter(int64 parameter_number, const Shape& shape,
443 const string& name,
444 const std::vector<bool>& replicated_at_leaf_buffers);
Parameter(int64 parameter_number,const Shape & shape,const string & name)445 XlaOp Parameter(int64 parameter_number, const Shape& shape,
446 const string& name) {
447 std::vector<bool> empty_bools;
448 return Parameter(parameter_number, shape, name, empty_bools);
449 }
450
451 virtual XlaOp ConstantLiteral(const LiteralSlice& literal);
452
453 XlaOp Broadcast(XlaOp operand, absl::Span<const int64> broadcast_sizes);
454
455 XlaOp BroadcastInDim(XlaOp operand,
456 const absl::Span<const int64> out_dim_size,
457 const absl::Span<const int64> broadcast_dimensions);
458
459 XlaOp Pad(XlaOp operand, XlaOp padding_value,
460 const PaddingConfig& padding_config);
461 XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64 dimno, int64 pad_lo,
462 int64 pad_hi);
463
464 virtual StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
465 XlaOp padding_value,
466 const PaddingConfig& padding_config);
467
468 XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
469 absl::Span<const int64> new_sizes,
470 int64 inferred_dimension = -1);
471
472 XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
473 int64 inferred_dimension = -1);
474
475 XlaOp Reshape(const Shape& shape, XlaOp operand,
476 int64 inferred_dimension = -1);
477
478 XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
479 absl::Span<const int64> new_size_bounds,
480 const std::vector<bool>& dims_are_dynamic);
481
482 XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
483
484 XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
485 absl::Span<const int64> limit_indices,
486 absl::Span<const int64> strides);
487 virtual StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand,
488 absl::Span<const int64> start_indices,
489 absl::Span<const int64> limit_indices,
490 absl::Span<const int64> strides);
491 virtual XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
492 int64 stride, int64 dimno);
493
494 XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
495 absl::Span<const int64> slice_sizes);
496 virtual StatusOr<XlaOp> DynamicSliceInternal(
497 const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
498 absl::Span<const int64> slice_sizes);
499
500 XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
501 absl::Span<const XlaOp> start_indices);
502 virtual StatusOr<XlaOp> DynamicUpdateSliceInternal(
503 const Shape& shape, XlaOp operand, XlaOp update,
504 absl::Span<const XlaOp> start_indices);
505
506 XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
507 virtual StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
508 absl::Span<const XlaOp> operands,
509 int64 dimension);
510
511 void Trace(const string& tag, XlaOp operand);
512
513 XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
514
515 XlaOp Tuple(absl::Span<const XlaOp> elements);
516 virtual StatusOr<XlaOp> TupleInternal(const Shape& shape,
517 absl::Span<const XlaOp> elements);
518
519 XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
520 virtual StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape,
521 XlaOp tuple_data,
522 int64 index);
523
524 XlaOp Dot(
525 XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config = nullptr,
526 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
527
528 XlaOp DotGeneral(
529 XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
530 const PrecisionConfig* precision_config = nullptr,
531 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
532
533 XlaOp Conv(
534 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
535 Padding padding, int64 feature_group_count = 1,
536 int64 batch_group_count = 1,
537 const PrecisionConfig* precision_config = nullptr,
538 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
539
540 XlaOp ConvWithGeneralPadding(
541 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
542 absl::Span<const std::pair<int64, int64>> padding,
543 int64 feature_group_count = 1, int64 batch_group_count = 1,
544 const PrecisionConfig* precision_config = nullptr,
545 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
546
547 XlaOp ConvWithGeneralDimensions(
548 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
549 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
550 int64 feature_group_count = 1, int64 batch_group_count = 1,
551 const PrecisionConfig* precision_config = nullptr,
552 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
553
554 XlaOp ConvGeneral(
555 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
556 absl::Span<const std::pair<int64, int64>> padding,
557 const ConvolutionDimensionNumbers& dimension_numbers,
558 int64 feature_group_count = 1, int64 batch_group_count = 1,
559 const PrecisionConfig* precision_config = nullptr,
560 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
561
562 XlaOp ConvGeneralDilated(
563 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
564 absl::Span<const std::pair<int64, int64>> padding,
565 absl::Span<const int64> lhs_dilation,
566 absl::Span<const int64> rhs_dilation,
567 const ConvolutionDimensionNumbers& dimension_numbers,
568 int64 feature_group_count = 1, int64 batch_group_count = 1,
569 const PrecisionConfig* precision_config = nullptr,
570 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
571
572 XlaOp DynamicConvForward(
573 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
574 absl::Span<const std::pair<int64, int64>> padding,
575 absl::Span<const int64> lhs_dilation,
576 absl::Span<const int64> rhs_dilation,
577 const ConvolutionDimensionNumbers& dimension_numbers,
578 int64 feature_group_count, int64 batch_group_count,
579 const PrecisionConfig* precision_config, PaddingType padding_type,
580 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
581
582 XlaOp DynamicConvInputGrad(
583 XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
584 absl::Span<const int64> window_strides,
585 absl::Span<const std::pair<int64, int64>> padding,
586 absl::Span<const int64> lhs_dilation,
587 absl::Span<const int64> rhs_dilation,
588 const ConvolutionDimensionNumbers& dimension_numbers,
589 int64 feature_group_count, int64 batch_group_count,
590 const PrecisionConfig* precision_config, PaddingType padding_type,
591 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
592
593 XlaOp DynamicConvKernelGrad(
594 XlaOp activations, XlaOp gradients,
595 absl::Span<const int64> window_strides,
596 absl::Span<const std::pair<int64, int64>> padding,
597 absl::Span<const int64> lhs_dilation,
598 absl::Span<const int64> rhs_dilation,
599 const ConvolutionDimensionNumbers& dimension_numbers,
600 int64 feature_group_count, int64 batch_group_count,
601 const PrecisionConfig* precision_config, PaddingType padding_type,
602 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
603
604 StatusOr<HloInstructionProto> DynamicConvInstruction(
605 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
606 absl::Span<const std::pair<int64, int64>> padding,
607 absl::Span<const int64> lhs_dilation,
608 absl::Span<const int64> rhs_dilation,
609 const ConvolutionDimensionNumbers& dimension_numbers,
610 int64 feature_group_count, int64 batch_group_count,
611 const PrecisionConfig* precision_config, PaddingType padding_type,
612 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
613
614 virtual StatusOr<XlaOp> ConvGeneralDilatedInternal(
615 const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
616 absl::Span<const int64> window_strides,
617 absl::Span<const std::pair<int64, int64>> padding,
618 absl::Span<const int64> lhs_dilation,
619 absl::Span<const int64> rhs_dilation,
620 const ConvolutionDimensionNumbers& dimension_numbers,
621 int64 feature_group_count, int64 batch_group_count,
622 const PrecisionConfig* precision_config);
623
624 XlaOp Fft(XlaOp operand, FftType fft_type,
625 absl::Span<const int64> fft_length);
626 virtual StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand,
627 FftType fft_type,
628 absl::Span<const int64> fft_length);
629
630 virtual StatusOr<XlaOp> TriangularSolveInternal(
631 const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options);
632
633 virtual StatusOr<XlaOp> CholeskyInternal(const Shape& shape, XlaOp a,
634 bool lower);
635
636 XlaOp Infeed(const Shape& shape, const string& config = "");
637 XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const string& config);
638 virtual StatusOr<XlaOp> InfeedWithTokenInternal(
639 const Shape& infeed_instruction_shape, XlaOp token, const string& config);
640
641 void Outfeed(XlaOp operand, const Shape& shape_with_layout,
642 const string& outfeed_config);
643 XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
644 const Shape& shape_with_layout,
645 const string& outfeed_config);
646 virtual StatusOr<XlaOp> OutfeedWithTokenInternal(
647 XlaOp operand, XlaOp token, const Shape& shape_with_layout,
648 const string& outfeed_config);
649 XlaOp Call(const XlaComputation& computation,
650 absl::Span<const XlaOp> operands);
651
652 XlaOp CustomCall(
653 const string& call_target_name, absl::Span<const XlaOp> operands,
654 const Shape& shape_with_layout, const string& opaque,
655 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
656 bool has_side_effect,
657 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
658 output_operand_aliasing,
659 const Literal* literal);
660
661 // Internal version of CustomCall without computation that doesn't do op
662 // specific error handling and expects arguments to be legal. CustomCall
663 // method above calls this method after error handling.
664 virtual StatusOr<XlaOp> CustomCallInternal(
665 const string& call_target_name, absl::Span<const XlaOp> operands,
666 const Shape& shape_with_layout, const string& opaque,
667 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
668 bool has_side_effect,
669 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
670 output_operand_aliasing,
671 const Literal* literal);
672
673 XlaOp CustomCall(
674 const string& call_target_name, absl::Span<const XlaOp> operands,
675 const XlaComputation& computation, const Shape& shape_with_layout,
676 const string& opaque,
677 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
678 bool has_side_effect,
679 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
680 output_operand_aliasing,
681 const Literal* literal);
682
683 XlaOp Reduce(XlaOp operand, XlaOp init_value,
684 const XlaComputation& computation,
685 absl::Span<const int64> dimensions_to_reduce);
686
687 XlaOp Reduce(absl::Span<const XlaOp> operands,
688 absl::Span<const XlaOp> init_values,
689 const XlaComputation& computation,
690 absl::Span<const int64> dimensions_to_reduce);
691
692 virtual StatusOr<XlaOp> ReduceInternal(
693 const Shape& shape, absl::Span<const XlaOp> all_operands,
694 const XlaComputation& computation,
695 absl::Span<const int64> dimensions_to_reduce);
696
697 XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
698 const XlaComputation& computation);
699
700 XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
701 const XlaComputation& computation,
702 absl::Span<const int64> window_dimensions,
703 absl::Span<const int64> window_strides, Padding padding);
704
705 XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
706 absl::Span<const XlaOp> init_values,
707 const XlaComputation& computation,
708 absl::Span<const int64> window_dimensions,
709 absl::Span<const int64> window_strides, Padding padding);
710
711 XlaOp ReduceWindowWithGeneralPadding(
712 absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
713 const XlaComputation& computation,
714 absl::Span<const int64> window_dimensions,
715 absl::Span<const int64> window_strides,
716 absl::Span<const int64> base_dilations,
717 absl::Span<const int64> window_dilations,
718 absl::Span<const std::pair<int64, int64>> padding);
719 StatusOr<HloInstructionProto> ReduceWindowInternal(
720 absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
721 const XlaComputation& computation,
722 absl::Span<const int64> window_dimensions,
723 absl::Span<const int64> window_strides,
724 absl::Span<const int64> base_dilations,
725 absl::Span<const int64> window_dilations,
726 absl::Span<const std::pair<int64, int64>> padding);
727 virtual StatusOr<XlaOp> ReduceWindowInternal(
728 const Shape& shape, XlaOp operand, XlaOp init_value,
729 const XlaComputation& computation, Window window);
730 XlaOp CrossReplicaSum(XlaOp operand,
731 absl::Span<const ReplicaGroup> replica_groups = {});
732
733 XlaOp AllGather(
734 XlaOp operand, int64 all_gather_dimension, int64 shard_count,
735 absl::Span<const ReplicaGroup> replica_groups = {},
736 const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
737 const absl::optional<Layout>& layout = absl::nullopt,
738 const absl::optional<bool> use_global_device_ids = absl::nullopt);
739
740 XlaOp AllReduce(
741 XlaOp operand, const XlaComputation& computation,
742 absl::Span<const ReplicaGroup> replica_groups = {},
743 const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
744 const absl::optional<Shape>& shape_with_layout = absl::nullopt);
745
746 XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
747 int64 split_count,
748 const std::vector<ReplicaGroup>& replica_groups,
749 const absl::optional<Layout>& layout = absl::nullopt);
750
751 XlaOp AllToAllTuple(XlaOp operand, int64 split_dimension,
752 int64 concat_dimension, int64 split_count,
753 const std::vector<ReplicaGroup>& replica_groups,
754 const absl::optional<Layout>& layout);
755
756 XlaOp CollectivePermute(
757 XlaOp operand,
758 const std::vector<std::pair<int64, int64>>& source_target_pairs);
759
760 XlaOp ReplicaId();
761
762 XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
763 absl::Span<const int64> window_dimensions,
764 absl::Span<const int64> window_strides,
765 Padding padding, XlaOp source, XlaOp init_value,
766 const XlaComputation& scatter);
767
768 XlaOp SelectAndScatterWithGeneralPadding(
769 XlaOp operand, const XlaComputation& select,
770 absl::Span<const int64> window_dimensions,
771 absl::Span<const int64> window_strides,
772 absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
773 XlaOp init_value, const XlaComputation& scatter);
774
775 StatusOr<HloInstructionProto> SelectAndScatterInternal(
776 XlaOp operand, const XlaComputation& select,
777 absl::Span<const int64> window_dimensions,
778 absl::Span<const int64> window_strides,
779 absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
780 XlaOp init_value, const XlaComputation& scatter);
781
782 virtual XlaOp Iota(const Shape& shape, int64 iota_dimension);
783
784 XlaOp Iota(PrimitiveType type, int64 size);
785
786 XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
787
788 XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
789 virtual StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape,
790 XlaOp operand);
791
792 XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
793 virtual StatusOr<XlaOp> TransposeInternal(
794 const Shape& shape, XlaOp operand, absl::Span<const int64> permutation);
795
796 XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
797 virtual StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand,
798 absl::Span<const int64> dimensions);
799
800 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
801 int64 dimension = -1, bool is_stable = false);
802 virtual StatusOr<XlaOp> SortInternal(const Shape& shape,
803 absl::Span<const XlaOp> operands,
804 const XlaComputation& comparator,
805 int64 dimension, bool is_stable);
806
807 XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
808
809 XlaOp Map(absl::Span<const XlaOp> operands, const XlaComputation& computation,
810 absl::Span<const int64> dimensions,
811 absl::Span<const XlaOp> static_operands = {});
812
813 XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
814
815 XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
816
817 XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
818 const Shape& shape);
819 // Internal variant for the op with the full result shape containing both data
820 // and state shape as a tuple.
821 virtual StatusOr<XlaOp> RngBitGeneratorInternal(
822 const Shape& full_result_shape, RandomAlgorithm algorithm,
823 XlaOp initial_state);
824
825 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
826 XlaOp init);
827 virtual StatusOr<XlaOp> WhileInternal(const Shape& shape,
828 const XlaComputation& condition,
829 const XlaComputation& body, XlaOp init);
830
831 XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
832 const XlaComputation& true_computation, XlaOp false_operand,
833 const XlaComputation& false_computation);
834
835 XlaOp Conditional(XlaOp branch_index,
836 absl::Span<const XlaComputation* const> branch_computations,
837 absl::Span<const XlaOp> branch_operands);
838
839 XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
840 const int mantissa_bits);
841 virtual StatusOr<XlaOp> ReducePrecisionInternal(const Shape& shape,
842 XlaOp operand,
843 const int exponent_bits,
844 const int mantissa_bits);
845
846 XlaOp Gather(XlaOp input, XlaOp start_indices,
847 const GatherDimensionNumbers& dimension_numbers,
848 absl::Span<const int64> slice_sizes,
849 bool indices_are_sorted = false);
850
851 virtual StatusOr<XlaOp> GatherInternal(
852 const Shape& shape, XlaOp input, XlaOp start_indices,
853 const GatherDimensionNumbers& dimension_numbers,
854 absl::Span<const int64> slice_sizes, bool indices_are_sorted);
855
856 XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
857 const XlaComputation& update_computation,
858 const ScatterDimensionNumbers& dimension_numbers,
859 bool indices_are_sorted = false, bool unique_indices = false);
860
861 virtual StatusOr<XlaOp> ScatterInternal(
862 const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
863 const XlaComputation& update_computation,
864 const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
865 bool unique_indices);
866
867 void Send(XlaOp operand, const ChannelHandle& handle);
868 XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle);
869
870 XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout,
871 const ChannelHandle& handle);
872
873 XlaOp RecvFromHost(XlaOp token, const Shape& shape,
874 const ChannelHandle& handle);
875
876 virtual XlaOp CreateToken();
877
878 XlaOp AfterAll(absl::Span<const XlaOp> tokens);
879
880 XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
881 XlaOp RecvWithToken(XlaOp token, const Shape& shape,
882 const ChannelHandle& handle);
883
884 XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
885 float epsilon, int64 feature_index);
886
887 XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean,
888 XlaOp variance, float epsilon, int64 feature_index);
889
890 XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
891 XlaOp batch_var, XlaOp grad_output, float epsilon,
892 int64 feature_index);
893
894 XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
895
896 XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
897
898 virtual StatusOr<XlaOp> SetDimensionSizeInternal(const Shape& shape,
899 XlaOp operand, XlaOp val,
900 int64 dimension);
901
902 XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension);
903
904 virtual StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr,
905 HloOpcode opcode,
906 absl::Span<const XlaOp> operands);
AddInstruction(HloInstructionProto && instr,HloOpcode opcode)907 StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr,
908 HloOpcode opcode) {
909 return AddInstruction(std::move(instr), opcode, /*operands=*/{});
910 }
911
912 void AddCalledComputation(const XlaComputation& computation,
913 HloInstructionProto* instr);
914
915 StatusOr<const HloInstructionProto*> LookUpInstruction(XlaOp op) const;
916 StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
917 int64 handle) const;
918 StatusOr<HloInstructionProto*> LookUpMutableInstruction(XlaOp op);
919 StatusOr<HloInstructionProto*> LookUpMutableInstructionByHandle(int64 handle);
920
921 // Internal helper method that does the building for an arbitrary unary op.
922 virtual XlaOp UnaryOp(HloOpcode unop, XlaOp operand);
923
924 // Internal helper method that does the building for an arbitrary binary op.
925 // broadcast_dimensions specifies which dimensions to use for broadcasting
926 // when the operation is between tensors of different ranks. The direction is
927 // only used if opcode is kCompare.
928 XlaOp BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
929 absl::Span<const int64> broadcast_dimensions,
930 absl::optional<ComparisonDirection> direction = absl::nullopt,
931 absl::optional<Comparison::Type> type = absl::nullopt);
932
933 StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
934 ComparisonDirection direction);
935
936 // Internal helper method for binary op compare without broadcast dimensions.
937 virtual StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
938 ComparisonDirection direction,
939 Comparison::Type type);
940
941 // Internal helper method that does the building for an arbitrary binary op
942 // with same ranked operands that doesn't broadcast.
943 virtual XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
944 XlaOp lhs, XlaOp rhs);
945
946 // Internal helper method that does the building for an arbitrary ternary op.
947 XlaOp TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs);
948
949 XlaOp RngOp(RandomDistribution distribution,
950 absl::Span<const XlaOp> parameters, const Shape& shape);
951
952 virtual StatusOr<XlaOp> RngOpInternal(RandomDistribution distribution,
953 absl::Span<const XlaOp> parameters,
954 const Shape& shape);
955
956 virtual StatusOr<XlaOp> InDimBroadcast(
957 const Shape& shape, XlaOp operand,
958 absl::Span<const int64> broadcast_dimensions);
959
960 // Internal helper method that creates a sequence of instructions that
961 // performs an explicit broadcast of the operand to the target shape.
962 StatusOr<XlaOp> AddBroadcastSequence(const Shape& output_shape,
963 XlaOp operand);
964
965 // Internal helper method for creating a Reshape op with the already inferred
966 // shape.
967 virtual StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand,
968 int64 inferred_dimension);
969
970 // Returns the (inferred) result for the program shape using the given root.
971 StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;
972
973 // A visitor which checks whether an operation is a compile-time constant,
974 // meaning that it doesn't depend on any parameters, or on any stateful
975 // operation such as `RngNormal` or `Infeed`. The visitor walks the
976 // computation starting at a given operation and sets is_constant to false iff
977 // a parameter or stateful operation is encountered.
978 void IsConstantVisitor(const int64 op_handle,
979 absl::flat_hash_set<int64>* visited,
980 bool* is_constant) const;
981
982 // Checks bounds for convolution parameters.
983 Status VerifyConvolution(
984 const Shape& lhs_shape, const Shape& rhs_shape,
985 const ConvolutionDimensionNumbers& dimension_numbers) const;
986
GetNextId()987 int64 GetNextId() { return ++next_id_; }
988
989 // Populates the module with the input/output alias information stored within
990 // the input_output_aliases vector.
991 static Status PopulateInputOutputAlias(
992 HloModuleProto* module, const ProgramShape& program_shape,
993 const std::vector<InputOutputAlias>& input_output_aliases);
994
995 string name_; // Name to use for the built computation.
996
997 // The next sequential ID for every instruction/computation contained within
998 // this computation.
999 int64 next_id_ = 0;
1000
1001 // The first error encountered while building the computation.
1002 // This is OK until the first error is encountered.
1003 Status first_error_;
1004
1005 // The saved stack trace from the point at which the first error occurred.
1006 tensorflow::SavedStackTrace first_error_backtrace_;
1007
1008 // The instructions of this computation.
1009 std::vector<HloInstructionProto> instructions_;
1010
1011 // An cache for the HloInstructionProto shapes, to avoid recreating Shape
1012 // objects from protos and to support the GetShapePtr() API.
1013 std::vector<std::unique_ptr<Shape>> instruction_shapes_;
1014
1015 // Dynamic parameter configuration of this computation.
1016 DynamicParameterBinding dynamic_parameter_binding_;
1017
1018 // Holds the input/output alias information populated by the SetUpAlias() API.
1019 std::vector<InputOutputAlias> input_output_aliases_;
1020
1021 // A map from XlaOp::Handle to the index in the instructions_ vector where the
1022 // instruction is held.
1023 absl::flat_hash_map<int64, int64> handle_to_index_;
1024
1025 // The embedded computations used by this computation. Each computation was
1026 // the entry computation of some XlaComputation, the key is the unique id of
1027 // that XlaComputation.
1028 std::map<int64, HloComputationProto> embedded_;
1029
1030 // The unique parameter numbers.
1031 absl::flat_hash_set<int64> parameter_numbers_;
1032
1033 // The metadata to attach to each op. This is structured as a "modal"-like
1034 // operation, in order to simplify client code (and not sprinkle this metadata
1035 // throughout the TensorFlow op kernel implementations).
1036 OpMetadata metadata_;
1037
1038 // A temporary metadata that will only be applied to the next op created.
1039 absl::optional<OpMetadata> one_shot_metadata_;
1040
1041 // Sharding for this operator. This is structured as a "model"-like operation,
1042 // in order to simplify client code, similar to metadata_.
1043 absl::optional<OpSharding> sharding_;
1044
1045 // Mode bit that indicates whether to die when a first error is encountered.
1046 bool die_immediately_on_error_ = false;
1047
1048 XlaBuilder* parent_builder_{nullptr};
1049
1050 FrontendAttributes frontend_attributes_;
1051
1052 friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number,
1053 const Shape& shape, const string& name,
1054 const std::vector<bool>& replicated_at_leaf_buffers);
1055 friend XlaOp ConstantLiteral(XlaBuilder* builder,
1056 const LiteralSlice& literal);
1057
1058 friend XlaOp Broadcast(XlaOp operand,
1059 absl::Span<const int64> broadcast_sizes);
1060
1061 friend XlaOp BroadcastInDim(
1062 XlaOp operand, const absl::Span<const int64> out_dim_size,
1063 const absl::Span<const int64> broadcast_dimensions);
1064
1065 friend XlaOp Copy(XlaOp operand);
1066
1067 friend XlaOp Pad(XlaOp operand, XlaOp padding_value,
1068 const PaddingConfig& padding_config);
1069
1070 friend XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64 dimno,
1071 int64 pad_lo, int64 pad_hi);
1072
1073 friend XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
1074 absl::Span<const int64> new_sizes);
1075
1076 friend XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
1077
1078 friend XlaOp Reshape(const Shape& shape, XlaOp operand);
1079
1080 friend XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
1081 absl::Span<const int64> new_size_bounds,
1082 const std::vector<bool>& dims_are_dynamic);
1083
1084 friend XlaOp ReshapeWithInferredDimension(XlaOp operand,
1085 absl::Span<const int64> new_sizes,
1086 int64 inferred_dimension);
1087
1088 friend XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
1089
1090 friend XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
1091 absl::Span<const int64> limit_indices,
1092 absl::Span<const int64> strides);
1093
1094 friend XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
1095 int64 stride, int64 dimno);
1096
1097 friend XlaOp DynamicSlice(XlaOp operand,
1098 absl::Span<const XlaOp> start_indices,
1099 absl::Span<const int64> slice_sizes);
1100
1101 friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
1102 absl::Span<const XlaOp> start_indices);
1103
1104 friend XlaOp ConcatInDim(XlaBuilder* builder,
1105 absl::Span<const XlaOp> operands, int64 dimension);
1106
1107 friend void Trace(const string& tag, XlaOp operand);
1108
1109 friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
1110 friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
1111 friend XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
1112 friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
1113 absl::Span<const int64> broadcast_dimensions,
1114 ComparisonDirection direction);
1115 friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
1116 absl::Span<const int64> broadcast_dimensions,
1117 ComparisonDirection direction,
1118 Comparison::Type compare_type);
1119 friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
1120 const PrecisionConfig* precision_config,
1121 absl::optional<PrimitiveType> preferred_element_type);
1122 friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
1123 const DotDimensionNumbers& dimension_number,
1124 const PrecisionConfig* precision_config,
1125 absl::optional<PrimitiveType> preferred_element_type);
1126 virtual StatusOr<XlaOp> DotGeneralInternal(
1127 const Shape& shape, XlaOp lhs, XlaOp rhs,
1128 const DotDimensionNumbers& dimension_number,
1129 const PrecisionConfig* precision_config);
1130 friend XlaOp Conv(XlaOp lhs, XlaOp rhs,
1131 absl::Span<const int64> window_strides, Padding padding,
1132 int64 feature_group_count, int64 batch_group_count,
1133 const PrecisionConfig* precision_config,
1134 absl::optional<PrimitiveType> preferred_element_type);
1135 friend XlaOp ConvWithGeneralPadding(
1136 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1137 absl::Span<const std::pair<int64, int64>> padding,
1138 int64 feature_group_count, int64 batch_group_count,
1139 const PrecisionConfig* precision_config,
1140 absl::optional<PrimitiveType> preferred_element_type);
1141 friend XlaOp ConvWithGeneralDimensions(
1142 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1143 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1144 int64 feature_group_count, int64 batch_group_count,
1145 const PrecisionConfig* precision_config,
1146 absl::optional<PrimitiveType> preferred_element_type);
1147 friend XlaOp ConvGeneral(
1148 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1149 absl::Span<const std::pair<int64, int64>> padding,
1150 const ConvolutionDimensionNumbers& dimension_numbers,
1151 int64 feature_group_count, int64 batch_group_count,
1152 const PrecisionConfig* precision_config,
1153 absl::optional<PrimitiveType> preferred_element_type);
1154 friend XlaOp DynamicConvForward(
1155 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1156 absl::Span<const std::pair<int64, int64>> padding,
1157 absl::Span<const int64> lhs_dilation,
1158 absl::Span<const int64> rhs_dilation,
1159 const ConvolutionDimensionNumbers& dimension_numbers,
1160 int64 feature_group_count, int64 batch_group_count,
1161 const PrecisionConfig* precision_config, PaddingType padding_type,
1162 absl::optional<PrimitiveType> preferred_element_type);
1163 friend XlaOp DynamicConvKernelGrad(
1164 XlaOp activations, XlaOp gradients,
1165 absl::Span<const int64> window_strides,
1166 absl::Span<const std::pair<int64, int64>> padding,
1167 absl::Span<const int64> lhs_dilation,
1168 absl::Span<const int64> rhs_dilation,
1169 const ConvolutionDimensionNumbers& dimension_numbers,
1170 int64 feature_group_count, int64 batch_group_count,
1171 const PrecisionConfig* precision_config, PaddingType padding_type,
1172 absl::optional<PrimitiveType> preferred_element_type);
1173 friend XlaOp DynamicConvInputGrad(
1174 XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
1175 absl::Span<const int64> window_strides,
1176 absl::Span<const std::pair<int64, int64>> padding,
1177 absl::Span<const int64> lhs_dilation,
1178 absl::Span<const int64> rhs_dilation,
1179 const ConvolutionDimensionNumbers& dimension_numbers,
1180 int64 feature_group_count, int64 batch_group_count,
1181 const PrecisionConfig* precision_config, PaddingType padding_type,
1182 absl::optional<PrimitiveType> preferred_element_type);
1183
1184 friend XlaOp ConvKernelGrad(
1185 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1186 absl::Span<const std::pair<int64, int64>> padding,
1187 absl::Span<const int64> lhs_dilation,
1188 absl::Span<const int64> rhs_dilation,
1189 const ConvolutionDimensionNumbers& dimension_numbers,
1190 int64 feature_group_count, int64 batch_group_count,
1191 const PrecisionConfig* precision_config,
1192 absl::optional<PrimitiveType> preferred_element_type);
1193
1194 friend XlaOp ConvGeneralDilated(
1195 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1196 absl::Span<const std::pair<int64, int64>> padding,
1197 absl::Span<const int64> lhs_dilation,
1198 absl::Span<const int64> rhs_dilation,
1199 const ConvolutionDimensionNumbers& dimension_numbers,
1200 int64 feature_group_count, int64 batch_group_count,
1201 const PrecisionConfig* precision_config,
1202 absl::optional<PrimitiveType> preferred_element_type);
1203 friend XlaOp Fft(XlaOp operand, FftType fft_type,
1204 absl::Span<const int64> fft_length);
1205 friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
1206 bool unit_diagonal,
1207 TriangularSolveOptions::Transpose transpose_a);
1208 friend XlaOp Cholesky(XlaOp a, bool lower);
1209 friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
1210 const string& config);
1211 friend void Outfeed(XlaOp operand, const Shape& shape_with_layout,
1212 const string& outfeed_config);
1213 friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
1214 absl::Span<const XlaOp> operands);
1215 friend XlaOp CustomCall(
1216 XlaBuilder* builder, const string& call_target_name,
1217 absl::Span<const XlaOp> operands, const Shape& shape,
1218 const string& opaque, bool has_side_effect,
1219 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1220 output_operand_aliasing,
1221 const Literal* literal);
1222 friend XlaOp CustomCallWithComputation(
1223 XlaBuilder* builder, const string& call_target_name,
1224 absl::Span<const XlaOp> operands, const XlaComputation& computation,
1225 const Shape& shape, const string& opaque, bool has_side_effect,
1226 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1227 output_operand_aliasing,
1228 const Literal* literal);
1229 friend XlaOp CustomCallWithLayout(
1230 XlaBuilder* builder, const string& call_target_name,
1231 absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
1232 absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
1233 bool has_side_effect,
1234 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1235 output_operand_aliasing,
1236 const Literal* literal);
1237 friend XlaOp Complex(XlaOp real, XlaOp imag,
1238 absl::Span<const int64> broadcast_dimensions);
1239 friend XlaOp Conj(XlaOp operand);
1240 friend XlaOp Add(XlaOp lhs, XlaOp rhs,
1241 absl::Span<const int64> broadcast_dimensions);
1242 friend XlaOp Sub(XlaOp lhs, XlaOp rhs,
1243 absl::Span<const int64> broadcast_dimensions);
1244 friend XlaOp Mul(XlaOp lhs, XlaOp rhs,
1245 absl::Span<const int64> broadcast_dimensions);
1246 friend XlaOp Div(XlaOp lhs, XlaOp rhs,
1247 absl::Span<const int64> broadcast_dimensions);
1248 friend XlaOp Rem(XlaOp lhs, XlaOp rhs,
1249 absl::Span<const int64> broadcast_dimensions);
1250 friend XlaOp Max(XlaOp lhs, XlaOp rhs,
1251 absl::Span<const int64> broadcast_dimensions);
1252 friend XlaOp Min(XlaOp lhs, XlaOp rhs,
1253 absl::Span<const int64> broadcast_dimensions);
1254 friend XlaOp And(XlaOp lhs, XlaOp rhs,
1255 absl::Span<const int64> broadcast_dimensions);
1256 friend XlaOp Or(XlaOp lhs, XlaOp rhs,
1257 absl::Span<const int64> broadcast_dimensions);
1258 friend XlaOp Xor(XlaOp lhs, XlaOp rhs,
1259 absl::Span<const int64> broadcast_dimensions);
1260 friend XlaOp Not(XlaOp operand);
1261 friend XlaOp PopulationCount(XlaOp operand);
1262 friend XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs,
1263 absl::Span<const int64> broadcast_dimensions);
1264 friend XlaOp ShiftRightArithmetic(
1265 XlaOp lhs, XlaOp rhs, absl::Span<const int64> broadcast_dimensions);
1266 friend XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs,
1267 absl::Span<const int64> broadcast_dimensions);
1268 friend XlaOp Reduce(XlaOp operand, XlaOp init_value,
1269 const XlaComputation& computation,
1270 absl::Span<const int64> dimensions_to_reduce);
1271 friend XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1272 absl::Span<const XlaOp> init_values,
1273 const XlaComputation& computation,
1274 absl::Span<const int64> dimensions_to_reduce);
1275 friend XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
1276 const XlaComputation& computation);
1277 friend XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
1278 const XlaComputation& computation,
1279 absl::Span<const int64> window_dimensions,
1280 absl::Span<const int64> window_strides,
1281 Padding padding);
1282 friend XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
1283 absl::Span<const XlaOp> init_values,
1284 const XlaComputation& computation,
1285 absl::Span<const int64> window_dimensions,
1286 absl::Span<const int64> window_strides,
1287 Padding padding);
1288 friend XlaOp ReduceWindowWithGeneralPadding(
1289 XlaOp operand, XlaOp init_value, const XlaComputation& computation,
1290 absl::Span<const int64> window_dimensions,
1291 absl::Span<const int64> window_strides,
1292 absl::Span<const int64> base_dilations,
1293 absl::Span<const int64> window_dilations,
1294 absl::Span<const std::pair<int64, int64>> padding);
1295 friend XlaOp CrossReplicaSum(XlaOp operand,
1296 absl::Span<const ReplicaGroup> replica_groups);
1297 friend XlaOp AllGather(XlaOp operand, int64 all_gather_dimension,
1298 int64 shard_count,
1299 absl::Span<const ReplicaGroup> replica_groups,
1300 const absl::optional<ChannelHandle>& channel_id,
1301 const absl::optional<Layout>& layout,
1302 const absl::optional<bool> use_global_device_ids);
1303 friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
1304 absl::Span<const ReplicaGroup> replica_groups,
1305 const absl::optional<ChannelHandle>& channel_id,
1306 const absl::optional<Shape>& shape_with_layout);
1307 friend XlaOp AllToAll(XlaOp operand, int64 split_dimension,
1308 int64 concat_dimension, int64 split_count,
1309 const std::vector<ReplicaGroup>& replica_groups,
1310 const absl::optional<Layout>& layout);
1311 friend XlaOp AllToAllTuple(XlaOp operand, int64 split_dimension,
1312 int64 concat_dimension, int64 split_count,
1313 const std::vector<ReplicaGroup>& replica_groups,
1314 const absl::optional<Layout>& layout);
1315 friend XlaOp CollectivePermute(
1316 XlaOp operand,
1317 const std::vector<std::pair<int64, int64>>& source_target_pairs);
1318 friend XlaOp ReplicaId(XlaBuilder* builder);
1319 friend XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
1320 absl::Span<const int64> window_dimensions,
1321 absl::Span<const int64> window_strides,
1322 Padding padding, XlaOp source, XlaOp init_value,
1323 const XlaComputation& scatter);
1324 friend XlaOp SelectAndScatterWithGeneralPadding(
1325 XlaOp operand, const XlaComputation& select,
1326 absl::Span<const int64> window_dimensions,
1327 absl::Span<const int64> window_strides,
1328 absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
1329 XlaOp init_value, const XlaComputation& scatter);
1330 friend XlaOp Abs(XlaOp operand);
1331 friend XlaOp Atan2(XlaOp y, XlaOp x,
1332 absl::Span<const int64> broadcast_dimensions);
1333 friend XlaOp Exp(XlaOp operand);
1334 friend XlaOp Expm1(XlaOp operand);
1335 friend XlaOp Floor(XlaOp operand);
1336 friend XlaOp Ceil(XlaOp operand);
1337 friend XlaOp Round(XlaOp operand);
1338 friend XlaOp Log(XlaOp operand);
1339 friend XlaOp Log1p(XlaOp operand);
1340 friend XlaOp Logistic(XlaOp operand);
1341 friend XlaOp Sign(XlaOp operand);
1342 friend XlaOp Clz(XlaOp operand);
1343 friend XlaOp Cos(XlaOp operand);
1344 friend XlaOp Sin(XlaOp operand);
1345 friend XlaOp Tanh(XlaOp operand);
1346 friend XlaOp Real(XlaOp operand);
1347 friend XlaOp Imag(XlaOp operand);
1348 friend XlaOp Sqrt(XlaOp operand);
1349 friend XlaOp Rsqrt(XlaOp operand);
1350 friend XlaOp Cbrt(XlaOp operand);
1351 friend XlaOp Pow(XlaOp lhs, XlaOp rhs,
1352 absl::Span<const int64> broadcast_dimensions);
1353 friend XlaOp IsFinite(XlaOp operand);
1354 friend XlaOp Iota(XlaBuilder* builder, const Shape& shape,
1355 int64 iota_dimension);
1356 friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
1357 friend XlaOp ConvertElementType(XlaOp operand,
1358 PrimitiveType new_element_type);
1359 friend XlaOp BitcastConvertType(XlaOp operand,
1360 PrimitiveType new_element_type);
1361 friend XlaOp Neg(XlaOp operand);
1362 friend XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
1363 friend XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
1364 friend XlaOp Sort(absl::Span<const XlaOp> operands,
1365 const XlaComputation& comparator, int64 dimension,
1366 bool is_stable);
1367 friend XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
1368 friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1369 const XlaComputation& computation,
1370 absl::Span<const int64> dimensions,
1371 absl::Span<const XlaOp> static_operands);
1372 friend XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
1373 friend XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
1374 friend XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
1375 const Shape& shape);
1376 friend XlaOp While(const XlaComputation& condition,
1377 const XlaComputation& body, XlaOp init);
1378 friend XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
1379 const XlaComputation& true_computation,
1380 XlaOp false_operand,
1381 const XlaComputation& false_computation);
1382 friend XlaOp Conditional(
1383 XlaOp branch_index,
1384 absl::Span<const XlaComputation* const> branch_computations,
1385 absl::Span<const XlaOp> branch_operands);
1386 friend XlaOp ConditionalImpl(
1387 XlaOp branch_index,
1388 absl::Span<const XlaComputation* const> branch_computations,
1389 absl::Span<const XlaOp> branch_operands);
1390 friend XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
1391 const int mantissa_bits);
1392 friend XlaOp Gather(XlaOp input, XlaOp start_indices,
1393 const GatherDimensionNumbers& dimension_numbers,
1394 absl::Span<const int64> slice_sizes,
1395 bool indices_are_sorted);
1396 friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
1397 const XlaComputation& update_computation,
1398 const ScatterDimensionNumbers& dimension_numbers,
1399 bool indices_are_sorted, bool unique_indices);
1400 friend void Send(XlaOp operand, const ChannelHandle& handle);
1401 friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
1402 const ChannelHandle& handle);
1403 friend XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
1404 float epsilon, int64 feature_index);
1405 friend XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset,
1406 XlaOp mean, XlaOp variance, float epsilon,
1407 int64 feature_index);
1408 friend XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
1409 XlaOp batch_var, XlaOp grad_output, float epsilon,
1410 int64 feature_index);
1411 friend XlaOp SendWithToken(XlaOp operand, XlaOp token,
1412 const ChannelHandle& handle);
1413 friend XlaOp RecvWithToken(XlaOp token, const Shape& shape,
1414 const ChannelHandle& handle);
1415 friend XlaOp SendToHost(XlaOp operand, XlaOp token,
1416 const Shape& shape_with_layout,
1417 const ChannelHandle& handle);
1418 friend XlaOp RecvFromHost(XlaOp token, const Shape& shape,
1419 const ChannelHandle& handle);
1420 friend XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
1421 const string& config);
1422 friend XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
1423 const Shape& shape_with_layout,
1424 const string& outfeed_config);
1425 friend XlaOp CreateToken(XlaBuilder* builder);
1426 friend XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
1427
1428 friend XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
1429 friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
1430 friend XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension);
1431
1432 protected:
1433 // Returns OK status if the given op was built using this builder. Otherwise,
1434 // returns an error.
1435 Status CheckOpBuilder(XlaOp op) const;
1436
1437 private:
1438 XlaOp ConditionalImpl(
1439 XlaOp branch_index,
1440 absl::Span<const XlaComputation* const> branch_computations,
1441 absl::Span<const XlaOp> branch_operands);
1442
1443 XlaOp AllToAllArray(XlaOp operand, int64 split_dimension,
1444 int64 concat_dimension, int64 split_count,
1445 const std::vector<ReplicaGroup>& replica_groups);
1446
1447 // Creates an op with the given opcode and the output shape.
1448 virtual StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
1449 absl::Span<const XlaOp> operands);
1450
1451 // Here, InstructionType is either const HloInstructionProto* or non-const
1452 // HloInstructionProto*.
1453 template <typename InstructionType>
LookUpInstructionByHandleInternal(int64 handle)1454 StatusOr<InstructionType> LookUpInstructionByHandleInternal(
1455 int64 handle) const {
1456 auto it = handle_to_index_.find(handle);
1457 if (it == handle_to_index_.end()) {
1458 return InvalidArgument("No XlaOp with handle %d", handle);
1459 }
1460 return const_cast<InstructionType>(&instructions_.at(it->second));
1461 }
1462
1463 // Here, InstructionType is either const HloInstructionProto* or non-const
1464 // HloInstructionProto*.
1465 //
1466 // TODO(hinsu): Return const pointer within StatusOr and use
1467 // absl::implicit_cast at callsites. This requires implicit_cast support in
1468 // stream_executor::port::StatusOr similar to absl::StatusOr.
1469 template <typename InstructionType>
LookUpInstructionInternal(XlaOp op)1470 StatusOr<InstructionType> LookUpInstructionInternal(XlaOp op) const {
1471 TF_RETURN_IF_ERROR(CheckOpBuilder(op));
1472 return LookUpInstructionByHandleInternal<InstructionType>(op.handle());
1473 }
1474
1475 friend struct internal::XlaBuilderFriend;
1476 };
1477
1478 // RAII-style object: sets the current sharding assignment in builder on
1479 // construction, and sets back to the previous assignment on destruction.
1480 class XlaScopedShardingAssignment {
1481 public:
XlaScopedShardingAssignment(xla::XlaBuilder * builder,absl::optional<OpSharding> sharding)1482 XlaScopedShardingAssignment(xla::XlaBuilder* builder,
1483 absl::optional<OpSharding> sharding)
1484 : builder_(builder), prev_sharding_(builder->sharding()) {
1485 SetSharding(sharding);
1486 }
1487
1488 XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
1489 XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
1490 delete;
1491
~XlaScopedShardingAssignment()1492 ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
1493
1494 private:
SetSharding(const absl::optional<OpSharding> & sharding)1495 void SetSharding(const absl::optional<OpSharding>& sharding) {
1496 if (sharding.has_value()) {
1497 builder_->SetSharding(sharding.value());
1498 } else {
1499 builder_->ClearSharding();
1500 }
1501 }
1502
1503 xla::XlaBuilder* const builder_;
1504 absl::optional<OpSharding> prev_sharding_;
1505 };
1506
1507 // RAII-style object: save the current builder's frontend attributes, and merge
1508 // them with the new ones on construction.
1509 // Restore the original attributes on destruction.
1510 class XlaScopedFrontendAttributesAssignment {
1511 public:
XlaScopedFrontendAttributesAssignment(xla::XlaBuilder * builder,FrontendAttributes attributes)1512 XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder,
1513 FrontendAttributes attributes)
1514 : builder_(builder) {
1515 saved_ = builder_->SwapFrontendAttributes(attributes);
1516 }
1517
~XlaScopedFrontendAttributesAssignment()1518 ~XlaScopedFrontendAttributesAssignment() {
1519 builder_->SetFrontendAttributes(saved_);
1520 }
1521
1522 private:
1523 xla::XlaBuilder* const builder_;
1524 FrontendAttributes saved_;
1525
1526 TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedFrontendAttributesAssignment);
1527 };
1528
1529 // RAII-style object: sets the current op metadata in builder on construction,
1530 // and sets back to the previous assignment on destruction.
1531 class XlaScopedOpMetadataAssignment {
1532 public:
XlaScopedOpMetadataAssignment(xla::XlaBuilder * builder,OpMetadata metadata)1533 XlaScopedOpMetadataAssignment(xla::XlaBuilder* builder, OpMetadata metadata)
1534 : builder_(builder) {
1535 saved_ = builder_->SwapOpMetadata(metadata);
1536 }
1537
~XlaScopedOpMetadataAssignment()1538 ~XlaScopedOpMetadataAssignment() { builder_->SwapOpMetadata(saved_); }
1539
1540 private:
1541 xla::XlaBuilder* const builder_;
1542 OpMetadata saved_;
1543
1544 TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedOpMetadataAssignment);
1545 };
1546
1547 // Free functions for building XlaOps. The intention is that these will
1548 // become the public API for building XlaOps rather than calling methods on
1549 // XlaBuilder directly.
1550 //
1551
1552 // Enqueues a "retrieve parameter value" instruction for a parameter that was
1553 // passed to the computation.
1554 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
1555 const string& name);
1556
1557 // Same as above, but with leaf buffer replication annotation.
1558 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
1559 const string& name,
1560 const std::vector<bool>& replicated_at_leaf_buffers);
1561
1562 // Enqueues a constant with the value of the given literal onto the
1563 // computation.
1564 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal);
1565
1566 // Enqueues a constant onto the computation. Methods are templated on the
1567 // native host type (NativeT) which corresponds to a specific XLA
1568 // PrimitiveType as given in the following table:
1569 //
1570 // Native Type PrimitiveType
1571 // -----------------------------
1572 // bool PRED
1573 // int32 S32
1574 // int64 S64
1575 // uint32 U32
1576 // uint64 U64
1577 // float F32
1578 // double F64
1579 //
1580 // Note: not all primitive types defined in xla_data.proto have a
1581 // corresponding native type yet.
1582 template <typename NativeT>
1583 XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
1584 template <typename NativeT>
1585 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values);
1586 XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values);
1587 template <typename NativeT>
1588 XlaOp ConstantR2(XlaBuilder* builder,
1589 std::initializer_list<std::initializer_list<NativeT>> values);
1590 template <typename NativeT>
1591 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
1592 const Array<NativeT>& values,
1593 const Layout& layout);
1594 template <typename NativeT>
1595 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values);
1596 template <typename NativeT>
1597 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
1598 const Array2D<NativeT>& values,
1599 const Layout& layout);
1600 template <typename NativeT>
1601 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
1602 const Array2D<NativeT>& values);
1603 template <typename NativeT>
1604 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
1605 const Array3D<NativeT>& values,
1606 const Layout& layout);
1607 template <typename NativeT>
1608 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
1609 const Array3D<NativeT>& values);
1610 template <typename NativeT>
1611 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
1612 const Array4D<NativeT>& values,
1613 const Layout& layout);
1614 template <typename NativeT>
1615 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
1616 const Array4D<NativeT>& values);
1617
1618 // Enqueues a rank one constant (XlaBuilder* builder, vector) onto the
1619 // computation. The vector has size 'length' and every element has the value
1620 // 'value'.
1621 template <typename NativeT>
1622 XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
1623
1624 // Adds dimensions to an array by duplicating the data in the array.
1625 //
1626 // The new dimensions are inserted on the left, i.e. if
1627 // broadcast_sizes has values {a0, ..., aN} and the operand shape
1628 // has dimensions {b0, ..., bM} then the shape of the output has
1629 // dimensions {a0, ..., aN, b0, ..., bM}.
1630 //
1631 // The new dimensions index into copies of the operand, i.e.
1632 //
1633 // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
1634 XlaOp Broadcast(XlaOp operand, absl::Span<const int64> broadcast_sizes);
1635
1636 // This op broadcasts the `operand` to an output with the given `shape`.
1637 // `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the
1638 // i'th dimension of the operand is mapped to the broadcast_dimensions[i]'th
1639 // dimension of the output. This also requires that the i'th input dimension is
1640 // either 1 or is the same as the output dimension it's broadcasting into.
1641 //
1642 // For example, say operand = {1, 2}, i.e., a 1D tensor in shape s32[2]; the
1643 // output shape is s32[2,2]:
1644 // - Specifying {1} as broadcast_dimension will generate output
1645 // {{1, 2},
1646 // {1, 2}}
1647 // - On the other hand, specifying {0} as broadcast_dimension
1648 // will generate output
1649 // {{1 , 1},
1650 // {2 , 2}}
1651 XlaOp BroadcastInDim(XlaOp operand, const absl::Span<const int64> out_dim_size,
1652 const absl::Span<const int64> broadcast_dimensions);
1653
1654 // Copies the input operand to the output. This operation is for internal
1655 // purpose and is only used by the compiler for optimization purposes or to
1656 // ensure correctness. The XLA client should never have to generate this
1657 // instruction.
1658 //
1659 // Copy has two potential use cases:
1660 //
1661 // * Create a copy of the operand with a new layout.
1662 //
1663 // * Create a copy of the operand in a separately allocated buffer. This is
1664 // necessary for some backends if the operand is a parameter or constant and
1665 // the operand is returned within a tuple. In this case, the lifetime of the
1666 // operand buffer must be the same as the lifetime of the output result.
1667 // However, the lifetimes of parameters and constants are managed separately
1668 // from the lifetime of the output result. Creating a separate copy of the
1669 // parameter or constant buffer resolves this issue.
1670 XlaOp Copy(XlaOp operand);
1671
1672 // Enqueues a pad operation onto the computation that pads the given value on
1673 // the edges as well as between the elements of the input. padding_config
1674 // specifies the padding amount for each dimension.
1675 XlaOp Pad(XlaOp operand, XlaOp padding_value,
1676 const PaddingConfig& padding_config);
1677
1678 // Enqueues a pad operation in a given dimension, taking all other
1679 // dimensions as they are.
1680 XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64 dimno, int64 pad_lo,
1681 int64 pad_hi);
1682
1683 // Enqueues an operation onto the computation that flattens the operand based
1684 // on the dimension order (major/slowest-varying to minor/fastest-varying)
1685 // given, followed by reshaping it into the shape with the given dimension
1686 // sizes (also major to minor). Conceptually, this is a limited form of
1687 // "shape casting".
1688 XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
1689 absl::Span<const int64> new_sizes);
1690
1691 // Enqueues a dynamic reshape operation. The dynamic reshape takes additional
1692 // XlaOps as sizes for the result dimension. The result dim i is a dynamic
1693 // dimension dimension if dims_are_dynamic[i] is true.
1694 XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
1695 absl::Span<const int64> new_size_bounds,
1696 const std::vector<bool>& dims_are_dynamic);
1697
1698 // Enqueues an operation onto the computation that collapses the operand,
1699 // from first to last dimension (C order), then reshapes it to the given
1700 // dimension sizes. Conceptually, this is a limited form of "shape casting".
1701 XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
1702
1703 // Enqueues a Reshape op that uses an explicit target shape.
1704 XlaOp Reshape(const Shape& shape, XlaOp operand);
1705
1706 // `inferred_dimension` represents the output dimension that's inferred by
1707 // upper-level framework by dividing the input element count by the known
1708 // output element count. While an inferred_dimension can be static, if there
1709 // is a dynamic dimension in the output, it must be the inferred dimension.
1710 XlaOp ReshapeWithInferredDimension(XlaOp operand,
1711 absl::Span<const int64> new_sizes,
1712 int64 inferred_dimension);
1713
1714 // Wrapper for Reshape.
1715 // Enqueues an operation to collapse the provided dimensions; e.g. an
1716 // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
1717 // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
1718 // be a consecutive, in-order subsequence of the operand dimensions.
1719 //
1720 // Note that collapsing a single dimension does nothing:
1721 //
1722 // {256} collapsing {0} => {256}
1723 // {1} collapsing {0} => {1}
1724 //
1725 // Collapsing multiple dimensions produces a single result dimension:
1726 //
1727 // {256, 2} collapsing {0,1} => {512}
1728 // {256, 2, 3} collapsing {0,1} => {512, 3}
1729 //
1730 // This could potentially cause data to be moved -- it provides a more
1731 // structured form of reshaping than an arbitrary Reshape operation.
1732 XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
1733
1734 // Enqueues a slice operation onto the computation that slices the operand
1735 // from the start indices to the limit indices; e.g.
1736 //
1737 // x
1738 // [ 0 1 2 3 ]
1739 // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
1740 // [ 8 9 a b ]
1741 //
1742 // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
1743 // range notation.
1744 // The strides parameter determines the stride over the slice
1745 XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
1746 absl::Span<const int64> limit_indices,
1747 absl::Span<const int64> strides);
1748
1749 // Enqueues a slice operation in a given dimension, taking all other
1750 // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
1751 // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
1752 // for:
1753 //
1754 // array[:, 2:4:1, :]
1755 XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
1756 int64 stride, int64 dimno);
1757
1758 // Enqueues a slice operation onto the computation that slices the 'operand'
1759 // from dynamic start indices which are passed in 'start_indices'.
1760 // The size of the slice in each dimension is passed in 'slice_sizes',
1761 // which specify the end point of exclusive slice intervals in each
1762 // dimension [start, start + size).
1763 // The shape of each element of 'start_indices' must be scalar, with the span
1764 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1765 // have the same shape.
1766 // Slice index calculations are computed modulo input dimension sizes to
1767 // prevent dynamic start indices from generating out-of-bound array accesses.
1768 XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
1769 absl::Span<const int64> slice_sizes);
1770
1771 // Enqueues a dynamic update slice operation onto the computation, which
1772 // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
1773 // The shape of 'update' determines the shape of the slice of 'operand'
1774 // which is updated.
1775 // The indices specified in 'start_indices' specify the offset of the slice
1776 // of 'operand' which is updated.
1777 //
1778 // update = {10, 11} // calculated at runtime.
1779 // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ]
1780 // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11]
1781 // [7 8 9] [7 8 9 ]
1782 //
1783 // The shape of each element of 'start_indices' must be scalar, with the span
1784 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1785 // have the same shape.
1786 // Slice index calculations are computed modulo update dimension sizes to
1787 // prevent dynamic start indices from generating out-of-bound array accesses.
1788 XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
1789 absl::Span<const XlaOp> start_indices);
1790
1791 // Enqueues a concatenate instruction onto the computation. 'operands' must
1792 // have >= 1 entry.
1793 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1794 int64 dimension);
1795
1796 // Enqueue a tracing operation onto the computation; the computation will emit
1797 // a logging message with the operand.
1798 void Trace(const string& tag, XlaOp operand);
1799
1800 // Enqueues a conditional-move-like select operation onto the computation;
1801 // predicated on pred, selects between on_true and on_false.
1802 XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
1803
1804 // Enqueues a tuple-creation instruction onto the computation.
1805 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
1806
1807 // Enqueues a tuple-element-get instruction onto the computation.
1808 XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
1809
1810 // Enqueues an equal-to comparison instruction onto the computation.
1811 XlaOp Eq(XlaOp lhs, XlaOp rhs,
1812 absl::Span<const int64> broadcast_dimensions = {});
1813 XlaOp EqTotalOrder(XlaOp lhs, XlaOp rhs,
1814 absl::Span<const int64> broadcast_dimensions = {});
1815
1816 // Enqueues a not-equal comparison instruction onto the computation.
1817 XlaOp Ne(XlaOp lhs, XlaOp rhs,
1818 absl::Span<const int64> broadcast_dimensions = {});
1819 XlaOp NeTotalOrder(XlaOp lhs, XlaOp rhs,
1820 absl::Span<const int64> broadcast_dimensions = {});
1821
1822 // Enqueues a greater-or-equal comparison instruction onto the computation.
1823 XlaOp Ge(XlaOp lhs, XlaOp rhs,
1824 absl::Span<const int64> broadcast_dimensions = {});
1825 XlaOp GeTotalOrder(XlaOp lhs, XlaOp rhs,
1826 absl::Span<const int64> broadcast_dimensions = {});
1827
1828 // Enqueues a greater-than comparison instruction onto the computation.
1829 XlaOp Gt(XlaOp lhs, XlaOp rhs,
1830 absl::Span<const int64> broadcast_dimensions = {});
1831 XlaOp GtTotalOrder(XlaOp lhs, XlaOp rhs,
1832 absl::Span<const int64> broadcast_dimensions = {});
1833
1834 // Enqueues a less-than comparison instruction onto the computation.
1835 XlaOp Lt(XlaOp lhs, XlaOp rhs,
1836 absl::Span<const int64> broadcast_dimensions = {});
1837 XlaOp LtTotalOrder(XlaOp lhs, XlaOp rhs,
1838 absl::Span<const int64> broadcast_dimensions = {});
1839
1840 // Enqueues a less-or-equal comparison instruction onto the computation.
1841 XlaOp Le(XlaOp lhs, XlaOp rhs,
1842 absl::Span<const int64> broadcast_dimensions = {});
1843 XlaOp LeTotalOrder(XlaOp lhs, XlaOp rhs,
1844 absl::Span<const int64> broadcast_dimensions = {});
1845
1846 // Enqueues a comparison instruction onto the computation (optionally without
1847 // broadcast_dimensions for consistency with others).
1848 XlaOp Compare(XlaOp lhs, XlaOp rhs,
1849 absl::Span<const int64> broadcast_dimensions,
1850 ComparisonDirection direction, Comparison::Type compare_type);
1851 XlaOp Compare(XlaOp lhs, XlaOp rhs,
1852 absl::Span<const int64> broadcast_dimensions,
1853 ComparisonDirection direction);
1854 XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction);
1855
1856 // Enqueues a dot instruction onto the computation.
1857 XlaOp Dot(XlaOp lhs, XlaOp rhs,
1858 const PrecisionConfig* precision_config = nullptr,
1859 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1860
1861 // Enqueues a general dot instruction onto the computation.
1862 XlaOp DotGeneral(
1863 XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
1864 const PrecisionConfig* precision_config = nullptr,
1865 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1866
1867 // Enqueues a convolution instruction onto the computation, which uses the
1868 // default convolution dimension numbers.
1869 XlaOp Conv(
1870 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1871 Padding padding, int64 feature_group_count = 1, int64 batch_group_count = 1,
1872 const PrecisionConfig* precision_config = nullptr,
1873 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1874
1875 // Enqueues a convolution instruction onto the computation, with the caller
1876 // provided padding configuration in the format returned by MakePadding().
1877 XlaOp ConvWithGeneralPadding(
1878 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1879 absl::Span<const std::pair<int64, int64>> padding,
1880 int64 feature_group_count = 1, int64 batch_group_count = 1,
1881 const PrecisionConfig* precision_config = nullptr,
1882 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1883
1884 // Enqueues a convolution instruction onto the computation, with the caller
1885 // provided dimension numbers configuration.
1886 XlaOp ConvWithGeneralDimensions(
1887 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1888 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1889 int64 feature_group_count = 1, int64 batch_group_count = 1,
1890 const PrecisionConfig* precision_config = nullptr,
1891 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1892
1893 // Enqueues a convolution instruction onto the computation, with the caller
1894 // provided padding configuration as well as the dimension numbers.
1895 XlaOp ConvGeneral(
1896 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1897 absl::Span<const std::pair<int64, int64>> padding,
1898 const ConvolutionDimensionNumbers& dimension_numbers,
1899 int64 feature_group_count = 1, int64 batch_group_count = 1,
1900 const PrecisionConfig* precision_config = nullptr,
1901 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1902
1903 // Enqueues a convolution instruction onto the computation, with the caller
1904 // provided padding configuration, dilation factors and dimension numbers.
1905 XlaOp ConvGeneralDilated(
1906 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1907 absl::Span<const std::pair<int64, int64>> padding,
1908 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1909 const ConvolutionDimensionNumbers& dimension_numbers,
1910 int64 feature_group_count = 1, int64 batch_group_count = 1,
1911 const PrecisionConfig* precision_config = nullptr,
1912 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1913
1914 XlaOp DynamicConvForward(
1915 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1916 absl::Span<const std::pair<int64, int64>> padding,
1917 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1918 const ConvolutionDimensionNumbers& dimension_numbers,
1919 int64 feature_group_count, int64 batch_group_count,
1920 const PrecisionConfig* precision_config, PaddingType padding_type,
1921 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1922
1923 XlaOp DynamicConvInputGrad(
1924 XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
1925 absl::Span<const int64> window_strides,
1926 absl::Span<const std::pair<int64, int64>> padding,
1927 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1928 const ConvolutionDimensionNumbers& dimension_numbers,
1929 int64 feature_group_count, int64 batch_group_count,
1930 const PrecisionConfig* precision_config, PaddingType padding_type,
1931 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1932
1933 XlaOp DynamicConvKernelGrad(
1934 XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
1935 absl::Span<const std::pair<int64, int64>> padding,
1936 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1937 const ConvolutionDimensionNumbers& dimension_numbers,
1938 int64 feature_group_count, int64 batch_group_count,
1939 const PrecisionConfig* precision_config, PaddingType padding_type,
1940 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1941
1942 // Enqueues an FFT instruction onto the computation, of the given type and
1943 // with the given FFT length.
1944 XlaOp Fft(XlaOp operand, FftType fft_type, absl::Span<const int64> fft_length);
1945
1946 // Solves systems of linear equations with lower or upper triangular coefficient
1947 // matrices by forward- or back-substitution. Broadcasting along leading
1948 // dimensions, this routine solves for x in one of the matrix systems
1949 // `op(a) * x = b`, or `x * op(a) = b`,
1950 // for the variable `x` given `a` and `b`, where `op(a)` is either
1951 // `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`.
1952 //
1953 // * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form
1954 // square matrices. If `lower` is true (false), then the strictly upper
1955 // (lower) triangular part of each innermost matrix in `a` is assumed to be
1956 // zero and is not accessed.
1957 // * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a
1958 // tensor of shape `[..., K, M]`.
1959 // * `left_side` is a boolean, indicating whether to solve a system of the form
1960 // op(a) * x = b (true) or x * op(a) = b (false).
1961 // * `lower` is a boolean, indicating whether the argument `a` is
1962 // lower-triangular (true) or upper-triangular (false).
1963 // * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be
1964 // 1 and not accessed.
1965 // * `transpose_a` indicates which function `op` we use to transform the tensor
1966 // `a`: the identity function, transpose(a), or conjugate(transpose(a))
1967 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
1968 bool unit_diagonal,
1969 TriangularSolveOptions::Transpose transpose_a);
1970
1971 // Computes the Cholesky decompositions of a batch of symmetric (Hermitian)
1972 // positive definite matrices.
1973 // `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the
1974 // two minor dimensions equal.
1975 // If `lower` is true, the data from the lower triangle is used; if false, the
1976 // upper triangle is used. The input data in the other triangle of the input
1977 // does not affect the output. Returns the output in the same lower/upper
1978 // triangle. The data returned in the other output triangle is arbitrary and
1979 // implementation-defined.
1980 //
1981 // If `a` is not Hermitian positive definite, returns an array full of NaNs.
1982 XlaOp Cholesky(XlaOp a, bool lower);
1983
1984 // Enqueues an infeed instruction onto the computation, which writes data of
1985 // the given shape to the infeed buffer of the device.
1986 XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
1987 const string& config = "");
1988
1989 // Variant of Infeed which takes a token-shaped operand and produces a
1990 // two-element tuple containing the data value and a token-shaped value.
1991 // Tokens are used for ordering side-effecting operations.
1992 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
1993 XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
1994 const string& config = "");
1995
1996 // Enqueues an outfeed instruction onto the computation. This instruction
1997 // generates outgoing data transfers for the given data.
1998 //
1999 // shape_with_layout communicates the laid out shape that we want to outfeed
2000 // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
2001 // will occur.
2002 void Outfeed(XlaOp operand, const Shape& shape_with_layout,
2003 const string& outfeed_config);
2004
2005 // Variant of Outfeed which takes a token-shaped operand and produces a
2006 // token-shaped value. Tokens are used for ordering side-effecting operations.
2007 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2008 XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
2009 const Shape& shape_with_layout,
2010 const string& outfeed_config);
2011
2012 // Enqueues a call instruction onto the computation.
2013 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
2014 absl::Span<const XlaOp> operands);
2015
2016 // Enqueues a custom call instruction onto the computation. A custom call
2017 // invokes code external to XLA. The |operands| are passed to the external code,
2018 // and the external code is expected to produce a result of the given
2019 // |shape|. The exact mechanism is backend-specific. For example, in the CPU
2020 // backend, a call instruction is emitted which targets a symbol with the name
2021 // |call_target_name|. |call_target_name| and |opaque| can arbitrary strings,
2022 // but |call_target_name| should be short as it may be used in labels. |opaque|
2023 // can encode arbitrarily large amounts of information. |has_side_effect|
2024 // specifies whether the instruction can have side effects.
2025 // |output_operand_aliasing| specifies a list of output/operand buffer pairs
2026 // that alias each other, where the output buffer is represented as a
2027 // ShapeIndex, and the operand buffer is represented as the operand index and
2028 // the ShapeIndex.
2029 XlaOp CustomCall(
2030 XlaBuilder* builder, const string& call_target_name,
2031 absl::Span<const XlaOp> operands, const Shape& shape,
2032 const string& opaque = "", bool has_side_effect = false,
2033 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
2034 output_operand_aliasing = {},
2035 const Literal* literal = nullptr);
2036
2037 // Overload which constructs a custom call that applies an Xla computation.
2038 XlaOp CustomCallWithComputation(
2039 XlaBuilder* builder, const string& call_target_name,
2040 absl::Span<const XlaOp> operands, const XlaComputation& computation,
2041 const Shape& shape, const string& opaque = "", bool has_side_effect = false,
2042 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
2043 output_operand_aliasing = {},
2044 const Literal* literal = nullptr);
2045
2046 // Overload which constructs a custom call with fixed layouts. The operands will
2047 // have the layouts specified by |operand_shapes_with_layout| when provided to
2048 // external code, and the external code is expected to produce a result with the
2049 // layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
2050 // and |operand_shapes_with_layout| must have layouts.
2051 XlaOp CustomCallWithLayout(
2052 XlaBuilder* builder, const string& call_target_name,
2053 absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
2054 absl::Span<const Shape> operand_shapes_with_layout,
2055 const string& opaque = "", bool has_side_effect = false,
2056 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
2057 output_operand_aliasing = {},
2058 const Literal* literal = nullptr);
2059
2060 // The following methods enqueue element-wise binary arithmetic operations
2061 // onto the computation. The shapes of the operands have to match unless one
2062 // of the operands is a scalar, or an explicit broadcast dimension is given
2063 // (see g3doc for more details).
2064
2065 // Enqueues a complex compose instruction onto the computation.
2066 XlaOp Complex(XlaOp real, XlaOp imag,
2067 absl::Span<const int64> broadcast_dimensions = {});
2068
2069 // Enqueues a complex conjugate instruction onto the computation.
2070 XlaOp Conj(XlaOp operand);
2071
2072 // Enqueues an add instruction onto the computation.
2073 XlaOp Add(XlaOp lhs, XlaOp rhs,
2074 absl::Span<const int64> broadcast_dimensions = {});
2075
2076 // Enqueues a subtract instruction onto the computation.
2077 XlaOp Sub(XlaOp lhs, XlaOp rhs,
2078 absl::Span<const int64> broadcast_dimensions = {});
2079
2080 // Enqueues a multiply instruction onto the computation.
2081 XlaOp Mul(XlaOp lhs, XlaOp rhs,
2082 absl::Span<const int64> broadcast_dimensions = {});
2083
2084 // Enqueues a divide instruction onto the computation.
2085 XlaOp Div(XlaOp lhs, XlaOp rhs,
2086 absl::Span<const int64> broadcast_dimensions = {});
2087
2088 // Enqueues a remainder instruction onto the computation.
2089 XlaOp Rem(XlaOp lhs, XlaOp rhs,
2090 absl::Span<const int64> broadcast_dimensions = {});
2091
2092 // Enqueues a max instruction onto the computation.
2093 XlaOp Max(XlaOp lhs, XlaOp rhs,
2094 absl::Span<const int64> broadcast_dimensions = {});
2095
2096 // Enqueues a min instruction onto the computation.
2097 XlaOp Min(XlaOp lhs, XlaOp rhs,
2098 absl::Span<const int64> broadcast_dimensions = {});
2099
2100 // Element-wise logical operators
2101 XlaOp And(XlaOp lhs, XlaOp rhs,
2102 absl::Span<const int64> broadcast_dimensions = {});
2103
2104 // Overload to call And with 3 or more operands. We need the following somewhat
2105 // convoluted overload set to disambiguate with the overload that takes the
2106 // `broadcast_dimensions` optional param.
And(XlaOp op1,XlaOp op2,XlaOp op3)2107 inline XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3) {
2108 return And(op1, And(op2, op3));
2109 }
2110 template <typename... XlaOpTs>
And(XlaOp op1,XlaOp op2,XlaOp op3,const XlaOpTs &...operands)2111 XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) {
2112 return And(op1, And(op2, And(op3, operands...)));
2113 }
2114
2115 XlaOp Or(XlaOp lhs, XlaOp rhs,
2116 absl::Span<const int64> broadcast_dimensions = {});
2117
2118 // Overload to call Or with 3 or more operands. As with `And`, we need the
2119 // following complicated overload set to handle the default arg in the `Or`
2120 // overload above.
Or(XlaOp op1,XlaOp op2,XlaOp op3)2121 inline XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3) {
2122 return Or(op1, Or(op2, op3));
2123 }
2124 template <typename... XlaOpTs>
Or(XlaOp op1,XlaOp op2,XlaOp op3,const XlaOpTs &...operands)2125 XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) {
2126 return Or(op1, Or(op2, Or(op3, operands...)));
2127 }
2128
2129 XlaOp Xor(XlaOp lhs, XlaOp rhs,
2130 absl::Span<const int64> broadcast_dimensions = {});
2131
2132 XlaOp Not(XlaOp operand);
2133
2134 XlaOp PopulationCount(XlaOp operand);
2135
2136 XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs,
2137 absl::Span<const int64> broadcast_dimensions = {});
2138 XlaOp ShiftRightArithmetic(XlaOp lhs, XlaOp rhs,
2139 absl::Span<const int64> broadcast_dimensions = {});
2140 XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs,
2141 absl::Span<const int64> broadcast_dimensions = {});
2142
2143 // Reduces an array among the provided dimensions, given "computation" as a
2144 // reduction operator.
2145 XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation,
2146 absl::Span<const int64> dimensions_to_reduce);
2147
2148 // Reduces several arrays simultaneously among the provided dimensions, given
2149 // "computation" as a reduction operator.
2150 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
2151 absl::Span<const XlaOp> init_values,
2152 const XlaComputation& computation,
2153 absl::Span<const int64> dimensions_to_reduce);
2154
2155 // Convenience wrapper around the above that reduces all the dimensions in the
2156 // operand shape.
2157 XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
2158 const XlaComputation& computation);
2159
2160 // Enqueues a windowed reduce instruction onto the computation.
2161 XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
2162 const XlaComputation& computation,
2163 absl::Span<const int64> window_dimensions,
2164 absl::Span<const int64> window_strides, Padding padding);
2165
2166 XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
2167 absl::Span<const XlaOp> init_values,
2168 const XlaComputation& computation,
2169 absl::Span<const int64> window_dimensions,
2170 absl::Span<const int64> window_strides, Padding padding);
2171
2172 // As ReduceWindow(), but the padding is given in the format
2173 // returned by MakePadding().
2174 XlaOp ReduceWindowWithGeneralPadding(
2175 XlaOp operand, XlaOp init_value, const XlaComputation& computation,
2176 absl::Span<const int64> window_dimensions,
2177 absl::Span<const int64> window_strides,
2178 absl::Span<const int64> base_dilations,
2179 absl::Span<const int64> window_dilations,
2180 absl::Span<const std::pair<int64, int64>> padding);
2181
2182 // Returns the sum of the operand value within each subgroup of replicas. All
2183 // replicas supply one input to the sum and all replicas receive the resulting
2184 // sum for each subgroup.
2185 XlaOp CrossReplicaSum(XlaOp operand,
2186 absl::Span<const ReplicaGroup> replica_groups = {});
2187
2188 XlaOp AllGather(
2189 XlaOp operand, int64 all_gather_dimension, int64 shard_count,
2190 absl::Span<const ReplicaGroup> replica_groups = {},
2191 const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
2192 const absl::optional<Layout>& layout = absl::nullopt,
2193 const absl::optional<bool> use_global_device_ids = absl::nullopt);
2194
2195 // Enqueues an operation that do an AllReduce of the operand cross cores. Here
2196 // AllReduce means doing a reduction on the input operand cross cores and then
2197 // broadcasting the reduction result to those cores. The reduction function is
2198 // defined by `computation`, which should be a commutative computation on
2199 // scalars, e.g., add, min, or max. The way that AllReduce is applied is
2200 // configured by:
2201 //
2202 // - `replica_groups`: each ReplicaGroup contains a list of replica id. If
2203 // empty, all replicas belong to one group. Allreduce will be applied within
2204 // subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}}
2205 // means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
2206 //
2207 // - `channel_id`: for Allreduce nodes from different modules, if they have the
2208 // same channel_id, they will be 'AllReduce'd. If empty, AllReduce will not be
2209 // applied cross modules.
2210 //
2211 // - `shape_with_layout`: forces the layout of the AllReduce to the given
2212 // layout. This is used to guarantee the same layout for a group of AllReduce
2213 // ops compiled separately.
2214 XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
2215 absl::Span<const ReplicaGroup> replica_groups = {},
2216 const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
2217 const absl::optional<Shape>& shape_with_layout = absl::nullopt);
2218
2219 // Enqueues an operation that do an Alltoall of the operand cross cores.
2220 // An optional `layout` can be specified to force the layout of the instruction.
2221 // This is used to guarantee the same layout for a group of AllToAll ops
2222 // compiled separately.
2223 XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
2224 int64 split_count,
2225 const std::vector<ReplicaGroup>& replica_groups = {},
2226 const absl::optional<Layout>& layout = absl::nullopt);
2227
2228 XlaOp AllToAllTuple(XlaOp operand, int64 split_dimension,
2229 int64 concat_dimension, int64 split_count,
2230 const std::vector<ReplicaGroup>& replica_groups = {},
2231 const absl::optional<Layout>& layout = absl::nullopt);
2232
2233 // Enqueues an collective operation that sends and receives data cross replicas.
2234 //
2235 // - `source_target_pair`: a list of (source_replica_id, target_replica_id)
2236 // pairs. For each pair, the operand is sent from source replica to target
2237 // replica. Note that, 1) any two pairs should not have the same target replica
2238 // id, and they should not have the same source replica id; 2) if a replica id
2239 // is not a target in any pair, then the output on that replica is a tensor
2240 // consists of 0(s) with the same shape as the input.
2241 XlaOp CollectivePermute(
2242 XlaOp operand,
2243 const std::vector<std::pair<int64, int64>>& source_target_pairs);
2244
2245 // Enqueues an operation that returns the replica ID.
2246 XlaOp ReplicaId(XlaBuilder* builder);
2247
2248 // Enqueues an operation that scatters the `source` array to the selected
2249 // indices of each window.
2250 XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
2251 absl::Span<const int64> window_dimensions,
2252 absl::Span<const int64> window_strides, Padding padding,
2253 XlaOp source, XlaOp init_value,
2254 const XlaComputation& scatter);
2255
2256 // As SelectAndScatter(), but the padding is given in the format
2257 // returned by MakePadding().
2258 XlaOp SelectAndScatterWithGeneralPadding(
2259 XlaOp operand, const XlaComputation& select,
2260 absl::Span<const int64> window_dimensions,
2261 absl::Span<const int64> window_strides,
2262 absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
2263 XlaOp init_value, const XlaComputation& scatter);
2264
2265 // Enqueues an abs instruction onto the computation.
2266 XlaOp Abs(XlaOp operand);
2267
2268 // Enqueues a atan2 instruction onto the computation.
2269 XlaOp Atan2(XlaOp y, XlaOp x,
2270 absl::Span<const int64> broadcast_dimensions = {});
2271
2272 // Enqueues an exp instruction onto the computation.
2273 XlaOp Exp(XlaOp operand);
2274
2275 // Enqueues an expm1 instruction onto the computation.
2276 XlaOp Expm1(XlaOp operand);
2277
2278 // Enqueues a floor instruction onto the computation.
2279 XlaOp Floor(XlaOp operand);
2280
2281 // Enqueues a ceil instruction onto the computation.
2282 XlaOp Ceil(XlaOp operand);
2283
2284 // Enqueues a round instruction onto the computation, rounding to nearest even
2285 // with half-way cases rounding away from zero.
2286 XlaOp Round(XlaOp operand);
2287
2288 // Enqueues an log instruction (natural logarithm) onto the computation.
2289 XlaOp Log(XlaOp operand);
2290
2291 // Enqueues an log1p instruction (log(x+1)) onto the computation.
2292 XlaOp Log1p(XlaOp operand);
2293
2294 // Enqueues a logistic instruction onto the computation.
2295 XlaOp Logistic(XlaOp operand);
2296
2297 // Enqueues a sign instruction onto the computation.
2298 XlaOp Sign(XlaOp operand);
2299
2300 // Enqueues a count leading zeros instruction onto the computation.
2301 XlaOp Clz(XlaOp operand);
2302
2303 // Enqueues a cosine instruction onto the computation.
2304 XlaOp Cos(XlaOp operand);
2305
2306 // Enqueues a sine instruction onto the computation.
2307 XlaOp Sin(XlaOp operand);
2308
2309 // Enqueues a tanh instruction onto the computation.
2310 XlaOp Tanh(XlaOp operand);
2311
2312 // Enqueues a real-part instruction onto the computation.
2313 XlaOp Real(XlaOp operand);
2314
2315 // Enqueues an imaginary-part instruction onto the computation.
2316 XlaOp Imag(XlaOp operand);
2317
2318 // Enqueues a sqrt computation onto the computation.
2319 XlaOp Sqrt(XlaOp operand);
2320
2321 // Enqueues a cbrt computation onto the computation.
2322 XlaOp Cbrt(XlaOp operand);
2323
2324 // Enqueues a rsqrt computation onto the computation.
2325 XlaOp Rsqrt(XlaOp operand);
2326
2327 // Enqueues a lhs^rhs computation onto the computation.
2328 XlaOp Pow(XlaOp lhs, XlaOp rhs,
2329 absl::Span<const int64> broadcast_dimensions = {});
2330
2331 // Enqueues an operator that tests if the operand's values are finite, i.e., not
2332 // +/-Inf or NaN. Returns an array of booleans with the same shape where
2333 // entries are true iff the corresponding entry was not infinite or NaN.
2334 //
2335 // Defined only for real-valued (i.e. not complex) floating-point types; raises
2336 // an error for other types.
2337 //
2338 // See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h.
2339 XlaOp IsFinite(XlaOp operand);
2340
2341 // Enqueues an iota operation onto the computation.
2342 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension);
2343
2344 // Enqueues a rank-1 iota operation onto the computation.
2345 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
2346
2347 // Enqueues a convert instruction onto the computation that changes the
2348 // element type of the operand array to primitive_type.
2349 XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
2350
2351 // Enqueues a no-op instruction onto the computation that changes
2352 // the element type of the operand array to primitive_type. The
2353 // bit-widths of the source and destination element types must be
2354 // identical.
2355 XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
2356
2357 // Enqueues a negate instruction onto the computation.
2358 XlaOp Neg(XlaOp operand);
2359
2360 // Enqueues a transpose instruction onto the computation.
2361 XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
2362
2363 // Enqueues a reverse instruction onto the computation. The order of the
2364 // elements in the given dimensions is reversed (i.e., the element at index i
2365 // is moved to index dimension_size - 1 - i).
2366 XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
2367
2368 // Enqueues a sort instruction onto the computation, using 'comparator' for
2369 // comparisons. 'comparator' needs to define a strict weak order. 'is_stable'
2370 // determines whether the stable sorting should be used.
2371 // If only one operand is provided:
2372 // * If the operand is a rank-1 tensor (an array), the result is a sorted array.
2373 // The resulting sorting order has the property that for all index positions
2374 // i, j with i < j, either
2375 // comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or
2376 // comparator(value[i], value[j]) = true.
2377 // * If the operand has higher rank, the operand is sorted along the provided
2378 // dimension. For example, for a rank-2 tensor (a matrix), a dimension value
2379 // of 0 will independently sort every column, and a dimension value of 1 will
2380 // independently sort each row. If no dimension number is provided, then the
2381 // last dimension is chosen by default. For the dimension which is sorted, the
2382 // same sorting order applies as in the rank-1 case.
2383 //
2384 // If more than one operand is provided:
2385 // * All operands must be tensors with the same dimensions. The element types of
2386 // the tensors may be different.
2387 // * The result is a tuple that consists of the operands in sorted order (along
2388 // the provided dimension, as above). The same permutation as implied by the
2389 // comparison computation is applied to all operand tensors. When comparing
2390 // two index positions, 'comparator' is called with 2 * n scalar parameters,
2391 // where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at
2392 // two index positions.
2393 // Default comparator computations can be found in lib/comparators.h
2394 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
2395 int64 dimension = -1, bool is_stable = false);
2396
2397 // Enqueues a clamp instruction onto the computation.
2398 XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
2399
2400 // Enqueues a map instruction onto the computation.
2401 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
2402 const XlaComputation& computation, absl::Span<const int64> dimensions,
2403 absl::Span<const XlaOp> static_operands = {});
2404
2405 // Enqueues a N(mu, sigma) random number generation instruction onto the
2406 // computation.
2407 XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
2408
2409 // Enqueues a U(a, b) random number generation instruction onto the
2410 // computation. Returns values in the semi-open interval [a, b).
2411 XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
2412
2413 // Enqueues a B(initial_state) random bit generation instruction onto the
2414 // computation. Resturns the new key and random bits with the specified shape.
2415 XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
2416 const Shape& shape);
2417
2418 // Enqueues a while node onto the computation.
2419 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
2420 XlaOp init);
2421
2422 // Enqueues a conditional node onto the computation.
2423 XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
2424 const XlaComputation& true_computation, XlaOp false_operand,
2425 const XlaComputation& false_computation);
2426
2427 // Enqueues either a predicated (if/else) or indexed (switch/case/default)
2428 // conditional node onto the computation. N >= 1 branch_computations and
2429 // branch_operands are matched by index. branch_index selects the branch that
2430 // will be executed. Out of range branch_index uses the N-1'th
2431 // branch_computation as default.
2432 XlaOp Conditional(XlaOp branch_index,
2433 absl::Span<const XlaComputation* const> branch_computations,
2434 absl::Span<const XlaOp> branch_operands);
2435
2436 // Enqueues a ReducePrecision node onto the computation.
2437 XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
2438 const int mantissa_bits);
2439
2440 // Enqueues a Gather node onto the computation.
2441 XlaOp Gather(XlaOp input, XlaOp start_indices,
2442 const GatherDimensionNumbers& dimension_numbers,
2443 absl::Span<const int64> slice_sizes,
2444 bool indices_are_sorted = false);
2445
2446 // Enqueues a Scatter node onto the computation.
2447 XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
2448 const XlaComputation& update_computation,
2449 const ScatterDimensionNumbers& dimension_numbers,
2450 bool indices_are_sorted = false, bool unique_indices = false);
2451
2452 // Enqueues a Send node onto the computation for device-to-device
2453 // communication. This operation sends the given operand to
2454 // a Recv instruction in a different computation that shares the same channel
2455 // handle.
2456 void Send(XlaOp operand, const ChannelHandle& handle);
2457
2458 // Variant of Send which takes a token-shaped operand and produces a
2459 // token-shaped value. Tokens are used for ordering side-effecting operations.
2460 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2461 XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle);
2462
2463 // Enqueues a Recv node onto the computation for device-to-device
2464 // communication. The data comes from a Send instruction in a different
2465 // computation that shares the same channel handle and its shape must be the
2466 // same as the given shape.
2467 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
2468 const ChannelHandle& handle);
2469
2470 // Variant of Recv which takes a token-shaped operand and produces a two-element
2471 // tuple containing the data value and a token-shaped value. Tokens are used
2472 // for ordering side-effecting operations.
2473 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2474 XlaOp RecvWithToken(XlaOp token, const Shape& shape,
2475 const ChannelHandle& handle);
2476
2477 // Enqueues a Send node which transfers data from the device to the host. The
2478 // 'shape_with_layout' argument defines the layout of the data transferred; its
2479 // shape must be compatible with the shape of the operand. The operand must be
2480 // array-shaped.
2481 // TODO(b/111544877): Support tuple shapes.
2482 XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout,
2483 const ChannelHandle& handle);
2484
2485 // Enqueues a Recv node which transfers data from the host to the device. The
2486 // given shape must contain a layout and must be an array.
2487 // TODO(b/111544877): Support tuple shapes.
2488 XlaOp RecvFromHost(XlaOp token, const Shape& shape,
2489 const ChannelHandle& handle);
2490
2491 // Enqueues an operation (AfterAll) with no operands that produces a
2492 // token-shaped value. Tokens are used for ordering side-effecting operations.
2493 // This is a separate method from AfterAll to facility the removal of
2494 // operand-less AfterAll instructions.
2495 // TODO(b/110532604): Remove this function when all tokens are derived from a
2496 // single token generated or passed into the entry computation.
2497 XlaOp CreateToken(XlaBuilder* builder);
2498
2499 // Enqueues an AfterAll instruction which produces a token-shaped value and
2500 // takes a variadic number of token-shaped operands. The number of operands must
2501 // be greater than zero. Used for joining tokens.
2502 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
2503
2504 // Normalizes operand across spatial and batch dimensions for each feature.
2505 //
2506 // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
2507 // is the normalized result and batch_mean and batch_var are the mean and
2508 // variance, respectively, across batch for the operand.
2509 XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, float epsilon,
2510 int64 feature_index);
2511
2512 // Normalizes operand across spatial and batch dimensions for each feature.
2513 //
2514 // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
2515 // computing `mean` and `variance` for each batch inside the operation. It
2516 // uses the input `mean` and `variance` instead as estimated values. The
2517 // purpose of this op is to reduce latency in inference, hence the name
2518 // `BatchNormInference`.
2519 //
2520 // The output has the same shape as `operand`, and contains the normalized
2521 // values for each batch.
2522 XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean,
2523 XlaOp variance, float epsilon, int64 feature_index);
2524
2525 // Calculates the gradients of a batch norm op.
2526 //
2527 // The inputs `batch_mean` and `batch_var` represent the mean and variance
2528 // across the batch.
2529 //
2530 // Returns a tuple of three elements:
2531 // - grad_operand: Gradient with respect to input `operand`
2532 // - grad_offset: Gradient with respect to input `offset`
2533 // - grad_scale: Gradient with respect to input `scale`
2534 XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
2535 XlaOp batch_var, XlaOp grad_output, float epsilon,
2536 int64 feature_index);
2537
2538 // Returns the size of the given dimension of the operand. The operand must be
2539 // array shaped.
2540 XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
2541
2542 XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
2543
2544 // Returns the same op but with dynamic dimension removed.
2545 XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension);
2546
2547 // Implementation details below this point.
2548 //
2549
2550 // Free function template implementations.
2551
2552 template <typename NativeT>
ConstantR0(XlaBuilder * builder,NativeT value)2553 XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
2554 return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
2555 }
2556
2557 template <typename NativeT>
ConstantR1(XlaBuilder * builder,absl::Span<const NativeT> values)2558 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
2559 BorrowingLiteral literal(
2560 reinterpret_cast<const char*>(values.begin()),
2561 ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
2562 {static_cast<int64>(values.size())}));
2563 return ConstantLiteral(builder, literal);
2564 }
2565
2566 template <typename NativeT>
ConstantR1(XlaBuilder * builder,int64 length,NativeT value)2567 XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
2568 Literal literal(ShapeUtil::MakeShape(
2569 primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
2570 literal.PopulateWithValue(value);
2571 return ConstantLiteral(builder, literal);
2572 }
2573
ConstantR1(XlaBuilder * builder,const tensorflow::core::Bitmap & values)2574 inline XlaOp ConstantR1(XlaBuilder* builder,
2575 const tensorflow::core::Bitmap& values) {
2576 return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
2577 }
2578
2579 template <typename NativeT>
ConstantR2(XlaBuilder * builder,std::initializer_list<std::initializer_list<NativeT>> values)2580 XlaOp ConstantR2(XlaBuilder* builder,
2581 std::initializer_list<std::initializer_list<NativeT>> values) {
2582 return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
2583 }
2584
2585 template <typename NativeT>
ConstantFromArrayWithLayout(XlaBuilder * builder,const Array<NativeT> & values,const Layout & layout)2586 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
2587 const Array<NativeT>& values,
2588 const Layout& layout) {
2589 return ConstantLiteral(
2590 builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
2591 }
2592
2593 template <typename NativeT>
ConstantFromArray(XlaBuilder * builder,const Array<NativeT> & values)2594 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
2595 return ConstantLiteral(builder,
2596 LiteralUtil::CreateFromArray<NativeT>(values));
2597 }
2598
2599 template <typename NativeT>
ConstantR2FromArray2DWithLayout(XlaBuilder * builder,const Array2D<NativeT> & values,const Layout & layout)2600 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
2601 const Array2D<NativeT>& values,
2602 const Layout& layout) {
2603 return ConstantLiteral(
2604 builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
2605 }
2606
2607 template <typename NativeT>
ConstantR2FromArray2D(XlaBuilder * builder,const Array2D<NativeT> & values)2608 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
2609 const Array2D<NativeT>& values) {
2610 return ConstantLiteral(builder,
2611 LiteralUtil::CreateR2FromArray2D<NativeT>(values));
2612 }
2613
2614 template <typename NativeT>
ConstantR3FromArray3DWithLayout(XlaBuilder * builder,const Array3D<NativeT> & values,const Layout & layout)2615 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
2616 const Array3D<NativeT>& values,
2617 const Layout& layout) {
2618 return ConstantLiteral(
2619 builder,
2620 LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
2621 }
2622
2623 template <typename NativeT>
ConstantR3FromArray3D(XlaBuilder * builder,const Array3D<NativeT> & values)2624 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
2625 const Array3D<NativeT>& values) {
2626 return ConstantFromArray(builder, values);
2627 }
2628
2629 template <typename NativeT>
ConstantR4FromArray4DWithLayout(XlaBuilder * builder,const Array4D<NativeT> & values,const Layout & layout)2630 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
2631 const Array4D<NativeT>& values,
2632 const Layout& layout) {
2633 return ConstantFromArrayWithLayout(builder, values, layout);
2634 }
2635
2636 template <typename NativeT>
ConstantR4FromArray4D(XlaBuilder * builder,const Array4D<NativeT> & values)2637 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
2638 const Array4D<NativeT>& values) {
2639 return ConstantFromArray(builder, values);
2640 }
2641
2642 } // namespace xla
2643
2644 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
2645