1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
18
19 #include <map>
20 #include <string>
21 #include <type_traits>
22 #include <utility>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/client/padding.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/comparison_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/literal_util.h"
33 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
34 #include "tensorflow/compiler/xla/service/hlo.pb.h"
35 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
36 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/stacktrace.h"
44 #include "tensorflow/core/platform/types.h"
45
46 namespace xla {
47
48 class XlaBuilder;
49 class XlaOp;
50 class HloInstruction;
51
52 namespace internal {
53
54 struct XlaBuilderFriend {
55 static XlaOp BuildFusion(XlaBuilder* builder,
56 absl::Span<const XlaOp> operands,
57 absl::string_view fusion_kind,
58 const XlaComputation& fused_computation);
59
60 static XlaOp BuildBitcast(XlaBuilder* builder, XlaOp operand,
61 const Shape& shape);
62
63 static HloInstructionProto* GetInstruction(XlaOp op);
64 };
65
66 } // namespace internal
67
68 // This represents an instruction that has been enqueued using the XlaBuilder.
69 // This is used to pass to subsequent computations that depends upon the
70 // instruction as an operand.
71 class XlaOp {
72 public:
XlaOp()73 XlaOp() : handle_(-1), builder_(nullptr) {
74 static_assert(std::is_trivially_destructible<XlaOp>::value,
75 "XlaOp should be trivially destructible");
76 }
77 ~XlaOp() = default;
78
79 XlaOp(const XlaOp& other) = default;
80 XlaOp& operator=(const XlaOp& other) = default;
81
82 // Precondition: !IsUninitialized().
83 //
84 // It's very common to do foo.builder()->bar(). Without this precondition, if
85 // foo.builder() is null, the call to bar will segfault at some point possibly
86 // deep in the callstack when we finally dereference `this`. The precondition
87 // lets us avoid this tricky-to-debug problem.
builder()88 XlaBuilder* builder() const {
89 CHECK(builder_ != nullptr);
90 return builder_;
91 }
92
93 // Returns true if the XlaOp represents valid, non-erroneous value.
valid()94 bool valid() const { return handle_ >= 0; }
95
96 // Returns true if the XlaOp was created by the XlaOp() constructor and
97 // not returned by a builder.
IsUninitialized()98 bool IsUninitialized() const { return builder_ == nullptr; }
99
IsIdenticalTo(XlaOp rhs)100 bool IsIdenticalTo(XlaOp rhs) const {
101 return handle_ == rhs.handle_ && builder_ == rhs.builder_;
102 }
103
104 friend std::ostream& operator<<(std::ostream& out, XlaOp op) {
105 out << op.handle();
106 return out;
107 }
108
109 private:
XlaOp(XlaBuilder * builder)110 explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {}
XlaOp(int64_t handle,XlaBuilder * builder)111 XlaOp(int64_t handle, XlaBuilder* builder)
112 : handle_(handle), builder_(builder) {}
113
handle()114 int64 handle() const { return handle_; }
115
116 friend class XlaBuilder;
117 friend class ValueInference;
118 friend class MlirHloBuilder;
119 friend struct internal::XlaBuilderFriend;
120
121 // < 0 means "invalid handle".
122 int64 handle_;
123
124 // Not owned. Non-null for any handle returned by XlaBuilder, even if the
125 // handle is invalid.
126 XlaBuilder* builder_;
127 };
128
129 // Arithmetic operator overloads for the XlaOp type.
130 XlaOp operator-(XlaOp x);
131 XlaOp operator+(XlaOp x, XlaOp y);
132 XlaOp operator-(XlaOp x, XlaOp y);
133 XlaOp operator*(XlaOp x, XlaOp y);
134 XlaOp operator/(XlaOp x, XlaOp y);
135 XlaOp operator%(XlaOp x, XlaOp y);
136
137 // Bitwise operator overloads for the XlaOp type.
138 XlaOp operator~(XlaOp x);
139 XlaOp operator&(XlaOp x, XlaOp y);
140 XlaOp operator|(XlaOp x, XlaOp y);
141 XlaOp operator^(XlaOp x, XlaOp y);
142 XlaOp operator<<(XlaOp x, XlaOp y);
143 // Performs a right arithmetic shift if 'x' is a signed type, otherwise performs
144 // a right logical shift.
145 XlaOp operator>>(XlaOp x, XlaOp y);
146
147 // We don't overload the relational operators (==, !=, <, <=, >, >=) because the
148 // semantics might be surprising since their result types are usually 'bool'.
149 // Further programmers may expect == to be a structural equality.
150 // We also choose not to overload any of the mutating operators (e.g., +=, -=)
151 // because the semantics might be misleading — XLA computations are immutable.
152
153 // A convenient interface for building up computations.
154 //
155 // Thread-compatible.
156 class XlaBuilder {
157 public:
158 // computation_name: name to use for the built computation.
159 XlaBuilder(const string& computation_name);
160
161 XlaBuilder(const XlaBuilder&) = delete;
162 XlaBuilder& operator=(const XlaBuilder&) = delete;
163
164 virtual ~XlaBuilder();
165
166 // Returns the computation name.
name()167 const string& name() const { return name_; }
168
169 // Sets OpMetadata that will be added to all instructions until cleared.
170 //
171 // OpMetadata is often applied to a series of XLA HLO instructions. As a
172 // result, OpMetadata is set on the computation builder. All subsequent
173 // instructions generated via this computation builder will have the same
174 // OpMetadata attached until a call to ClearOpMetadata.
SetOpMetadata(OpMetadata metadata)175 void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); }
176
177 // Swaps the passed op metadata with the ones currently set.
178 //
179 // Returns the old op metadata.
SwapOpMetadata(OpMetadata metadata)180 OpMetadata SwapOpMetadata(OpMetadata metadata) {
181 OpMetadata old_metadata = std::move(metadata_);
182 metadata_ = std::move(metadata);
183 return old_metadata;
184 }
185
186 // Similar to SetOpMetadata, but only set the metadata for the next op.
SetOneShotOpMetadata(OpMetadata metadata)187 void SetOneShotOpMetadata(OpMetadata metadata) {
188 one_shot_metadata_ = std::move(metadata);
189 }
190
191 // Clears the HloMetadata state.
ClearOpMetadata()192 void ClearOpMetadata() { metadata_.Clear(); }
193
194 // Sets an OpSharding that will be attached to all instructions until cleared.
SetSharding(const OpSharding & sharding)195 void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
196
197 // Sets the FrontendAttributes that will be added to all instructions until
198 // cleared.
199 //
200 // FrontendAttributes are often applied to a series of XLA HLO instructions.
201 // As a result they are set on the computation builder and all the
202 // instructions generated via the computation builder will have the same
203 // frontend attributes attached to them.
SetFrontendAttributes(const FrontendAttributes & frontend_attributes)204 void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) {
205 frontend_attributes_ = frontend_attributes;
206 }
207
208 // Swap the passed FrontendAttributes with the ones currently set.
209 //
210 // Return the old attributes.
SwapFrontendAttributes(const FrontendAttributes & frontend_attributes)211 FrontendAttributes SwapFrontendAttributes(
212 const FrontendAttributes& frontend_attributes) {
213 FrontendAttributes old_attributes = std::move(frontend_attributes_);
214 frontend_attributes_ = frontend_attributes;
215 return old_attributes;
216 }
217
218 // Returns the FrontendAttributes that will be attached to all instructions.
frontend_attributes()219 const FrontendAttributes& frontend_attributes() const {
220 return frontend_attributes_;
221 }
222
223 // Clears all the frontend attributes.
ClearFrontendAttributes()224 void ClearFrontendAttributes() { frontend_attributes_.Clear(); }
225
226 // Clears the sharding. Ops will be sharded according to the default placement
227 // policy.
ClearSharding()228 void ClearSharding() { sharding_ = absl::nullopt; }
229
230 // Returns the OpSharding that will be attached to all instructions.
sharding()231 const absl::optional<OpSharding>& sharding() const { return sharding_; }
232
233 // Sets the builder to a mode where it will die immediately when an error is
234 // encountered, rather than producing it in a deferred fashion when Build() is
235 // called (which is the default).
set_die_immediately_on_error(bool enabled)236 void set_die_immediately_on_error(bool enabled) {
237 die_immediately_on_error_ = enabled;
238 }
239
240 // Default dimension numbers used for a 2D convolution.
241 static constexpr int64_t kConvBatchDimension = 0;
242 static constexpr int64_t kConvFeatureDimension = 1;
243 static constexpr int64_t kConvFirstSpatialDimension = 2;
244 static constexpr int64_t kConvSecondSpatialDimension = 3;
245 static constexpr int64_t kConvKernelOutputDimension = 0;
246 static constexpr int64_t kConvKernelInputDimension = 1;
247 static constexpr int64_t kConvKernelFirstSpatialDimension = 2;
248 static constexpr int64_t kConvKernelSecondSpatialDimension = 3;
249
250 // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
251 // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
252 // the kernel operand
253 // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
254 static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
255 int num_spatial_dims = 2);
256
257 // Returns an error if the convolution dimension numbers have conflicts.
258 static Status Validate(const ConvolutionDimensionNumbers& dnum);
259
260 // Returns a new XlaBuilder whose resultant Computation is used only by this
261 // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
262 // behavior as the parent.
263 std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
264
265 // Builds the computation with the requested operations, or returns a non-ok
266 // status. Note that all ops that have been enqueued will be moved to the
267 // computation being returned. The root of the computation will be the last
268 // added operation.
269 //
270 // `remove_dynamic_dimensions` tells the builder whether to remove the
271 // dynamic dimensions information in all ops.
272 //
273 // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the
274 // dynamic dimensions information when XLA backend can handle dynamic
275 // dimensions.
276 StatusOr<XlaComputation> Build(bool remove_dynamic_dimensions = false);
277
278 // Overload of Build which specifies a particular root instruction for the
279 // computation.
280 StatusOr<XlaComputation> Build(XlaOp root,
281 bool remove_dynamic_dimensions = false);
282
283 // Builds the computation with the requested operations, or notes an error in
284 // the parent XlaBuilder and returns an empty computation if building failed.
285 // This function is intended to be used where the returned XlaComputation is
286 // only used by the parent XlaBuilder and hence further operation on the
287 // returned XlaComputation will simply be error'ed out if an error occurred
288 // while building this computation. If the built computation is to be used by
289 // a XlaBuilder other than the parent XlaBuilder then Build() should be used
290 // instead.
291 XlaComputation BuildAndNoteError();
292
293 // Returns a subgraph that roots on the given root. If the root is not a
294 // compile-time constant (see `IsConstant`), returns an error.
295 //
296 // This will copy the needed ops/computations to the subgraph.
297 StatusOr<XlaComputation> BuildConstantSubGraph(
298 XlaOp root_op, bool dynamic_dimension_is_uint_max = false);
299
300 // Returns the first error that was encountered while building the
301 // computation. When an error is encountered, by default we return a vacuous
302 // XlaOp and inform the user of the error that occurred while
303 // building the computation when they make a final call to Build().
304 //
305 // See also set_die_immediately_on_error().
first_error()306 Status first_error() const { return first_error_; }
307
308 // Returns the current status of the builder, complete with the stack trace
309 // information.
310 Status GetCurrentStatus() const;
311
312 // Returns the shape of the given op.
313 StatusOr<Shape> GetShape(XlaOp op) const;
314
315 // Returns the shape of the given op.
316 virtual StatusOr<const Shape*> GetShapePtr(XlaOp op) const;
317
318 // Returns the (inferred) result for the current computation's shape. This
319 // assumes the root instruction is the last added instruction.
320 StatusOr<ProgramShape> GetProgramShape() const;
321
322 // Returns the (inferred) result for the current computation's shape using the
323 // given operation as the root.
324 StatusOr<ProgramShape> GetProgramShape(XlaOp root) const;
325
326 // Reports an error to the builder, by
327 // * storing it internally and capturing a backtrace if it's the first error
328 // (this deferred value will be produced on the call to
329 // Build()/GetShape()/...)
330 // * dying if die_immediately_on_error_ is true.
331 // Returns an XlaOp with an invalid handle but a valid builder. This value can
332 // be returned in place of a value in APIs that return an XlaOp.
333 XlaOp ReportError(const Status& error);
334
335 // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
336 // If the Status was an error, reports the error to builder and returns an
337 // invalid XlaOp handle.
338 XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
339
340 // A helper function that runs a function that returns a StatusOr<XlaOp> and
341 // returns an XlaOp.
342 XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
343
344 // Returns true if 'operand' is a compile-time constant. A compile-time
345 // constant does not depend on any parameters, or on stateful operators such
346 // as `RngNormal` or `Infeed`.
347 //
348 // This tests whether a computation is a compile-time constant without
349 // evaluating the computation.
350 StatusOr<bool> IsConstant(XlaOp operand) const;
351
352 // Sets up binding which indicates that the `target_dim_num` in the subshape
353 // `target_param_index` of parameter `target_param_num` is a dynamic dimension
354 // and its real dynamic size is represented by `dynamic_param_index` in
355 // parameter `dynamic_param_num`.
356 //
357 // Note that this should be called before the dynamic parameters are used to
358 // create other operations, otherwise created operations won't have the
359 // dynamic dimensions information.
360 //
361 // TODO(b/119520625): Remove this API once we have more dynamic shape infra
362 // ready.
363 ABSL_DEPRECATED("Use SetDimensionSize to set a dynamic dimension.")
364 Status SetDynamicBinding(int64_t dynamic_size_param_num,
365 ShapeIndex dynamic_size_param_index,
366 int64_t target_param_num,
367 ShapeIndex target_param_index,
368 int64_t target_dim_num);
369
370 // Adds a new input/output alias. Since the input/output shape information are
371 // not available until the computation is built, and eventual error in the
372 // arguments of this API will be detected only at computation Build() time.
373 //
374 // Note: Except when 'must-alias' is true, alias is assumed to be 'may-alias'
375 // and only donated buffer at runtime will be aliased with output. If a buffer
376 // is not donated at runtime, a copy will be inserted by XLA to prevent buffer
377 // clobbering.
378 void SetUpAlias(const ShapeIndex& output_index, int64_t param_number,
379 const ShapeIndex& param_index,
380 HloInputOutputAliasConfig::AliasKind kind =
381 HloInputOutputAliasConfig::AliasKind::kMayAlias) {
382 input_output_aliases_.push_back(
383 {output_index, param_number, param_index, kind});
384 }
385
386 // Describes an input/output alias as inserted by the SetUpAlias() API.
387 struct InputOutputAlias {
388 // Specifies the index of the aliased buffer in the result tuple.
389 ShapeIndex output_index;
390 // Specifies the parameter containing the buffer to be aliased.
391 int64 param_number;
392 // Specifies the index of the aliased buffer in the parameter
393 ShapeIndex param_index;
394 // Specifies if the alias is a must alias or may alias.
395 HloInputOutputAliasConfig::AliasKind kind;
396 };
397
398 // Looks up the HloInstruction and sets the frontend attribute "attribute" to
399 // "value".
400 //
401 // If the attribute already existed then its value is updated.
402 //
403 // Note: the attribute is only added to the HloInstruction, not to the
404 // builder.
405 Status SetInstructionFrontendAttribute(XlaOp op, string attribute,
406 string value);
407
408 // Returns shapes for the operands.
409 StatusOr<std::vector<Shape>> GetOperandShapes(
410 absl::Span<const XlaOp> operands) const;
411
412 // Converts the op to string for the ease of debugging.
413 std::string OpToString(XlaOp op) const;
414
415 private:
416 void ToStringHelper(std::string* out, int ident, int64_t op_handle) const;
417
418 // Build helper which takes the id of the root operation..
419 StatusOr<XlaComputation> Build(int64_t root_id,
420 bool remove_dynamic_dimensions);
421
422 // Description for the methods below can be found in the corresponding public
423 // functions section in this file.
424
425 XlaOp Parameter(int64_t parameter_number, const Shape& shape,
426 const string& name,
427 const std::vector<bool>& replicated_at_leaf_buffers);
Parameter(int64_t parameter_number,const Shape & shape,const string & name)428 XlaOp Parameter(int64_t parameter_number, const Shape& shape,
429 const string& name) {
430 std::vector<bool> empty_bools;
431 return Parameter(parameter_number, shape, name, empty_bools);
432 }
433
434 virtual XlaOp ConstantLiteral(const LiteralSlice& literal);
435
436 XlaOp Broadcast(XlaOp operand, absl::Span<const int64> broadcast_sizes);
437
438 XlaOp BroadcastInDim(XlaOp operand,
439 const absl::Span<const int64> out_dim_size,
440 const absl::Span<const int64> broadcast_dimensions);
441
442 XlaOp Pad(XlaOp operand, XlaOp padding_value,
443 const PaddingConfig& padding_config);
444 XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
445 int64_t pad_lo, int64_t pad_hi);
446
447 virtual StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
448 XlaOp padding_value,
449 const PaddingConfig& padding_config);
450
451 XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
452 absl::Span<const int64> new_sizes,
453 int64_t inferred_dimension = -1);
454
455 XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
456 int64_t inferred_dimension = -1);
457
458 XlaOp Reshape(const Shape& shape, XlaOp operand,
459 int64_t inferred_dimension = -1);
460
461 XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
462 absl::Span<const int64> new_size_bounds,
463 const std::vector<bool>& dims_are_dynamic);
464
465 XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
466
467 XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
468 absl::Span<const int64> limit_indices,
469 absl::Span<const int64> strides);
470 virtual StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand,
471 absl::Span<const int64> start_indices,
472 absl::Span<const int64> limit_indices,
473 absl::Span<const int64> strides);
474 virtual XlaOp SliceInDim(XlaOp operand, int64_t start_index,
475 int64_t limit_index, int64_t stride, int64_t dimno);
476
477 XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
478 absl::Span<const int64> slice_sizes);
479 virtual StatusOr<XlaOp> DynamicSliceInternal(
480 const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
481 absl::Span<const int64> slice_sizes);
482
483 XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
484 absl::Span<const XlaOp> start_indices);
485 virtual StatusOr<XlaOp> DynamicUpdateSliceInternal(
486 const Shape& shape, XlaOp operand, XlaOp update,
487 absl::Span<const XlaOp> start_indices);
488
489 XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64_t dimension);
490 virtual StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
491 absl::Span<const XlaOp> operands,
492 int64_t dimension);
493
494 void Trace(const string& tag, XlaOp operand);
495
496 XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
497
498 XlaOp Tuple(absl::Span<const XlaOp> elements);
499 virtual StatusOr<XlaOp> TupleInternal(const Shape& shape,
500 absl::Span<const XlaOp> elements);
501
502 XlaOp GetTupleElement(XlaOp tuple_data, int64_t index);
503 virtual StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape,
504 XlaOp tuple_data,
505 int64_t index);
506
507 XlaOp Dot(
508 XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config = nullptr,
509 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
510
511 XlaOp DotGeneral(
512 XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
513 const PrecisionConfig* precision_config = nullptr,
514 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
515
516 XlaOp Conv(
517 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
518 Padding padding, int64_t feature_group_count = 1,
519 int64_t batch_group_count = 1,
520 const PrecisionConfig* precision_config = nullptr,
521 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
522
523 XlaOp ConvWithGeneralPadding(
524 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
525 absl::Span<const std::pair<int64, int64>> padding,
526 int64_t feature_group_count = 1, int64_t batch_group_count = 1,
527 const PrecisionConfig* precision_config = nullptr,
528 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
529
530 XlaOp ConvWithGeneralDimensions(
531 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
532 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
533 int64_t feature_group_count = 1, int64_t batch_group_count = 1,
534 const PrecisionConfig* precision_config = nullptr,
535 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
536
537 XlaOp ConvGeneral(
538 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
539 absl::Span<const std::pair<int64, int64>> padding,
540 const ConvolutionDimensionNumbers& dimension_numbers,
541 int64_t feature_group_count = 1, int64_t batch_group_count = 1,
542 const PrecisionConfig* precision_config = nullptr,
543 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
544
545 XlaOp ConvGeneralDilated(
546 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
547 absl::Span<const std::pair<int64, int64>> padding,
548 absl::Span<const int64> lhs_dilation,
549 absl::Span<const int64> rhs_dilation,
550 const ConvolutionDimensionNumbers& dimension_numbers,
551 int64_t feature_group_count = 1, int64_t batch_group_count = 1,
552 const PrecisionConfig* precision_config = nullptr,
553 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
554
555 XlaOp DynamicConvForward(
556 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
557 absl::Span<const std::pair<int64, int64>> padding,
558 absl::Span<const int64> lhs_dilation,
559 absl::Span<const int64> rhs_dilation,
560 const ConvolutionDimensionNumbers& dimension_numbers,
561 int64_t feature_group_count, int64_t batch_group_count,
562 const PrecisionConfig* precision_config, PaddingType padding_type,
563 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
564
565 XlaOp DynamicConvInputGrad(
566 XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
567 absl::Span<const int64> window_strides,
568 absl::Span<const std::pair<int64, int64>> padding,
569 absl::Span<const int64> lhs_dilation,
570 absl::Span<const int64> rhs_dilation,
571 const ConvolutionDimensionNumbers& dimension_numbers,
572 int64_t feature_group_count, int64_t batch_group_count,
573 const PrecisionConfig* precision_config, PaddingType padding_type,
574 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
575
576 XlaOp DynamicConvKernelGrad(
577 XlaOp activations, XlaOp gradients,
578 absl::Span<const int64> window_strides,
579 absl::Span<const std::pair<int64, int64>> padding,
580 absl::Span<const int64> lhs_dilation,
581 absl::Span<const int64> rhs_dilation,
582 const ConvolutionDimensionNumbers& dimension_numbers,
583 int64_t feature_group_count, int64_t batch_group_count,
584 const PrecisionConfig* precision_config, PaddingType padding_type,
585 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
586
587 StatusOr<HloInstructionProto> DynamicConvInstruction(
588 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
589 absl::Span<const std::pair<int64, int64>> padding,
590 absl::Span<const int64> lhs_dilation,
591 absl::Span<const int64> rhs_dilation,
592 const ConvolutionDimensionNumbers& dimension_numbers,
593 int64_t feature_group_count, int64_t batch_group_count,
594 const PrecisionConfig* precision_config, PaddingType padding_type,
595 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
596
597 virtual StatusOr<XlaOp> ConvGeneralDilatedInternal(
598 const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
599 absl::Span<const int64> window_strides,
600 absl::Span<const std::pair<int64, int64>> padding,
601 absl::Span<const int64> lhs_dilation,
602 absl::Span<const int64> rhs_dilation,
603 const ConvolutionDimensionNumbers& dimension_numbers,
604 int64_t feature_group_count, int64_t batch_group_count,
605 const PrecisionConfig* precision_config);
606
607 XlaOp Fft(XlaOp operand, FftType fft_type,
608 absl::Span<const int64> fft_length);
609 virtual StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand,
610 FftType fft_type,
611 absl::Span<const int64> fft_length);
612
613 virtual StatusOr<XlaOp> TriangularSolveInternal(
614 const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options);
615
616 virtual StatusOr<XlaOp> CholeskyInternal(const Shape& shape, XlaOp a,
617 bool lower);
618
619 XlaOp Infeed(const Shape& shape, const string& config = "");
620 XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const string& config);
621 virtual StatusOr<XlaOp> InfeedWithTokenInternal(
622 const Shape& infeed_instruction_shape, XlaOp token, const string& config);
623
624 void Outfeed(XlaOp operand, const Shape& shape_with_layout,
625 const string& outfeed_config);
626 XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
627 const Shape& shape_with_layout,
628 const string& outfeed_config);
629 virtual StatusOr<XlaOp> OutfeedWithTokenInternal(
630 XlaOp operand, XlaOp token, const Shape& shape_with_layout,
631 const string& outfeed_config);
632 XlaOp Call(const XlaComputation& computation,
633 absl::Span<const XlaOp> operands);
634
635 XlaOp CustomCall(
636 const string& call_target_name, absl::Span<const XlaOp> operands,
637 const Shape& shape_with_layout, const string& opaque,
638 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
639 bool has_side_effect,
640 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
641 output_operand_aliasing,
642 const Literal* literal, absl::optional<Window> window,
643 absl::optional<ConvolutionDimensionNumbers> dnums,
644 CustomCallSchedule schedule, CustomCallApiVersion api_version);
645
646 // Internal version of CustomCall without computation that doesn't do op
647 // specific error handling and expects arguments to be legal. CustomCall
648 // method above calls this method after error handling.
649 virtual StatusOr<XlaOp> CustomCallInternal(
650 const string& call_target_name, absl::Span<const XlaOp> operands,
651 const Shape& shape_with_layout, const string& opaque,
652 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
653 bool has_side_effect,
654 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
655 output_operand_aliasing,
656 const Literal* literal, absl::optional<Window> window,
657 absl::optional<ConvolutionDimensionNumbers> dnums,
658 CustomCallSchedule schedule, CustomCallApiVersion api_version);
659
660 XlaOp CustomCall(
661 const string& call_target_name, absl::Span<const XlaOp> operands,
662 const XlaComputation& computation, const Shape& shape_with_layout,
663 const string& opaque,
664 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
665 bool has_side_effect,
666 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
667 output_operand_aliasing,
668 const Literal* literal, CustomCallSchedule schedule,
669 CustomCallApiVersion api_version);
670
671 XlaOp Reduce(XlaOp operand, XlaOp init_value,
672 const XlaComputation& computation,
673 absl::Span<const int64> dimensions_to_reduce);
674
675 XlaOp Reduce(absl::Span<const XlaOp> operands,
676 absl::Span<const XlaOp> init_values,
677 const XlaComputation& computation,
678 absl::Span<const int64> dimensions_to_reduce);
679
680 virtual StatusOr<XlaOp> ReduceInternal(
681 const Shape& shape, absl::Span<const XlaOp> all_operands,
682 const XlaComputation& computation,
683 absl::Span<const int64> dimensions_to_reduce);
684
685 XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
686 const XlaComputation& computation);
687
688 XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
689 const XlaComputation& computation,
690 absl::Span<const int64> window_dimensions,
691 absl::Span<const int64> window_strides, Padding padding);
692
693 XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
694 absl::Span<const XlaOp> init_values,
695 const XlaComputation& computation,
696 absl::Span<const int64> window_dimensions,
697 absl::Span<const int64> window_strides, Padding padding);
698
699 XlaOp ReduceWindowWithGeneralPadding(
700 absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
701 const XlaComputation& computation,
702 absl::Span<const int64> window_dimensions,
703 absl::Span<const int64> window_strides,
704 absl::Span<const int64> base_dilations,
705 absl::Span<const int64> window_dilations,
706 absl::Span<const std::pair<int64, int64>> padding);
707 StatusOr<HloInstructionProto> ReduceWindowInternal(
708 absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
709 const XlaComputation& computation,
710 absl::Span<const int64> window_dimensions,
711 absl::Span<const int64> window_strides,
712 absl::Span<const int64> base_dilations,
713 absl::Span<const int64> window_dilations,
714 absl::Span<const std::pair<int64, int64>> padding);
715 virtual StatusOr<XlaOp> ReduceWindowInternal(
716 const Shape& shape, XlaOp operand, XlaOp init_value,
717 const XlaComputation& computation, Window window);
718 XlaOp CrossReplicaSum(XlaOp operand,
719 absl::Span<const ReplicaGroup> replica_groups = {});
720
721 XlaOp AllGather(
722 XlaOp operand, int64_t all_gather_dimension, int64_t shard_count,
723 absl::Span<const ReplicaGroup> replica_groups = {},
724 const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
725 const absl::optional<Layout>& layout = absl::nullopt,
726 const absl::optional<bool> use_global_device_ids = absl::nullopt);
727
728 XlaOp AllReduce(
729 XlaOp operand, const XlaComputation& computation,
730 absl::Span<const ReplicaGroup> replica_groups = {},
731 const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
732 const absl::optional<Shape>& shape_with_layout = absl::nullopt);
733
734 XlaOp ReduceScatter(
735 XlaOp operand, const XlaComputation& computation,
736 int64_t scatter_dimension, int64_t shard_count,
737 absl::Span<const ReplicaGroup> replica_groups = {},
738 const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
739 const absl::optional<Layout>& layout = absl::nullopt,
740 const absl::optional<bool> use_global_device_ids = absl::nullopt);
741
742 XlaOp AllToAll(XlaOp operand, int64_t split_dimension,
743 int64_t concat_dimension, int64_t split_count,
744 absl::Span<const ReplicaGroup> replica_groups,
745 const absl::optional<Layout>& layout = absl::nullopt);
746
747 XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension,
748 int64_t concat_dimension, int64_t split_count,
749 absl::Span<const ReplicaGroup> replica_groups,
750 const absl::optional<Layout>& layout);
751
752 XlaOp CollectivePermute(
753 XlaOp operand,
754 const std::vector<std::pair<int64, int64>>& source_target_pairs);
755
756 XlaOp ReplicaId();
757
758 XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
759 absl::Span<const int64> window_dimensions,
760 absl::Span<const int64> window_strides,
761 Padding padding, XlaOp source, XlaOp init_value,
762 const XlaComputation& scatter);
763
764 XlaOp SelectAndScatterWithGeneralPadding(
765 XlaOp operand, const XlaComputation& select,
766 absl::Span<const int64> window_dimensions,
767 absl::Span<const int64> window_strides,
768 absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
769 XlaOp init_value, const XlaComputation& scatter);
770
771 StatusOr<HloInstructionProto> SelectAndScatterInternal(
772 XlaOp operand, const XlaComputation& select,
773 absl::Span<const int64> window_dimensions,
774 absl::Span<const int64> window_strides,
775 absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
776 XlaOp init_value, const XlaComputation& scatter);
777
778 virtual XlaOp Iota(const Shape& shape, int64_t iota_dimension);
779
780 XlaOp Iota(PrimitiveType type, int64_t size);
781
782 XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
783
784 XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
785 virtual StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape,
786 XlaOp operand);
787
788 XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
789 virtual StatusOr<XlaOp> TransposeInternal(
790 const Shape& shape, XlaOp operand, absl::Span<const int64> permutation);
791
792 XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
793 virtual StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand,
794 absl::Span<const int64> dimensions);
795
796 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
797 int64_t dimension = -1, bool is_stable = false);
798 virtual StatusOr<XlaOp> SortInternal(const Shape& shape,
799 absl::Span<const XlaOp> operands,
800 const XlaComputation& comparator,
801 int64_t dimension, bool is_stable);
802
803 XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
804
805 XlaOp Map(absl::Span<const XlaOp> operands, const XlaComputation& computation,
806 absl::Span<const int64> dimensions,
807 absl::Span<const XlaOp> static_operands = {});
808
809 XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
810
811 XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
812
813 XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
814 const Shape& shape);
815 // Internal variant for the op with the full result shape containing both data
816 // and state shape as a tuple.
817 virtual StatusOr<XlaOp> RngBitGeneratorInternal(
818 const Shape& full_result_shape, RandomAlgorithm algorithm,
819 XlaOp initial_state);
820
821 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
822 XlaOp init);
823 virtual StatusOr<XlaOp> WhileInternal(const Shape& shape,
824 const XlaComputation& condition,
825 const XlaComputation& body, XlaOp init);
826
827 XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
828 const XlaComputation& true_computation, XlaOp false_operand,
829 const XlaComputation& false_computation);
830
831 XlaOp Conditional(XlaOp branch_index,
832 absl::Span<const XlaComputation* const> branch_computations,
833 absl::Span<const XlaOp> branch_operands);
834
835 XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
836 const int mantissa_bits);
837 virtual StatusOr<XlaOp> ReducePrecisionInternal(const Shape& shape,
838 XlaOp operand,
839 const int exponent_bits,
840 const int mantissa_bits);
841
842 XlaOp Gather(XlaOp input, XlaOp start_indices,
843 const GatherDimensionNumbers& dimension_numbers,
844 absl::Span<const int64> slice_sizes,
845 bool indices_are_sorted = false);
846
847 virtual StatusOr<XlaOp> GatherInternal(
848 const Shape& shape, XlaOp input, XlaOp start_indices,
849 const GatherDimensionNumbers& dimension_numbers,
850 absl::Span<const int64> slice_sizes, bool indices_are_sorted);
851
852 XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
853 const XlaComputation& update_computation,
854 const ScatterDimensionNumbers& dimension_numbers,
855 bool indices_are_sorted = false, bool unique_indices = false);
856
857 virtual StatusOr<XlaOp> ScatterInternal(
858 const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
859 const XlaComputation& update_computation,
860 const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
861 bool unique_indices);
862
863 void Send(XlaOp operand, const ChannelHandle& handle);
864 XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle);
865
866 XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout,
867 const ChannelHandle& handle);
868
869 XlaOp RecvFromHost(XlaOp token, const Shape& shape,
870 const ChannelHandle& handle);
871
872 virtual XlaOp CreateToken();
873
874 XlaOp AfterAll(absl::Span<const XlaOp> tokens);
875
876 XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
877 XlaOp RecvWithToken(XlaOp token, const Shape& shape,
878 const ChannelHandle& handle);
879
880 XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
881 float epsilon, int64_t feature_index);
882
883 XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean,
884 XlaOp variance, float epsilon,
885 int64_t feature_index);
886
887 XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
888 XlaOp batch_var, XlaOp grad_output, float epsilon,
889 int64_t feature_index);
890
891 XlaOp GetDimensionSize(XlaOp operand, int64_t dimension);
892
893 XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension);
894
895 virtual StatusOr<XlaOp> SetDimensionSizeInternal(const Shape& shape,
896 XlaOp operand, XlaOp val,
897 int64_t dimension);
898
899 XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension);
900
901 virtual StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr,
902 HloOpcode opcode,
903 absl::Span<const XlaOp> operands);
AddInstruction(HloInstructionProto && instr,HloOpcode opcode)904 StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr,
905 HloOpcode opcode) {
906 return AddInstruction(std::move(instr), opcode, /*operands=*/{});
907 }
908
909 void AddCalledComputation(const XlaComputation& computation,
910 HloInstructionProto* instr);
911
912 StatusOr<const HloInstructionProto*> LookUpInstruction(XlaOp op) const;
913 StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
914 int64_t handle) const;
915 StatusOr<HloInstructionProto*> LookUpMutableInstruction(XlaOp op);
916 StatusOr<HloInstructionProto*> LookUpMutableInstructionByHandle(
917 int64_t handle);
918
919 // Internal helper method that does the building for an arbitrary unary op.
920 virtual XlaOp UnaryOp(HloOpcode unop, XlaOp operand);
921
922 // Internal helper method that does the building for an arbitrary binary op.
923 // broadcast_dimensions specifies which dimensions to use for broadcasting
924 // when the operation is between tensors of different ranks. The direction is
925 // only used if opcode is kCompare.
926 XlaOp BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
927 absl::Span<const int64> broadcast_dimensions,
928 absl::optional<ComparisonDirection> direction = absl::nullopt,
929 absl::optional<Comparison::Type> type = absl::nullopt);
930
931 StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
932 ComparisonDirection direction);
933
934 // Internal helper method for binary op compare without broadcast dimensions.
935 virtual StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
936 ComparisonDirection direction,
937 Comparison::Type type);
938
939 // Internal helper method that does the building for an arbitrary binary op
940 // with same ranked operands that doesn't broadcast.
941 virtual XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
942 XlaOp lhs, XlaOp rhs);
943
944 // Internal helper method that does the building for an arbitrary ternary op.
945 XlaOp TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs);
946
947 XlaOp RngOp(RandomDistribution distribution,
948 absl::Span<const XlaOp> parameters, const Shape& shape);
949
950 virtual StatusOr<XlaOp> RngOpInternal(RandomDistribution distribution,
951 absl::Span<const XlaOp> parameters,
952 const Shape& shape);
953
954 virtual StatusOr<XlaOp> InDimBroadcast(
955 const Shape& shape, XlaOp operand,
956 absl::Span<const int64> broadcast_dimensions);
957
958 // Internal helper method that creates a sequence of instructions that
959 // performs an explicit broadcast of the operand to the target shape.
960 StatusOr<XlaOp> AddBroadcastSequence(const Shape& output_shape,
961 XlaOp operand);
962
963 // Internal helper method for creating a Reshape op with the already inferred
964 // shape.
965 virtual StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand,
966 int64_t inferred_dimension);
967
968 // Returns the (inferred) result for the program shape using the given root.
969 StatusOr<ProgramShape> GetProgramShape(int64_t root_id) const;
970
971 // A visitor which checks whether an operation is a compile-time constant,
972 // meaning that it doesn't depend on any parameters, or on any stateful
973 // operation such as `RngNormal` or `Infeed`. The visitor walks the
974 // computation starting at a given operation and sets is_constant to false iff
975 // a parameter or stateful operation is encountered.
976 void IsConstantVisitor(const int64_t op_handle, int depth,
977 absl::flat_hash_set<int64>* visited,
978 bool* is_constant) const;
979
980 // Checks bounds for convolution parameters.
981 Status VerifyConvolution(
982 const Shape& lhs_shape, const Shape& rhs_shape,
983 const ConvolutionDimensionNumbers& dimension_numbers) const;
984
GetNextId()985 int64 GetNextId() { return ++next_id_; }
986
987 // Populates the module with the input/output alias information stored within
988 // the input_output_aliases vector.
989 static Status PopulateInputOutputAlias(
990 HloModuleProto* module, const ProgramShape& program_shape,
991 const std::vector<InputOutputAlias>& input_output_aliases);
992
993 string name_; // Name to use for the built computation.
994
995 // The next sequential ID for every instruction/computation contained within
996 // this computation.
997 int64 next_id_ = 0;
998
999 // The first error encountered while building the computation.
1000 // This is OK until the first error is encountered.
1001 Status first_error_;
1002
1003 // The saved stack trace from the point at which the first error occurred.
1004 tensorflow::SavedStackTrace first_error_backtrace_;
1005
1006 // The instructions of this computation.
1007 std::vector<HloInstructionProto> instructions_;
1008
1009 // An cache for the HloInstructionProto shapes, to avoid recreating Shape
1010 // objects from protos and to support the GetShapePtr() API.
1011 std::vector<std::unique_ptr<Shape>> instruction_shapes_;
1012
1013 // Dynamic parameter configuration of this computation.
1014 DynamicParameterBinding dynamic_parameter_binding_;
1015
1016 // Holds the input/output alias information populated by the SetUpAlias() API.
1017 std::vector<InputOutputAlias> input_output_aliases_;
1018
1019 // A map from XlaOp::Handle to the index in the instructions_ vector where the
1020 // instruction is held.
1021 absl::flat_hash_map<int64, int64> handle_to_index_;
1022
1023 // Track imported instructions by their computation id and the position in
1024 // their computation's instruction list.
1025 struct ImportedInstruction {
1026 int64 computation_id;
1027 int64 instruction_index;
1028 };
1029
1030 absl::flat_hash_map<int64, ImportedInstruction> handle_to_imported_index_;
1031
1032 // The embedded computations used by this computation. Each computation was
1033 // the entry computation of some XlaComputation, the key is the unique id of
1034 // that XlaComputation.
1035 std::map<int64, HloComputationProto> embedded_;
1036
1037 // The unique parameter numbers.
1038 absl::flat_hash_set<int64> parameter_numbers_;
1039
1040 // The metadata to attach to each op. This is structured as a "modal"-like
1041 // operation, in order to simplify client code (and not sprinkle this metadata
1042 // throughout the TensorFlow op kernel implementations).
1043 OpMetadata metadata_;
1044
1045 // A temporary metadata that will only be applied to the next op created.
1046 absl::optional<OpMetadata> one_shot_metadata_;
1047
1048 // Sharding for this operator. This is structured as a "model"-like operation,
1049 // in order to simplify client code, similar to metadata_.
1050 absl::optional<OpSharding> sharding_;
1051
1052 // Mode bit that indicates whether to die when a first error is encountered.
1053 bool die_immediately_on_error_ = false;
1054
1055 XlaBuilder* parent_builder_{nullptr};
1056
1057 FrontendAttributes frontend_attributes_;
1058
1059 friend XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
1060 const Shape& shape, const string& name,
1061 const std::vector<bool>& replicated_at_leaf_buffers);
1062 friend XlaOp ConstantLiteral(XlaBuilder* builder,
1063 const LiteralSlice& literal);
1064
1065 friend XlaOp Broadcast(XlaOp operand,
1066 absl::Span<const int64> broadcast_sizes);
1067
1068 friend XlaOp BroadcastInDim(
1069 XlaOp operand, const absl::Span<const int64> out_dim_size,
1070 const absl::Span<const int64> broadcast_dimensions);
1071
1072 friend XlaOp Copy(XlaOp operand);
1073
1074 friend XlaOp Pad(XlaOp operand, XlaOp padding_value,
1075 const PaddingConfig& padding_config);
1076
1077 friend XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
1078 int64_t pad_lo, int64_t pad_hi);
1079
1080 friend XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
1081 absl::Span<const int64> new_sizes);
1082
1083 friend XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
1084
1085 friend XlaOp Reshape(const Shape& shape, XlaOp operand);
1086
1087 friend XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
1088 absl::Span<const int64> new_size_bounds,
1089 const std::vector<bool>& dims_are_dynamic);
1090
1091 friend XlaOp ReshapeWithInferredDimension(XlaOp operand,
1092 absl::Span<const int64> new_sizes,
1093 int64_t inferred_dimension);
1094
1095 friend XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
1096
1097 friend XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
1098 absl::Span<const int64> limit_indices,
1099 absl::Span<const int64> strides);
1100
1101 friend XlaOp SliceInDim(XlaOp operand, int64_t start_index,
1102 int64_t limit_index, int64_t stride, int64_t dimno);
1103
1104 friend XlaOp DynamicSlice(XlaOp operand,
1105 absl::Span<const XlaOp> start_indices,
1106 absl::Span<const int64> slice_sizes);
1107
1108 friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
1109 absl::Span<const XlaOp> start_indices);
1110
1111 friend XlaOp ConcatInDim(XlaBuilder* builder,
1112 absl::Span<const XlaOp> operands, int64_t dimension);
1113
1114 friend void Trace(const string& tag, XlaOp operand);
1115
1116 friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
1117 friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
1118 friend XlaOp GetTupleElement(XlaOp tuple_data, int64_t index);
1119 friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
1120 absl::Span<const int64> broadcast_dimensions,
1121 ComparisonDirection direction);
1122 friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
1123 absl::Span<const int64> broadcast_dimensions,
1124 ComparisonDirection direction,
1125 Comparison::Type compare_type);
1126 friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
1127 const PrecisionConfig* precision_config,
1128 absl::optional<PrimitiveType> preferred_element_type);
1129 friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
1130 const DotDimensionNumbers& dimension_number,
1131 const PrecisionConfig* precision_config,
1132 absl::optional<PrimitiveType> preferred_element_type);
1133 virtual StatusOr<XlaOp> DotGeneralInternal(
1134 const Shape& shape, XlaOp lhs, XlaOp rhs,
1135 const DotDimensionNumbers& dimension_number,
1136 const PrecisionConfig* precision_config);
1137 friend XlaOp Conv(XlaOp lhs, XlaOp rhs,
1138 absl::Span<const int64> window_strides, Padding padding,
1139 int64_t feature_group_count, int64_t batch_group_count,
1140 const PrecisionConfig* precision_config,
1141 absl::optional<PrimitiveType> preferred_element_type);
1142 friend XlaOp ConvWithGeneralPadding(
1143 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1144 absl::Span<const std::pair<int64, int64>> padding,
1145 int64_t feature_group_count, int64_t batch_group_count,
1146 const PrecisionConfig* precision_config,
1147 absl::optional<PrimitiveType> preferred_element_type);
1148 friend XlaOp ConvWithGeneralDimensions(
1149 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1150 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1151 int64_t feature_group_count, int64_t batch_group_count,
1152 const PrecisionConfig* precision_config,
1153 absl::optional<PrimitiveType> preferred_element_type);
1154 friend XlaOp ConvGeneral(
1155 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1156 absl::Span<const std::pair<int64, int64>> padding,
1157 const ConvolutionDimensionNumbers& dimension_numbers,
1158 int64_t feature_group_count, int64_t batch_group_count,
1159 const PrecisionConfig* precision_config,
1160 absl::optional<PrimitiveType> preferred_element_type);
1161 friend XlaOp DynamicConvForward(
1162 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1163 absl::Span<const std::pair<int64, int64>> padding,
1164 absl::Span<const int64> lhs_dilation,
1165 absl::Span<const int64> rhs_dilation,
1166 const ConvolutionDimensionNumbers& dimension_numbers,
1167 int64_t feature_group_count, int64_t batch_group_count,
1168 const PrecisionConfig* precision_config, PaddingType padding_type,
1169 absl::optional<PrimitiveType> preferred_element_type);
1170 friend XlaOp DynamicConvKernelGrad(
1171 XlaOp activations, XlaOp gradients,
1172 absl::Span<const int64> window_strides,
1173 absl::Span<const std::pair<int64, int64>> padding,
1174 absl::Span<const int64> lhs_dilation,
1175 absl::Span<const int64> rhs_dilation,
1176 const ConvolutionDimensionNumbers& dimension_numbers,
1177 int64_t feature_group_count, int64_t batch_group_count,
1178 const PrecisionConfig* precision_config, PaddingType padding_type,
1179 absl::optional<PrimitiveType> preferred_element_type);
1180 friend XlaOp DynamicConvInputGrad(
1181 XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
1182 absl::Span<const int64> window_strides,
1183 absl::Span<const std::pair<int64, int64>> padding,
1184 absl::Span<const int64> lhs_dilation,
1185 absl::Span<const int64> rhs_dilation,
1186 const ConvolutionDimensionNumbers& dimension_numbers,
1187 int64_t feature_group_count, int64_t batch_group_count,
1188 const PrecisionConfig* precision_config, PaddingType padding_type,
1189 absl::optional<PrimitiveType> preferred_element_type);
1190
1191 friend XlaOp ConvKernelGrad(
1192 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1193 absl::Span<const std::pair<int64, int64>> padding,
1194 absl::Span<const int64> lhs_dilation,
1195 absl::Span<const int64> rhs_dilation,
1196 const ConvolutionDimensionNumbers& dimension_numbers,
1197 int64_t feature_group_count, int64_t batch_group_count,
1198 const PrecisionConfig* precision_config,
1199 absl::optional<PrimitiveType> preferred_element_type);
1200
1201 friend XlaOp ConvGeneralDilated(
1202 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1203 absl::Span<const std::pair<int64, int64>> padding,
1204 absl::Span<const int64> lhs_dilation,
1205 absl::Span<const int64> rhs_dilation,
1206 const ConvolutionDimensionNumbers& dimension_numbers,
1207 int64_t feature_group_count, int64_t batch_group_count,
1208 const PrecisionConfig* precision_config,
1209 absl::optional<PrimitiveType> preferred_element_type);
1210 friend XlaOp Fft(XlaOp operand, FftType fft_type,
1211 absl::Span<const int64> fft_length);
1212 friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
1213 bool unit_diagonal,
1214 TriangularSolveOptions::Transpose transpose_a);
1215 friend XlaOp Cholesky(XlaOp a, bool lower);
1216 friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
1217 const string& config);
1218 friend void Outfeed(XlaOp operand, const Shape& shape_with_layout,
1219 const string& outfeed_config);
1220 friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
1221 absl::Span<const XlaOp> operands);
1222 friend XlaOp CustomCall(
1223 XlaBuilder* builder, const string& call_target_name,
1224 absl::Span<const XlaOp> operands, const Shape& shape,
1225 const string& opaque, bool has_side_effect,
1226 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1227 output_operand_aliasing,
1228 const Literal* literal, CustomCallSchedule schedule,
1229 CustomCallApiVersion api_version);
1230 friend XlaOp CustomCallWithComputation(
1231 XlaBuilder* builder, const string& call_target_name,
1232 absl::Span<const XlaOp> operands, const XlaComputation& computation,
1233 const Shape& shape, const string& opaque, bool has_side_effect,
1234 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1235 output_operand_aliasing,
1236 const Literal* literal, CustomCallSchedule schedule,
1237 CustomCallApiVersion api_version);
1238 friend XlaOp CustomCallWithLayout(
1239 XlaBuilder* builder, const string& call_target_name,
1240 absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
1241 absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
1242 bool has_side_effect,
1243 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1244 output_operand_aliasing,
1245 const Literal* literal, CustomCallSchedule schedule,
1246 CustomCallApiVersion api_version);
1247 friend XlaOp CustomCallWithConvDnums(
1248 XlaBuilder* builder, const string& call_target_name,
1249 absl::Span<const XlaOp> operands, const Shape& shape,
1250 absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
1251 bool has_side_effect,
1252 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1253 output_operand_aliasing,
1254 const Literal* literal, Window window, ConvolutionDimensionNumbers dnums,
1255 CustomCallSchedule schedule, CustomCallApiVersion api_version);
1256 friend XlaOp Complex(XlaOp real, XlaOp imag,
1257 absl::Span<const int64> broadcast_dimensions);
1258 friend XlaOp Conj(XlaOp operand);
1259 friend XlaOp Add(XlaOp lhs, XlaOp rhs,
1260 absl::Span<const int64> broadcast_dimensions);
1261 friend XlaOp Sub(XlaOp lhs, XlaOp rhs,
1262 absl::Span<const int64> broadcast_dimensions);
1263 friend XlaOp Mul(XlaOp lhs, XlaOp rhs,
1264 absl::Span<const int64> broadcast_dimensions);
1265 friend XlaOp Div(XlaOp lhs, XlaOp rhs,
1266 absl::Span<const int64> broadcast_dimensions);
1267 friend XlaOp Rem(XlaOp lhs, XlaOp rhs,
1268 absl::Span<const int64> broadcast_dimensions);
1269 friend XlaOp Max(XlaOp lhs, XlaOp rhs,
1270 absl::Span<const int64> broadcast_dimensions);
1271 friend XlaOp Min(XlaOp lhs, XlaOp rhs,
1272 absl::Span<const int64> broadcast_dimensions);
1273 friend XlaOp And(XlaOp lhs, XlaOp rhs,
1274 absl::Span<const int64> broadcast_dimensions);
1275 friend XlaOp Or(XlaOp lhs, XlaOp rhs,
1276 absl::Span<const int64> broadcast_dimensions);
1277 friend XlaOp Xor(XlaOp lhs, XlaOp rhs,
1278 absl::Span<const int64> broadcast_dimensions);
1279 friend XlaOp Not(XlaOp operand);
1280 friend XlaOp PopulationCount(XlaOp operand);
1281 friend XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs,
1282 absl::Span<const int64> broadcast_dimensions);
1283 friend XlaOp ShiftRightArithmetic(
1284 XlaOp lhs, XlaOp rhs, absl::Span<const int64> broadcast_dimensions);
1285 friend XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs,
1286 absl::Span<const int64> broadcast_dimensions);
1287 friend XlaOp Reduce(XlaOp operand, XlaOp init_value,
1288 const XlaComputation& computation,
1289 absl::Span<const int64> dimensions_to_reduce);
1290 friend XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1291 absl::Span<const XlaOp> init_values,
1292 const XlaComputation& computation,
1293 absl::Span<const int64> dimensions_to_reduce);
1294 friend XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
1295 const XlaComputation& computation);
1296 friend XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
1297 const XlaComputation& computation,
1298 absl::Span<const int64> window_dimensions,
1299 absl::Span<const int64> window_strides,
1300 Padding padding);
1301 friend XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
1302 absl::Span<const XlaOp> init_values,
1303 const XlaComputation& computation,
1304 absl::Span<const int64> window_dimensions,
1305 absl::Span<const int64> window_strides,
1306 Padding padding);
1307 friend XlaOp ReduceWindowWithGeneralPadding(
1308 XlaOp operand, XlaOp init_value, const XlaComputation& computation,
1309 absl::Span<const int64> window_dimensions,
1310 absl::Span<const int64> window_strides,
1311 absl::Span<const int64> base_dilations,
1312 absl::Span<const int64> window_dilations,
1313 absl::Span<const std::pair<int64, int64>> padding);
1314 friend XlaOp ReduceWindowWithGeneralPadding(
1315 absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
1316 const XlaComputation& computation,
1317 absl::Span<const int64> window_dimensions,
1318 absl::Span<const int64> window_strides,
1319 absl::Span<const int64> base_dilations,
1320 absl::Span<const int64> window_dilations,
1321 absl::Span<const std::pair<int64, int64>> padding);
1322
1323 friend XlaOp CrossReplicaSum(XlaOp operand,
1324 absl::Span<const ReplicaGroup> replica_groups);
1325 friend XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension,
1326 int64_t shard_count,
1327 absl::Span<const ReplicaGroup> replica_groups,
1328 const absl::optional<ChannelHandle>& channel_id,
1329 const absl::optional<Layout>& layout,
1330 const absl::optional<bool> use_global_device_ids);
1331 friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
1332 absl::Span<const ReplicaGroup> replica_groups,
1333 const absl::optional<ChannelHandle>& channel_id,
1334 const absl::optional<Shape>& shape_with_layout);
1335 friend XlaOp ReduceScatter(XlaOp operand, const XlaComputation& computation,
1336 int64_t scatter_dimension, int64_t shard_count,
1337 absl::Span<const ReplicaGroup> replica_groups,
1338 const absl::optional<ChannelHandle>& channel_id,
1339 const absl::optional<Layout>& layout,
1340 const absl::optional<bool> use_global_device_ids);
1341
1342 friend XlaOp AllToAll(XlaOp operand, int64_t split_dimension,
1343 int64_t concat_dimension, int64_t split_count,
1344 absl::Span<const ReplicaGroup> replica_groups,
1345 const absl::optional<Layout>& layout);
1346 friend XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension,
1347 int64_t concat_dimension, int64_t split_count,
1348 absl::Span<const ReplicaGroup> replica_groups,
1349 const absl::optional<Layout>& layout);
1350 friend XlaOp CollectivePermute(
1351 XlaOp operand,
1352 const std::vector<std::pair<int64, int64>>& source_target_pairs);
1353 friend XlaOp ReplicaId(XlaBuilder* builder);
1354 friend XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
1355 absl::Span<const int64> window_dimensions,
1356 absl::Span<const int64> window_strides,
1357 Padding padding, XlaOp source, XlaOp init_value,
1358 const XlaComputation& scatter);
1359 friend XlaOp SelectAndScatterWithGeneralPadding(
1360 XlaOp operand, const XlaComputation& select,
1361 absl::Span<const int64> window_dimensions,
1362 absl::Span<const int64> window_strides,
1363 absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
1364 XlaOp init_value, const XlaComputation& scatter);
1365 friend XlaOp Abs(XlaOp operand);
1366 friend XlaOp Atan2(XlaOp y, XlaOp x,
1367 absl::Span<const int64> broadcast_dimensions);
1368 friend XlaOp Exp(XlaOp operand);
1369 friend XlaOp Expm1(XlaOp operand);
1370 friend XlaOp Floor(XlaOp operand);
1371 friend XlaOp Ceil(XlaOp operand);
1372 friend XlaOp Round(XlaOp operand);
1373 friend XlaOp Log(XlaOp operand);
1374 friend XlaOp Log1p(XlaOp operand);
1375 friend XlaOp Logistic(XlaOp operand);
1376 friend XlaOp Sign(XlaOp operand);
1377 friend XlaOp Clz(XlaOp operand);
1378 friend XlaOp Cos(XlaOp operand);
1379 friend XlaOp Sin(XlaOp operand);
1380 friend XlaOp Tanh(XlaOp operand);
1381 friend XlaOp Real(XlaOp operand);
1382 friend XlaOp Imag(XlaOp operand);
1383 friend XlaOp Sqrt(XlaOp operand);
1384 friend XlaOp Rsqrt(XlaOp operand);
1385 friend XlaOp Cbrt(XlaOp operand);
1386 friend XlaOp Pow(XlaOp lhs, XlaOp rhs,
1387 absl::Span<const int64> broadcast_dimensions);
1388 friend XlaOp IsFinite(XlaOp operand);
1389 friend XlaOp Iota(XlaBuilder* builder, const Shape& shape,
1390 int64_t iota_dimension);
1391 friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size);
1392 friend XlaOp ConvertElementType(XlaOp operand,
1393 PrimitiveType new_element_type);
1394 friend XlaOp BitcastConvertType(XlaOp operand,
1395 PrimitiveType new_element_type);
1396 friend XlaOp Neg(XlaOp operand);
1397 friend XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
1398 friend XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
1399 friend XlaOp Sort(absl::Span<const XlaOp> operands,
1400 const XlaComputation& comparator, int64_t dimension,
1401 bool is_stable);
1402 friend XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
1403 friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1404 const XlaComputation& computation,
1405 absl::Span<const int64> dimensions,
1406 absl::Span<const XlaOp> static_operands);
1407 friend XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
1408 friend XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
1409 friend XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
1410 const Shape& shape);
1411 friend XlaOp While(const XlaComputation& condition,
1412 const XlaComputation& body, XlaOp init);
1413 friend XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
1414 const XlaComputation& true_computation,
1415 XlaOp false_operand,
1416 const XlaComputation& false_computation);
1417 friend XlaOp Conditional(
1418 XlaOp branch_index,
1419 absl::Span<const XlaComputation* const> branch_computations,
1420 absl::Span<const XlaOp> branch_operands);
1421 friend XlaOp ConditionalImpl(
1422 XlaOp branch_index,
1423 absl::Span<const XlaComputation* const> branch_computations,
1424 absl::Span<const XlaOp> branch_operands);
1425 friend XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
1426 const int mantissa_bits);
1427 friend XlaOp Gather(XlaOp input, XlaOp start_indices,
1428 const GatherDimensionNumbers& dimension_numbers,
1429 absl::Span<const int64> slice_sizes,
1430 bool indices_are_sorted);
1431 friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
1432 const XlaComputation& update_computation,
1433 const ScatterDimensionNumbers& dimension_numbers,
1434 bool indices_are_sorted, bool unique_indices);
1435 friend void Send(XlaOp operand, const ChannelHandle& handle);
1436 friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
1437 const ChannelHandle& handle);
1438 friend XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
1439 float epsilon, int64_t feature_index);
1440 friend XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset,
1441 XlaOp mean, XlaOp variance, float epsilon,
1442 int64_t feature_index);
1443 friend XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
1444 XlaOp batch_var, XlaOp grad_output, float epsilon,
1445 int64_t feature_index);
1446 friend XlaOp SendWithToken(XlaOp operand, XlaOp token,
1447 const ChannelHandle& handle);
1448 friend XlaOp RecvWithToken(XlaOp token, const Shape& shape,
1449 const ChannelHandle& handle);
1450 friend XlaOp SendToHost(XlaOp operand, XlaOp token,
1451 const Shape& shape_with_layout,
1452 const ChannelHandle& handle);
1453 friend XlaOp RecvFromHost(XlaOp token, const Shape& shape,
1454 const ChannelHandle& handle);
1455 friend XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
1456 const string& config);
1457 friend XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
1458 const Shape& shape_with_layout,
1459 const string& outfeed_config);
1460 friend XlaOp CreateToken(XlaBuilder* builder);
1461 friend XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
1462
1463 friend XlaOp GetDimensionSize(XlaOp operand, int64_t dimension);
1464 friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension);
1465 friend XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension);
1466
1467 protected:
1468 // Returns OK status if the given op was built using this builder. Otherwise,
1469 // returns an error.
1470 Status CheckOpBuilder(XlaOp op) const;
1471
1472 private:
1473 XlaOp ConditionalImpl(
1474 XlaOp branch_index,
1475 absl::Span<const XlaComputation* const> branch_computations,
1476 absl::Span<const XlaOp> branch_operands);
1477
1478 XlaOp AllToAllArray(XlaOp operand, int64_t split_dimension,
1479 int64_t concat_dimension, int64_t split_count,
1480 absl::Span<const ReplicaGroup> replica_groups);
1481
1482 // Creates an op with the given opcode and the output shape.
1483 virtual StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
1484 absl::Span<const XlaOp> operands);
1485
1486 // Here, InstructionType is either const HloInstructionProto* or non-const
1487 // HloInstructionProto*.
1488 template <typename InstructionType>
LookUpInstructionByHandleInternal(int64_t handle)1489 StatusOr<InstructionType> LookUpInstructionByHandleInternal(
1490 int64_t handle) const {
1491 auto it = handle_to_index_.find(handle);
1492 if (it == handle_to_index_.end()) {
1493 // Try look for the instruction in the imported instructions.
1494 auto imported_it = handle_to_imported_index_.find(handle);
1495 if (imported_it != handle_to_imported_index_.end()) {
1496 ImportedInstruction imported = imported_it->second;
1497 return const_cast<InstructionType>(
1498 &embedded_.at(imported.computation_id)
1499 .instructions(imported.instruction_index));
1500 }
1501 return InvalidArgument("No XlaOp with handle %d", handle);
1502 }
1503 return const_cast<InstructionType>(&instructions_.at(it->second));
1504 }
1505
1506 // Here, InstructionType is either const HloInstructionProto* or non-const
1507 // HloInstructionProto*.
1508 //
1509 // TODO(hinsu): Return const pointer within StatusOr and use
1510 // absl::implicit_cast at callsites. This requires implicit_cast support in
1511 // stream_executor::port::StatusOr similar to absl::StatusOr.
1512 template <typename InstructionType>
LookUpInstructionInternal(XlaOp op)1513 StatusOr<InstructionType> LookUpInstructionInternal(XlaOp op) const {
1514 TF_RETURN_IF_ERROR(CheckOpBuilder(op));
1515 return LookUpInstructionByHandleInternal<InstructionType>(op.handle());
1516 }
1517
1518 friend struct internal::XlaBuilderFriend;
1519
1520 friend class ValueInference;
1521 };
1522
1523 // RAII-style object: sets the current sharding assignment in builder on
1524 // construction, and sets back to the previous assignment on destruction.
1525 class XlaScopedShardingAssignment {
1526 public:
XlaScopedShardingAssignment(xla::XlaBuilder * builder,absl::optional<OpSharding> sharding)1527 XlaScopedShardingAssignment(xla::XlaBuilder* builder,
1528 absl::optional<OpSharding> sharding)
1529 : builder_(builder), prev_sharding_(builder->sharding()) {
1530 SetSharding(sharding);
1531 }
1532
1533 XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
1534 XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
1535 delete;
1536
~XlaScopedShardingAssignment()1537 ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
1538
1539 private:
SetSharding(const absl::optional<OpSharding> & sharding)1540 void SetSharding(const absl::optional<OpSharding>& sharding) {
1541 if (sharding.has_value()) {
1542 builder_->SetSharding(sharding.value());
1543 } else {
1544 builder_->ClearSharding();
1545 }
1546 }
1547
1548 xla::XlaBuilder* const builder_;
1549 absl::optional<OpSharding> prev_sharding_;
1550 };
1551
1552 // RAII-style object: save the current builder's frontend attributes, and merge
1553 // them with the new ones on construction.
1554 // Restore the original attributes on destruction.
1555 class XlaScopedFrontendAttributesAssignment {
1556 public:
XlaScopedFrontendAttributesAssignment(xla::XlaBuilder * builder,FrontendAttributes attributes)1557 XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder,
1558 FrontendAttributes attributes)
1559 : builder_(builder) {
1560 saved_ = builder_->SwapFrontendAttributes(attributes);
1561 }
1562
~XlaScopedFrontendAttributesAssignment()1563 ~XlaScopedFrontendAttributesAssignment() {
1564 builder_->SetFrontendAttributes(saved_);
1565 }
1566
1567 private:
1568 xla::XlaBuilder* const builder_;
1569 FrontendAttributes saved_;
1570
1571 TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedFrontendAttributesAssignment);
1572 };
1573
1574 // RAII-style object: sets the current op metadata in builder on construction,
1575 // and sets back to the previous assignment on destruction.
1576 class XlaScopedOpMetadataAssignment {
1577 public:
XlaScopedOpMetadataAssignment(xla::XlaBuilder * builder,OpMetadata metadata)1578 XlaScopedOpMetadataAssignment(xla::XlaBuilder* builder, OpMetadata metadata)
1579 : builder_(builder) {
1580 saved_ = builder_->SwapOpMetadata(metadata);
1581 }
1582
~XlaScopedOpMetadataAssignment()1583 ~XlaScopedOpMetadataAssignment() { builder_->SwapOpMetadata(saved_); }
1584
1585 private:
1586 xla::XlaBuilder* const builder_;
1587 OpMetadata saved_;
1588
1589 TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedOpMetadataAssignment);
1590 };
1591
1592 // Free functions for building XlaOps. The intention is that these will
1593 // become the public API for building XlaOps rather than calling methods on
1594 // XlaBuilder directly.
1595 //
1596
1597 // Enqueues a "retrieve parameter value" instruction for a parameter that was
1598 // passed to the computation.
1599 XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
1600 const Shape& shape, const string& name);
1601
1602 // Same as above, but with leaf buffer replication annotation.
1603 XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
1604 const Shape& shape, const string& name,
1605 const std::vector<bool>& replicated_at_leaf_buffers);
1606
1607 // Enqueues a constant with the value of the given literal onto the
1608 // computation.
1609 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal);
1610
1611 // Enqueues a constant onto the computation. Methods are templated on the
1612 // native host type (NativeT) which corresponds to a specific XLA
1613 // PrimitiveType as given in the following table:
1614 //
1615 // Native Type PrimitiveType
1616 // -----------------------------
1617 // bool PRED
1618 // int32 S32
1619 // int64 S64
1620 // uint32 U32
1621 // uint64 U64
1622 // float F32
1623 // double F64
1624 //
1625 // Note: not all primitive types defined in xla_data.proto have a
1626 // corresponding native type yet.
1627 template <typename NativeT>
1628 XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
1629 template <typename NativeT>
1630 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values);
1631 XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values);
1632 template <typename NativeT>
1633 XlaOp ConstantR2(XlaBuilder* builder,
1634 std::initializer_list<std::initializer_list<NativeT>> values);
1635 template <typename NativeT>
1636 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
1637 const Array<NativeT>& values,
1638 const Layout& layout);
1639 template <typename NativeT>
1640 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values);
1641 template <typename NativeT>
1642 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
1643 const Array2D<NativeT>& values,
1644 const Layout& layout);
1645 template <typename NativeT>
1646 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
1647 const Array2D<NativeT>& values);
1648 template <typename NativeT>
1649 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
1650 const Array3D<NativeT>& values,
1651 const Layout& layout);
1652 template <typename NativeT>
1653 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
1654 const Array3D<NativeT>& values);
1655 template <typename NativeT>
1656 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
1657 const Array4D<NativeT>& values,
1658 const Layout& layout);
1659 template <typename NativeT>
1660 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
1661 const Array4D<NativeT>& values);
1662
1663 // Enqueues a rank one constant (XlaBuilder* builder, vector) onto the
1664 // computation. The vector has size 'length' and every element has the value
1665 // 'value'.
1666 template <typename NativeT>
1667 XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value);
1668
1669 // Adds dimensions to an array by duplicating the data in the array.
1670 //
1671 // The new dimensions are inserted on the left, i.e. if
1672 // broadcast_sizes has values {a0, ..., aN} and the operand shape
1673 // has dimensions {b0, ..., bM} then the shape of the output has
1674 // dimensions {a0, ..., aN, b0, ..., bM}.
1675 //
1676 // The new dimensions index into copies of the operand, i.e.
1677 //
1678 // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
1679 XlaOp Broadcast(XlaOp operand, absl::Span<const int64> broadcast_sizes);
1680
1681 // This op broadcasts the `operand` to an output with the given `shape`.
1682 // `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the
1683 // i'th dimension of the operand is mapped to the broadcast_dimensions[i]'th
1684 // dimension of the output. This also requires that the i'th input dimension is
1685 // either 1 or is the same as the output dimension it's broadcasting into.
1686 //
1687 // For example, say operand = {1, 2}, i.e., a 1D tensor in shape s32[2]; the
1688 // output shape is s32[2,2]:
1689 // - Specifying {1} as broadcast_dimension will generate output
1690 // {{1, 2},
1691 // {1, 2}}
1692 // - On the other hand, specifying {0} as broadcast_dimension
1693 // will generate output
1694 // {{1 , 1},
1695 // {2 , 2}}
1696 XlaOp BroadcastInDim(XlaOp operand, const absl::Span<const int64> out_dim_size,
1697 const absl::Span<const int64> broadcast_dimensions);
1698
1699 // Copies the input operand to the output. This operation is for internal
1700 // purpose and is only used by the compiler for optimization purposes or to
1701 // ensure correctness. The XLA client should never have to generate this
1702 // instruction.
1703 //
1704 // Copy has two potential use cases:
1705 //
1706 // * Create a copy of the operand with a new layout.
1707 //
1708 // * Create a copy of the operand in a separately allocated buffer. This is
1709 // necessary for some backends if the operand is a parameter or constant and
1710 // the operand is returned within a tuple. In this case, the lifetime of the
1711 // operand buffer must be the same as the lifetime of the output result.
1712 // However, the lifetimes of parameters and constants are managed separately
1713 // from the lifetime of the output result. Creating a separate copy of the
1714 // parameter or constant buffer resolves this issue.
1715 XlaOp Copy(XlaOp operand);
1716
1717 // Enqueues a pad operation onto the computation that pads the given value on
1718 // the edges as well as between the elements of the input. padding_config
1719 // specifies the padding amount for each dimension.
1720 XlaOp Pad(XlaOp operand, XlaOp padding_value,
1721 const PaddingConfig& padding_config);
1722
1723 // Enqueues a pad operation in a given dimension, taking all other
1724 // dimensions as they are.
1725 XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
1726 int64_t pad_lo, int64_t pad_hi);
1727
1728 // Enqueues an operation onto the computation that flattens the operand based
1729 // on the dimension order (major/slowest-varying to minor/fastest-varying)
1730 // given, followed by reshaping it into the shape with the given dimension
1731 // sizes (also major to minor). Conceptually, this is a limited form of
1732 // "shape casting".
1733 XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
1734 absl::Span<const int64> new_sizes);
1735
1736 // Enqueues a dynamic reshape operation. The dynamic reshape takes additional
1737 // XlaOps as sizes for the result dimension. The result dim i is a dynamic
1738 // dimension dimension if dims_are_dynamic[i] is true.
1739 XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
1740 absl::Span<const int64> new_size_bounds,
1741 const std::vector<bool>& dims_are_dynamic);
1742
1743 // Enqueues an operation onto the computation that collapses the operand,
1744 // from first to last dimension (C order), then reshapes it to the given
1745 // dimension sizes. Conceptually, this is a limited form of "shape casting".
1746 XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
1747
1748 // Enqueues a Reshape op that uses an explicit target shape.
1749 XlaOp Reshape(const Shape& shape, XlaOp operand);
1750
1751 // `inferred_dimension` represents the output dimension that's inferred by
1752 // upper-level framework by dividing the input element count by the known
1753 // output element count. While an inferred_dimension can be static, if there
1754 // is a dynamic dimension in the output, it must be the inferred dimension.
1755 XlaOp ReshapeWithInferredDimension(XlaOp operand,
1756 absl::Span<const int64> new_sizes,
1757 int64_t inferred_dimension);
1758
1759 // Wrapper for Reshape.
1760 // Enqueues an operation to collapse the provided dimensions; e.g. an
1761 // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
1762 // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
1763 // be a consecutive, in-order subsequence of the operand dimensions.
1764 //
1765 // Note that collapsing a single dimension does nothing:
1766 //
1767 // {256} collapsing {0} => {256}
1768 // {1} collapsing {0} => {1}
1769 //
1770 // Collapsing multiple dimensions produces a single result dimension:
1771 //
1772 // {256, 2} collapsing {0,1} => {512}
1773 // {256, 2, 3} collapsing {0,1} => {512, 3}
1774 //
1775 // This could potentially cause data to be moved -- it provides a more
1776 // structured form of reshaping than an arbitrary Reshape operation.
1777 XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
1778
1779 // Enqueues a slice operation onto the computation that slices the operand
1780 // from the start indices to the limit indices; e.g.
1781 //
1782 // x
1783 // [ 0 1 2 3 ]
1784 // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
1785 // [ 8 9 a b ]
1786 //
1787 // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
1788 // range notation.
1789 // The strides parameter determines the stride over the slice
1790 XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
1791 absl::Span<const int64> limit_indices,
1792 absl::Span<const int64> strides);
1793
1794 // Enqueues a slice operation in a given dimension, taking all other
1795 // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
1796 // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
1797 // for:
1798 //
1799 // array[:, 2:4:1, :]
1800 XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index,
1801 int64_t stride, int64_t dimno);
1802
1803 // Enqueues a slice operation onto the computation that slices the 'operand'
1804 // from dynamic start indices which are passed in 'start_indices'.
1805 // The size of the slice in each dimension is passed in 'slice_sizes',
1806 // which specify the end point of exclusive slice intervals in each
1807 // dimension [start, start + size).
1808 // The shape of each element of 'start_indices' must be scalar, with the span
1809 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1810 // have the same shape.
1811 // Slice index calculations are computed modulo input dimension sizes to
1812 // prevent dynamic start indices from generating out-of-bound array accesses.
1813 XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
1814 absl::Span<const int64> slice_sizes);
1815
1816 // Enqueues a dynamic update slice operation onto the computation, which
1817 // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
1818 // The shape of 'update' determines the shape of the slice of 'operand'
1819 // which is updated.
1820 // The indices specified in 'start_indices' specify the offset of the slice
1821 // of 'operand' which is updated.
1822 //
1823 // update = {10, 11} // calculated at runtime.
1824 // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ]
1825 // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11]
1826 // [7 8 9] [7 8 9 ]
1827 //
1828 // The shape of each element of 'start_indices' must be scalar, with the span
1829 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1830 // have the same shape.
1831 // Slice index calculations are computed modulo update dimension sizes to
1832 // prevent dynamic start indices from generating out-of-bound array accesses.
1833 XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
1834 absl::Span<const XlaOp> start_indices);
1835
1836 // Enqueues a concatenate instruction onto the computation. 'operands' must
1837 // have >= 1 entry.
1838 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1839 int64_t dimension);
1840
1841 // Enqueue a tracing operation onto the computation; the computation will emit
1842 // a logging message with the operand.
1843 void Trace(const string& tag, XlaOp operand);
1844
1845 // Enqueues a conditional-move-like select operation onto the computation;
1846 // predicated on pred, selects between on_true and on_false.
1847 XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
1848
1849 // Enqueues a tuple-creation instruction onto the computation.
1850 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
1851
1852 // Enqueues a tuple-element-get instruction onto the computation.
1853 XlaOp GetTupleElement(XlaOp tuple_data, int64_t index);
1854
1855 // Enqueues an equal-to comparison instruction onto the computation.
1856 XlaOp Eq(XlaOp lhs, XlaOp rhs,
1857 absl::Span<const int64> broadcast_dimensions = {});
1858 XlaOp EqTotalOrder(XlaOp lhs, XlaOp rhs,
1859 absl::Span<const int64> broadcast_dimensions = {});
1860
1861 // Enqueues a not-equal comparison instruction onto the computation.
1862 XlaOp Ne(XlaOp lhs, XlaOp rhs,
1863 absl::Span<const int64> broadcast_dimensions = {});
1864 XlaOp NeTotalOrder(XlaOp lhs, XlaOp rhs,
1865 absl::Span<const int64> broadcast_dimensions = {});
1866
1867 // Enqueues a greater-or-equal comparison instruction onto the computation.
1868 XlaOp Ge(XlaOp lhs, XlaOp rhs,
1869 absl::Span<const int64> broadcast_dimensions = {});
1870 XlaOp GeTotalOrder(XlaOp lhs, XlaOp rhs,
1871 absl::Span<const int64> broadcast_dimensions = {});
1872
1873 // Enqueues a greater-than comparison instruction onto the computation.
1874 XlaOp Gt(XlaOp lhs, XlaOp rhs,
1875 absl::Span<const int64> broadcast_dimensions = {});
1876 XlaOp GtTotalOrder(XlaOp lhs, XlaOp rhs,
1877 absl::Span<const int64> broadcast_dimensions = {});
1878
1879 // Enqueues a less-than comparison instruction onto the computation.
1880 XlaOp Lt(XlaOp lhs, XlaOp rhs,
1881 absl::Span<const int64> broadcast_dimensions = {});
1882 XlaOp LtTotalOrder(XlaOp lhs, XlaOp rhs,
1883 absl::Span<const int64> broadcast_dimensions = {});
1884
1885 // Enqueues a less-or-equal comparison instruction onto the computation.
1886 XlaOp Le(XlaOp lhs, XlaOp rhs,
1887 absl::Span<const int64> broadcast_dimensions = {});
1888 XlaOp LeTotalOrder(XlaOp lhs, XlaOp rhs,
1889 absl::Span<const int64> broadcast_dimensions = {});
1890
1891 // Enqueues a comparison instruction onto the computation (optionally without
1892 // broadcast_dimensions for consistency with others).
1893 XlaOp Compare(XlaOp lhs, XlaOp rhs,
1894 absl::Span<const int64> broadcast_dimensions,
1895 ComparisonDirection direction, Comparison::Type compare_type);
1896 XlaOp Compare(XlaOp lhs, XlaOp rhs,
1897 absl::Span<const int64> broadcast_dimensions,
1898 ComparisonDirection direction);
1899 XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction);
1900
1901 // Enqueues a dot instruction onto the computation.
1902 XlaOp Dot(XlaOp lhs, XlaOp rhs,
1903 const PrecisionConfig* precision_config = nullptr,
1904 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1905
1906 // Enqueues a general dot instruction onto the computation.
1907 XlaOp DotGeneral(
1908 XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
1909 const PrecisionConfig* precision_config = nullptr,
1910 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1911
1912 // Enqueues a convolution instruction onto the computation, which uses the
1913 // default convolution dimension numbers.
1914 XlaOp Conv(
1915 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1916 Padding padding, int64_t feature_group_count = 1,
1917 int64_t batch_group_count = 1,
1918 const PrecisionConfig* precision_config = nullptr,
1919 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1920
1921 // Enqueues a convolution instruction onto the computation, with the caller
1922 // provided padding configuration in the format returned by MakePadding().
1923 XlaOp ConvWithGeneralPadding(
1924 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1925 absl::Span<const std::pair<int64, int64>> padding,
1926 int64_t feature_group_count = 1, int64_t batch_group_count = 1,
1927 const PrecisionConfig* precision_config = nullptr,
1928 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1929
1930 // Enqueues a convolution instruction onto the computation, with the caller
1931 // provided dimension numbers configuration.
1932 XlaOp ConvWithGeneralDimensions(
1933 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1934 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1935 int64_t feature_group_count = 1, int64_t batch_group_count = 1,
1936 const PrecisionConfig* precision_config = nullptr,
1937 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1938
1939 // Enqueues a convolution instruction onto the computation, with the caller
1940 // provided padding configuration as well as the dimension numbers.
1941 XlaOp ConvGeneral(
1942 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1943 absl::Span<const std::pair<int64, int64>> padding,
1944 const ConvolutionDimensionNumbers& dimension_numbers,
1945 int64_t feature_group_count = 1, int64_t batch_group_count = 1,
1946 const PrecisionConfig* precision_config = nullptr,
1947 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1948
1949 // Enqueues a convolution instruction onto the computation, with the caller
1950 // provided padding configuration, dilation factors and dimension numbers.
1951 XlaOp ConvGeneralDilated(
1952 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1953 absl::Span<const std::pair<int64, int64>> padding,
1954 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1955 const ConvolutionDimensionNumbers& dimension_numbers,
1956 int64_t feature_group_count = 1, int64_t batch_group_count = 1,
1957 const PrecisionConfig* precision_config = nullptr,
1958 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1959
1960 XlaOp DynamicConvForward(
1961 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1962 absl::Span<const std::pair<int64, int64>> padding,
1963 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1964 const ConvolutionDimensionNumbers& dimension_numbers,
1965 int64_t feature_group_count, int64_t batch_group_count,
1966 const PrecisionConfig* precision_config, PaddingType padding_type,
1967 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1968
1969 XlaOp DynamicConvInputGrad(
1970 XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
1971 absl::Span<const int64> window_strides,
1972 absl::Span<const std::pair<int64, int64>> padding,
1973 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1974 const ConvolutionDimensionNumbers& dimension_numbers,
1975 int64_t feature_group_count, int64_t batch_group_count,
1976 const PrecisionConfig* precision_config, PaddingType padding_type,
1977 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1978
1979 XlaOp DynamicConvKernelGrad(
1980 XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
1981 absl::Span<const std::pair<int64, int64>> padding,
1982 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1983 const ConvolutionDimensionNumbers& dimension_numbers,
1984 int64_t feature_group_count, int64_t batch_group_count,
1985 const PrecisionConfig* precision_config, PaddingType padding_type,
1986 absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1987
1988 // Enqueues an FFT instruction onto the computation, of the given type and
1989 // with the given FFT length.
1990 XlaOp Fft(XlaOp operand, FftType fft_type, absl::Span<const int64> fft_length);
1991
1992 // Solves systems of linear equations with lower or upper triangular coefficient
1993 // matrices by forward- or back-substitution. Broadcasting along leading
1994 // dimensions, this routine solves for x in one of the matrix systems
1995 // `op(a) * x = b`, or `x * op(a) = b`,
1996 // for the variable `x` given `a` and `b`, where `op(a)` is either
1997 // `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`.
1998 //
1999 // * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form
2000 // square matrices. If `lower` is true (false), then the strictly upper
2001 // (lower) triangular part of each innermost matrix in `a` is assumed to be
2002 // zero and is not accessed.
2003 // * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a
2004 // tensor of shape `[..., K, M]`.
2005 // * `left_side` is a boolean, indicating whether to solve a system of the form
2006 // op(a) * x = b (true) or x * op(a) = b (false).
2007 // * `lower` is a boolean, indicating whether the argument `a` is
2008 // lower-triangular (true) or upper-triangular (false).
2009 // * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be
2010 // 1 and not accessed.
2011 // * `transpose_a` indicates which function `op` we use to transform the tensor
2012 // `a`: the identity function, transpose(a), or conjugate(transpose(a))
2013 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
2014 bool unit_diagonal,
2015 TriangularSolveOptions::Transpose transpose_a);
2016
2017 // Computes the Cholesky decompositions of a batch of symmetric (Hermitian)
2018 // positive definite matrices.
2019 // `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the
2020 // two minor dimensions equal.
2021 // If `lower` is true, the data from the lower triangle is used; if false, the
2022 // upper triangle is used. The input data in the other triangle of the input
2023 // does not affect the output. Returns the output in the same lower/upper
2024 // triangle. The data returned in the other output triangle is arbitrary and
2025 // implementation-defined.
2026 //
2027 // If `a` is not Hermitian positive definite, returns an array full of NaNs.
2028 XlaOp Cholesky(XlaOp a, bool lower);
2029
2030 // Enqueues an infeed instruction onto the computation, which writes data of
2031 // the given shape to the infeed buffer of the device.
2032 XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
2033 const string& config = "");
2034
2035 // Variant of Infeed which takes a token-shaped operand and produces a
2036 // two-element tuple containing the data value and a token-shaped value.
2037 // Tokens are used for ordering side-effecting operations.
2038 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2039 XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
2040 const string& config = "");
2041
2042 // Enqueues an outfeed instruction onto the computation. This instruction
2043 // generates outgoing data transfers for the given data.
2044 //
2045 // shape_with_layout communicates the laid out shape that we want to outfeed
2046 // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
2047 // will occur.
2048 void Outfeed(XlaOp operand, const Shape& shape_with_layout,
2049 const string& outfeed_config);
2050
2051 // Variant of Outfeed which takes a token-shaped operand and produces a
2052 // token-shaped value. Tokens are used for ordering side-effecting operations.
2053 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2054 XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
2055 const Shape& shape_with_layout,
2056 const string& outfeed_config);
2057
2058 // Enqueues a call instruction onto the computation.
2059 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
2060 absl::Span<const XlaOp> operands);
2061
2062 // Enqueues a custom call instruction onto the computation. A custom call
2063 // invokes code external to XLA. The |operands| are passed to the external code,
2064 // and the external code is expected to produce a result of the given
2065 // |shape|. The exact mechanism is backend-specific. For example, in the CPU
2066 // backend, a call instruction is emitted which targets a symbol with the name
2067 // |call_target_name|. |call_target_name| and |opaque| can arbitrary strings,
2068 // but |call_target_name| should be short as it may be used in labels. |opaque|
2069 // can encode arbitrarily large amounts of information. |has_side_effect|
2070 // specifies whether the instruction can have side effects.
2071 // |output_operand_aliasing| specifies a list of output/operand buffer pairs
2072 // that alias each other, where the output buffer is represented as a
2073 // ShapeIndex, and the operand buffer is represented as the operand index and
2074 // the ShapeIndex.
2075 XlaOp CustomCall(
2076 XlaBuilder* builder, const string& call_target_name,
2077 absl::Span<const XlaOp> operands, const Shape& shape,
2078 const string& opaque = "", bool has_side_effect = false,
2079 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
2080 output_operand_aliasing = {},
2081 const Literal* literal = nullptr,
2082 CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE,
2083 CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
2084
2085 // Overload which constructs a custom call that applies an Xla computation.
2086 XlaOp CustomCallWithComputation(
2087 XlaBuilder* builder, const string& call_target_name,
2088 absl::Span<const XlaOp> operands, const XlaComputation& computation,
2089 const Shape& shape, const string& opaque = "", bool has_side_effect = false,
2090 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
2091 output_operand_aliasing = {},
2092 const Literal* literal = nullptr,
2093 CustomCallSchedule = CustomCallSchedule::SCHEDULE_NONE,
2094 CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
2095
2096 // Overload which constructs a custom call with fixed layouts. The operands will
2097 // have the layouts specified by |operand_shapes_with_layout| when provided to
2098 // external code, and the external code is expected to produce a result with the
2099 // layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
2100 // and |operand_shapes_with_layout| must have layouts.
2101 XlaOp CustomCallWithLayout(
2102 XlaBuilder* builder, const string& call_target_name,
2103 absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
2104 absl::Span<const Shape> operand_shapes_with_layout,
2105 const string& opaque = "", bool has_side_effect = false,
2106 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
2107 output_operand_aliasing = {},
2108 const Literal* literal = nullptr,
2109 CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE,
2110 CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
2111
2112 // Overload which annotates a custom call with the given Window and
2113 // ConvolutionDimensionNumbers. Useful for custom-calls which represent
2114 // convolutions.
2115 //
2116 // This sets the layout of its operands if operand_shapes_with_layout is
2117 // nonempty, and it sets the layout of its result if `shape` has a layout.
2118 XlaOp CustomCallWithConvDnums(
2119 XlaBuilder* builder, const string& call_target_name,
2120 absl::Span<const XlaOp> operands, const Shape& shape,
2121 absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
2122 bool has_side_effect,
2123 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
2124 output_operand_aliasing,
2125 const Literal* literal, Window window, ConvolutionDimensionNumbers dnums,
2126 CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE,
2127 CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
2128
2129 // The following methods enqueue element-wise binary arithmetic operations
2130 // onto the computation. The shapes of the operands have to match unless one
2131 // of the operands is a scalar, or an explicit broadcast dimension is given
2132 // (see g3doc for more details).
2133
2134 // Enqueues a complex compose instruction onto the computation.
2135 XlaOp Complex(XlaOp real, XlaOp imag,
2136 absl::Span<const int64> broadcast_dimensions = {});
2137
2138 // Enqueues a complex conjugate instruction onto the computation.
2139 XlaOp Conj(XlaOp operand);
2140
2141 // Enqueues an add instruction onto the computation.
2142 XlaOp Add(XlaOp lhs, XlaOp rhs,
2143 absl::Span<const int64> broadcast_dimensions = {});
2144
2145 // Enqueues a subtract instruction onto the computation.
2146 XlaOp Sub(XlaOp lhs, XlaOp rhs,
2147 absl::Span<const int64> broadcast_dimensions = {});
2148
2149 // Enqueues a multiply instruction onto the computation.
2150 XlaOp Mul(XlaOp lhs, XlaOp rhs,
2151 absl::Span<const int64> broadcast_dimensions = {});
2152
2153 // Enqueues a divide instruction onto the computation.
2154 XlaOp Div(XlaOp lhs, XlaOp rhs,
2155 absl::Span<const int64> broadcast_dimensions = {});
2156
2157 // Enqueues a remainder instruction onto the computation.
2158 XlaOp Rem(XlaOp lhs, XlaOp rhs,
2159 absl::Span<const int64> broadcast_dimensions = {});
2160
2161 // Enqueues a max instruction onto the computation.
2162 XlaOp Max(XlaOp lhs, XlaOp rhs,
2163 absl::Span<const int64> broadcast_dimensions = {});
2164
2165 // Enqueues a min instruction onto the computation.
2166 XlaOp Min(XlaOp lhs, XlaOp rhs,
2167 absl::Span<const int64> broadcast_dimensions = {});
2168
2169 // Element-wise logical operators
2170 XlaOp And(XlaOp lhs, XlaOp rhs,
2171 absl::Span<const int64> broadcast_dimensions = {});
2172
2173 // Overload to call And with 3 or more operands. We need the following somewhat
2174 // convoluted overload set to disambiguate with the overload that takes the
2175 // `broadcast_dimensions` optional param.
And(XlaOp op1,XlaOp op2,XlaOp op3)2176 inline XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3) {
2177 return And(op1, And(op2, op3));
2178 }
2179 template <typename... XlaOpTs>
And(XlaOp op1,XlaOp op2,XlaOp op3,const XlaOpTs &...operands)2180 XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) {
2181 return And(op1, And(op2, And(op3, operands...)));
2182 }
2183
2184 XlaOp Or(XlaOp lhs, XlaOp rhs,
2185 absl::Span<const int64> broadcast_dimensions = {});
2186
2187 // Overload to call Or with 3 or more operands. As with `And`, we need the
2188 // following complicated overload set to handle the default arg in the `Or`
2189 // overload above.
Or(XlaOp op1,XlaOp op2,XlaOp op3)2190 inline XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3) {
2191 return Or(op1, Or(op2, op3));
2192 }
2193 template <typename... XlaOpTs>
Or(XlaOp op1,XlaOp op2,XlaOp op3,const XlaOpTs &...operands)2194 XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) {
2195 return Or(op1, Or(op2, Or(op3, operands...)));
2196 }
2197
2198 XlaOp Xor(XlaOp lhs, XlaOp rhs,
2199 absl::Span<const int64> broadcast_dimensions = {});
2200
2201 XlaOp Not(XlaOp operand);
2202
2203 XlaOp PopulationCount(XlaOp operand);
2204
2205 XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs,
2206 absl::Span<const int64> broadcast_dimensions = {});
2207 XlaOp ShiftRightArithmetic(XlaOp lhs, XlaOp rhs,
2208 absl::Span<const int64> broadcast_dimensions = {});
2209 XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs,
2210 absl::Span<const int64> broadcast_dimensions = {});
2211
2212 // Reduces an array among the provided dimensions, given "computation" as a
2213 // reduction operator.
2214 XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation,
2215 absl::Span<const int64> dimensions_to_reduce);
2216
2217 // Reduces several arrays simultaneously among the provided dimensions, given
2218 // "computation" as a reduction operator.
2219 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
2220 absl::Span<const XlaOp> init_values,
2221 const XlaComputation& computation,
2222 absl::Span<const int64> dimensions_to_reduce);
2223
2224 // Convenience wrapper around the above that reduces all the dimensions in the
2225 // operand shape.
2226 XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
2227 const XlaComputation& computation);
2228
2229 // Enqueues a windowed reduce instruction onto the computation.
2230 XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
2231 const XlaComputation& computation,
2232 absl::Span<const int64> window_dimensions,
2233 absl::Span<const int64> window_strides, Padding padding);
2234
2235 XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
2236 absl::Span<const XlaOp> init_values,
2237 const XlaComputation& computation,
2238 absl::Span<const int64> window_dimensions,
2239 absl::Span<const int64> window_strides, Padding padding);
2240
2241 // As ReduceWindow(), but the padding is given in the format
2242 // returned by MakePadding().
2243 XlaOp ReduceWindowWithGeneralPadding(
2244 XlaOp operand, XlaOp init_value, const XlaComputation& computation,
2245 absl::Span<const int64> window_dimensions,
2246 absl::Span<const int64> window_strides,
2247 absl::Span<const int64> base_dilations,
2248 absl::Span<const int64> window_dilations,
2249 absl::Span<const std::pair<int64, int64>> padding);
2250 XlaOp ReduceWindowWithGeneralPadding(
2251 absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
2252 const XlaComputation& computation,
2253 absl::Span<const int64> window_dimensions,
2254 absl::Span<const int64> window_strides,
2255 absl::Span<const int64> base_dilations,
2256 absl::Span<const int64> window_dilations,
2257 absl::Span<const std::pair<int64, int64>> padding);
2258
2259 // Returns the sum of the operand value within each subgroup of replicas. All
2260 // replicas supply one input to the sum and all replicas receive the resulting
2261 // sum for each subgroup.
2262 XlaOp CrossReplicaSum(XlaOp operand,
2263 absl::Span<const ReplicaGroup> replica_groups = {});
2264
2265 XlaOp AllGather(
2266 XlaOp operand, int64_t all_gather_dimension, int64_t shard_count,
2267 absl::Span<const ReplicaGroup> replica_groups = {},
2268 const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
2269 const absl::optional<Layout>& layout = absl::nullopt,
2270 const absl::optional<bool> use_global_device_ids = absl::nullopt);
2271
2272 // Enqueues an operation that do an AllReduce of the operand cross cores. Here
2273 // AllReduce means doing a reduction on the input operand cross cores and then
2274 // broadcasting the reduction result to those cores. The reduction function is
2275 // defined by `computation`, which should be a commutative computation on
2276 // scalars, e.g., add, min, or max. The way that AllReduce is applied is
2277 // configured by:
2278 //
2279 // - `replica_groups`: each ReplicaGroup contains a list of replica id. If
2280 // empty, all replicas belong to one group. Allreduce will be applied within
2281 // subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}}
2282 // means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
2283 //
2284 // - `channel_id`: for Allreduce nodes from different modules, if they have the
2285 // same channel_id, they will be 'AllReduce'd. If empty, AllReduce will not be
2286 // applied cross modules.
2287 //
2288 // - `shape_with_layout`: forces the layout of the AllReduce to the given
2289 // layout. This is used to guarantee the same layout for a group of AllReduce
2290 // ops compiled separately.
2291 XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
2292 absl::Span<const ReplicaGroup> replica_groups = {},
2293 const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
2294 const absl::optional<Shape>& shape_with_layout = absl::nullopt);
2295
2296 XlaOp ReduceScatter(
2297 XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension,
2298 int64_t shard_count, absl::Span<const ReplicaGroup> replica_groups = {},
2299 const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
2300 const absl::optional<Layout>& layout = absl::nullopt,
2301 const absl::optional<bool> use_global_device_ids = absl::nullopt);
2302
2303 // Enqueues an operation that do an Alltoall of the operand cross cores.
2304 // An optional `layout` can be specified to force the layout of the instruction.
2305 // This is used to guarantee the same layout for a group of AllToAll ops
2306 // compiled separately.
2307 XlaOp AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension,
2308 int64_t split_count,
2309 absl::Span<const ReplicaGroup> replica_groups = {},
2310 const absl::optional<Layout>& layout = absl::nullopt);
2311
2312 XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension,
2313 int64_t concat_dimension, int64_t split_count,
2314 absl::Span<const ReplicaGroup> replica_groups = {},
2315 const absl::optional<Layout>& layout = absl::nullopt);
2316
2317 // Enqueues an collective operation that sends and receives data cross replicas.
2318 //
2319 // - `source_target_pair`: a list of (source_replica_id, target_replica_id)
2320 // pairs. For each pair, the operand is sent from source replica to target
2321 // replica. Note that, 1) any two pairs should not have the same target replica
2322 // id, and they should not have the same source replica id; 2) if a replica id
2323 // is not a target in any pair, then the output on that replica is a tensor
2324 // consists of 0(s) with the same shape as the input.
2325 XlaOp CollectivePermute(
2326 XlaOp operand,
2327 const std::vector<std::pair<int64, int64>>& source_target_pairs);
2328
2329 // Enqueues an operation that returns the replica ID.
2330 XlaOp ReplicaId(XlaBuilder* builder);
2331
2332 // Enqueues an operation that scatters the `source` array to the selected
2333 // indices of each window.
2334 XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
2335 absl::Span<const int64> window_dimensions,
2336 absl::Span<const int64> window_strides, Padding padding,
2337 XlaOp source, XlaOp init_value,
2338 const XlaComputation& scatter);
2339
2340 // As SelectAndScatter(), but the padding is given in the format
2341 // returned by MakePadding().
2342 XlaOp SelectAndScatterWithGeneralPadding(
2343 XlaOp operand, const XlaComputation& select,
2344 absl::Span<const int64> window_dimensions,
2345 absl::Span<const int64> window_strides,
2346 absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
2347 XlaOp init_value, const XlaComputation& scatter);
2348
2349 // Enqueues an abs instruction onto the computation.
2350 XlaOp Abs(XlaOp operand);
2351
2352 // Enqueues a atan2 instruction onto the computation.
2353 XlaOp Atan2(XlaOp y, XlaOp x,
2354 absl::Span<const int64> broadcast_dimensions = {});
2355
2356 // Enqueues an exp instruction onto the computation.
2357 XlaOp Exp(XlaOp operand);
2358
2359 // Enqueues an expm1 instruction onto the computation.
2360 XlaOp Expm1(XlaOp operand);
2361
2362 // Enqueues a floor instruction onto the computation.
2363 XlaOp Floor(XlaOp operand);
2364
2365 // Enqueues a ceil instruction onto the computation.
2366 XlaOp Ceil(XlaOp operand);
2367
2368 // Enqueues a round instruction onto the computation, rounding to nearest even
2369 // with half-way cases rounding away from zero.
2370 XlaOp Round(XlaOp operand);
2371
2372 // Enqueues an log instruction (natural logarithm) onto the computation.
2373 XlaOp Log(XlaOp operand);
2374
2375 // Enqueues an log1p instruction (log(x+1)) onto the computation.
2376 XlaOp Log1p(XlaOp operand);
2377
2378 // Enqueues a logistic instruction onto the computation.
2379 XlaOp Logistic(XlaOp operand);
2380
2381 // Enqueues a sign instruction onto the computation.
2382 XlaOp Sign(XlaOp operand);
2383
2384 // Enqueues a count leading zeros instruction onto the computation.
2385 XlaOp Clz(XlaOp operand);
2386
2387 // Enqueues a cosine instruction onto the computation.
2388 XlaOp Cos(XlaOp operand);
2389
2390 // Enqueues a sine instruction onto the computation.
2391 XlaOp Sin(XlaOp operand);
2392
2393 // Enqueues a tanh instruction onto the computation.
2394 XlaOp Tanh(XlaOp operand);
2395
2396 // Enqueues a real-part instruction onto the computation.
2397 XlaOp Real(XlaOp operand);
2398
2399 // Enqueues an imaginary-part instruction onto the computation.
2400 XlaOp Imag(XlaOp operand);
2401
2402 // Enqueues a sqrt computation onto the computation.
2403 XlaOp Sqrt(XlaOp operand);
2404
2405 // Enqueues a cbrt computation onto the computation.
2406 XlaOp Cbrt(XlaOp operand);
2407
2408 // Enqueues a rsqrt computation onto the computation.
2409 XlaOp Rsqrt(XlaOp operand);
2410
2411 // Enqueues a lhs^rhs computation onto the computation.
2412 XlaOp Pow(XlaOp lhs, XlaOp rhs,
2413 absl::Span<const int64> broadcast_dimensions = {});
2414
2415 // Enqueues an operator that tests if the operand's values are finite, i.e., not
2416 // +/-Inf or NaN. Returns an array of booleans with the same shape where
2417 // entries are true iff the corresponding entry was not infinite or NaN.
2418 //
2419 // Defined only for real-valued (i.e. not complex) floating-point types; raises
2420 // an error for other types.
2421 //
2422 // See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h.
2423 XlaOp IsFinite(XlaOp operand);
2424
2425 // Enqueues an iota operation onto the computation.
2426 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64_t iota_dimension);
2427
2428 // Enqueues a rank-1 iota operation onto the computation.
2429 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size);
2430
2431 // Enqueues a convert instruction onto the computation that changes the
2432 // element type of the operand array to primitive_type.
2433 XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
2434
2435 // Enqueues a no-op instruction onto the computation that changes
2436 // the element type of the operand array to primitive_type. The
2437 // bit-widths of the source and destination element types must be
2438 // identical.
2439 XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
2440
2441 // Enqueues a negate instruction onto the computation.
2442 XlaOp Neg(XlaOp operand);
2443
2444 // Enqueues a transpose instruction onto the computation.
2445 XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
2446
2447 // Enqueues a reverse instruction onto the computation. The order of the
2448 // elements in the given dimensions is reversed (i.e., the element at index i
2449 // is moved to index dimension_size - 1 - i).
2450 XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
2451
2452 // Enqueues a sort instruction onto the computation, using 'comparator' for
2453 // comparisons. 'comparator' needs to define a strict weak order. 'is_stable'
2454 // determines whether the stable sorting should be used.
2455 // If only one operand is provided:
2456 // * If the operand is a rank-1 tensor (an array), the result is a sorted array.
2457 // The resulting sorting order has the property that for all index positions
2458 // i, j with i < j, either
2459 // comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or
2460 // comparator(value[i], value[j]) = true.
2461 // * If the operand has higher rank, the operand is sorted along the provided
2462 // dimension. For example, for a rank-2 tensor (a matrix), a dimension value
2463 // of 0 will independently sort every column, and a dimension value of 1 will
2464 // independently sort each row. If no dimension number is provided, then the
2465 // last dimension is chosen by default. For the dimension which is sorted, the
2466 // same sorting order applies as in the rank-1 case.
2467 //
2468 // If more than one operand is provided:
2469 // * All operands must be tensors with the same dimensions. The element types of
2470 // the tensors may be different.
2471 // * The result is a tuple that consists of the operands in sorted order (along
2472 // the provided dimension, as above). The same permutation as implied by the
2473 // comparison computation is applied to all operand tensors. When comparing
2474 // two index positions, 'comparator' is called with 2 * n scalar parameters,
2475 // where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at
2476 // two index positions.
2477 // Default comparator computations can be found in lib/comparators.h
2478 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
2479 int64_t dimension = -1, bool is_stable = false);
2480
2481 // Enqueues a clamp instruction onto the computation.
2482 XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
2483
2484 // Enqueues a map instruction onto the computation.
2485 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
2486 const XlaComputation& computation, absl::Span<const int64> dimensions,
2487 absl::Span<const XlaOp> static_operands = {});
2488
2489 // Enqueues a N(mu, sigma) random number generation instruction onto the
2490 // computation.
2491 XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
2492
2493 // Enqueues a U(a, b) random number generation instruction onto the
2494 // computation. Returns values in the semi-open interval [a, b).
2495 XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
2496
2497 // Enqueues a B(initial_state) random bit generation instruction onto the
2498 // computation. Returns the new key and random bits with the specified shape.
2499 XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
2500 const Shape& shape);
2501
2502 // Enqueues a while node onto the computation.
2503 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
2504 XlaOp init);
2505
2506 // Enqueues a conditional node onto the computation.
2507 XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
2508 const XlaComputation& true_computation, XlaOp false_operand,
2509 const XlaComputation& false_computation);
2510
2511 // Enqueues either a predicated (if/else) or indexed (switch/case/default)
2512 // conditional node onto the computation. N >= 1 branch_computations and
2513 // branch_operands are matched by index. branch_index selects the branch that
2514 // will be executed. Out of range branch_index uses the N-1'th
2515 // branch_computation as default.
2516 XlaOp Conditional(XlaOp branch_index,
2517 absl::Span<const XlaComputation* const> branch_computations,
2518 absl::Span<const XlaOp> branch_operands);
2519
2520 // Enqueues a ReducePrecision node onto the computation.
2521 XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
2522 const int mantissa_bits);
2523
2524 // Enqueues a Gather node onto the computation.
2525 XlaOp Gather(XlaOp input, XlaOp start_indices,
2526 const GatherDimensionNumbers& dimension_numbers,
2527 absl::Span<const int64> slice_sizes,
2528 bool indices_are_sorted = false);
2529
2530 // Enqueues a Scatter node onto the computation.
2531 XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
2532 const XlaComputation& update_computation,
2533 const ScatterDimensionNumbers& dimension_numbers,
2534 bool indices_are_sorted = false, bool unique_indices = false);
2535
2536 // Enqueues a Send node onto the computation for device-to-device
2537 // communication. This operation sends the given operand to
2538 // a Recv instruction in a different computation that shares the same channel
2539 // handle.
2540 void Send(XlaOp operand, const ChannelHandle& handle);
2541
2542 // Variant of Send which takes a token-shaped operand and produces a
2543 // token-shaped value. Tokens are used for ordering side-effecting operations.
2544 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2545 XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle);
2546
2547 // Enqueues a Recv node onto the computation for device-to-device
2548 // communication. The data comes from a Send instruction in a different
2549 // computation that shares the same channel handle and its shape must be the
2550 // same as the given shape.
2551 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
2552 const ChannelHandle& handle);
2553
2554 // Variant of Recv which takes a token-shaped operand and produces a two-element
2555 // tuple containing the data value and a token-shaped value. Tokens are used
2556 // for ordering side-effecting operations.
2557 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2558 XlaOp RecvWithToken(XlaOp token, const Shape& shape,
2559 const ChannelHandle& handle);
2560
2561 // Enqueues a Send node which transfers data from the device to the host. The
2562 // 'shape_with_layout' argument defines the layout of the data transferred; its
2563 // shape must be compatible with the shape of the operand. The operand must be
2564 // array-shaped.
2565 // TODO(b/111544877): Support tuple shapes.
2566 XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout,
2567 const ChannelHandle& handle);
2568
2569 // Enqueues a Recv node which transfers data from the host to the device. The
2570 // given shape must contain a layout and must be an array.
2571 // TODO(b/111544877): Support tuple shapes.
2572 XlaOp RecvFromHost(XlaOp token, const Shape& shape,
2573 const ChannelHandle& handle);
2574
2575 // Enqueues an operation (AfterAll) with no operands that produces a
2576 // token-shaped value. Tokens are used for ordering side-effecting operations.
2577 // This is a separate method from AfterAll to facility the removal of
2578 // operand-less AfterAll instructions.
2579 // TODO(b/110532604): Remove this function when all tokens are derived from a
2580 // single token generated or passed into the entry computation.
2581 XlaOp CreateToken(XlaBuilder* builder);
2582
2583 // Enqueues an AfterAll instruction which produces a token-shaped value and
2584 // takes a variadic number of token-shaped operands. The number of operands must
2585 // be greater than zero. Used for joining tokens.
2586 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
2587
2588 // Normalizes operand across spatial and batch dimensions for each feature.
2589 //
2590 // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
2591 // is the normalized result and batch_mean and batch_var are the mean and
2592 // variance, respectively, across batch for the operand.
2593 XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, float epsilon,
2594 int64_t feature_index);
2595
2596 // Normalizes operand across spatial and batch dimensions for each feature.
2597 //
2598 // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
2599 // computing `mean` and `variance` for each batch inside the operation. It
2600 // uses the input `mean` and `variance` instead as estimated values. The
2601 // purpose of this op is to reduce latency in inference, hence the name
2602 // `BatchNormInference`.
2603 //
2604 // The output has the same shape as `operand`, and contains the normalized
2605 // values for each batch.
2606 XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean,
2607 XlaOp variance, float epsilon, int64_t feature_index);
2608
2609 // Calculates the gradients of a batch norm op.
2610 //
2611 // The inputs `batch_mean` and `batch_var` represent the mean and variance
2612 // across the batch.
2613 //
2614 // Returns a tuple of three elements:
2615 // - grad_operand: Gradient with respect to input `operand`
2616 // - grad_offset: Gradient with respect to input `offset`
2617 // - grad_scale: Gradient with respect to input `scale`
2618 XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
2619 XlaOp batch_var, XlaOp grad_output, float epsilon,
2620 int64_t feature_index);
2621
2622 // Returns the size of the given dimension of the operand. The operand must be
2623 // array shaped.
2624 XlaOp GetDimensionSize(XlaOp operand, int64_t dimension);
2625
2626 XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension);
2627
2628 // Returns the same op but with dynamic dimension removed.
2629 XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension);
2630
2631 // Implementation details below this point.
2632 //
2633
2634 // Free function template implementations.
2635
2636 template <typename NativeT>
ConstantR0(XlaBuilder * builder,NativeT value)2637 XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
2638 return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
2639 }
2640
2641 template <typename NativeT>
ConstantR1(XlaBuilder * builder,absl::Span<const NativeT> values)2642 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
2643 BorrowingLiteral literal(
2644 reinterpret_cast<const char*>(values.begin()),
2645 ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
2646 {static_cast<int64>(values.size())}));
2647 return ConstantLiteral(builder, literal);
2648 }
2649
2650 template <typename NativeT>
ConstantR1(XlaBuilder * builder,int64_t length,NativeT value)2651 XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value) {
2652 Literal literal(ShapeUtil::MakeShape(
2653 primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
2654 literal.PopulateWithValue(value);
2655 return ConstantLiteral(builder, literal);
2656 }
2657
ConstantR1(XlaBuilder * builder,const tensorflow::core::Bitmap & values)2658 inline XlaOp ConstantR1(XlaBuilder* builder,
2659 const tensorflow::core::Bitmap& values) {
2660 return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
2661 }
2662
2663 template <typename NativeT>
ConstantR2(XlaBuilder * builder,std::initializer_list<std::initializer_list<NativeT>> values)2664 XlaOp ConstantR2(XlaBuilder* builder,
2665 std::initializer_list<std::initializer_list<NativeT>> values) {
2666 return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
2667 }
2668
2669 template <typename NativeT>
ConstantFromArrayWithLayout(XlaBuilder * builder,const Array<NativeT> & values,const Layout & layout)2670 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
2671 const Array<NativeT>& values,
2672 const Layout& layout) {
2673 return ConstantLiteral(
2674 builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
2675 }
2676
2677 template <typename NativeT>
ConstantFromArray(XlaBuilder * builder,const Array<NativeT> & values)2678 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
2679 return ConstantLiteral(builder,
2680 LiteralUtil::CreateFromArray<NativeT>(values));
2681 }
2682
2683 template <typename NativeT>
ConstantR2FromArray2DWithLayout(XlaBuilder * builder,const Array2D<NativeT> & values,const Layout & layout)2684 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
2685 const Array2D<NativeT>& values,
2686 const Layout& layout) {
2687 return ConstantLiteral(
2688 builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
2689 }
2690
2691 template <typename NativeT>
ConstantR2FromArray2D(XlaBuilder * builder,const Array2D<NativeT> & values)2692 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
2693 const Array2D<NativeT>& values) {
2694 return ConstantLiteral(builder,
2695 LiteralUtil::CreateR2FromArray2D<NativeT>(values));
2696 }
2697
2698 template <typename NativeT>
ConstantR3FromArray3DWithLayout(XlaBuilder * builder,const Array3D<NativeT> & values,const Layout & layout)2699 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
2700 const Array3D<NativeT>& values,
2701 const Layout& layout) {
2702 return ConstantLiteral(
2703 builder,
2704 LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
2705 }
2706
2707 template <typename NativeT>
ConstantR3FromArray3D(XlaBuilder * builder,const Array3D<NativeT> & values)2708 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
2709 const Array3D<NativeT>& values) {
2710 return ConstantFromArray(builder, values);
2711 }
2712
2713 template <typename NativeT>
ConstantR4FromArray4DWithLayout(XlaBuilder * builder,const Array4D<NativeT> & values,const Layout & layout)2714 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
2715 const Array4D<NativeT>& values,
2716 const Layout& layout) {
2717 return ConstantFromArrayWithLayout(builder, values, layout);
2718 }
2719
2720 template <typename NativeT>
ConstantR4FromArray4D(XlaBuilder * builder,const Array4D<NativeT> & values)2721 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
2722 const Array4D<NativeT>& values) {
2723 return ConstantFromArray(builder, values);
2724 }
2725
2726 } // namespace xla
2727
2728 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
2729