• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
18 
19 #include <cstdint>
20 #include <deque>
21 #include <functional>
22 #include <map>
23 #include <memory>
24 #include <optional>
25 #include <ostream>
26 #include <string>
27 #include <type_traits>
28 #include <utility>
29 #include <vector>
30 
31 #include "absl/container/flat_hash_map.h"
32 #include "absl/container/flat_hash_set.h"
33 #include "absl/strings/string_view.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/client/padding.h"
36 #include "tensorflow/compiler/xla/client/xla_computation.h"
37 #include "tensorflow/compiler/xla/comparison_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/literal_util.h"
40 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
41 #include "tensorflow/compiler/xla/service/hlo.pb.h"
42 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
43 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
44 #include "tensorflow/compiler/xla/shape_util.h"
45 #include "tensorflow/compiler/xla/status_macros.h"
46 #include "tensorflow/compiler/xla/statusor.h"
47 #include "tensorflow/compiler/xla/types.h"
48 #include "tensorflow/compiler/xla/xla_data.pb.h"
49 #include "tensorflow/core/platform/stacktrace.h"
50 
51 namespace xla {
52 
53 class XlaBuilder;
54 class XlaOp;
55 class HloInstruction;
56 
57 namespace internal {
58 
59 struct XlaBuilderFriend {
60   static XlaOp BuildAddDependency(XlaBuilder* builder, XlaOp operand,
61                                   XlaOp token, const Shape& shape);
62 
63   static XlaOp BuildFusion(XlaBuilder* builder,
64                            absl::Span<const XlaOp> operands,
65                            absl::string_view fusion_kind,
66                            const XlaComputation& fused_computation);
67 
68   static XlaOp BuildBitcast(XlaBuilder* builder, XlaOp operand,
69                             const Shape& shape);
70 
71   static XlaOp BuildPartitionId(XlaBuilder* builder, const Shape& shape);
72 
73   static XlaOp BuildDomain(XlaBuilder* builder, XlaOp operand,
74                            const OpSharding entry, const OpSharding exit,
75                            const Shape& shape);
76 
77   static XlaOp BuildRngGetAndUpdateState(XlaBuilder* builder, int64_t delta,
78                                          const Shape& shape);
79 
80   static HloInstructionProto* GetInstruction(XlaOp op);
81   static HloInstructionProto* GetInstructionByHandle(XlaBuilder* builder,
82                                                      int64_t handle);
83 };
84 
85 }  // namespace internal
86 
87 // This represents an instruction that has been enqueued using the XlaBuilder.
88 // This is used to pass to subsequent computations that depends upon the
89 // instruction as an operand.
90 class XlaOp {
91  public:
XlaOp()92   XlaOp() : handle_(-1), builder_(nullptr) {
93     static_assert(std::is_trivially_destructible<XlaOp>::value,
94                   "XlaOp should be trivially destructible");
95   }
96   ~XlaOp() = default;
97 
98   XlaOp(const XlaOp& other) = default;
99   XlaOp& operator=(const XlaOp& other) = default;
100 
101   // Precondition: !IsUninitialized().
102   //
103   // It's very common to do foo.builder()->bar().  Without this precondition, if
104   // foo.builder() is null, the call to bar will segfault at some point possibly
105   // deep in the callstack when we finally dereference `this`.  The precondition
106   // lets us avoid this tricky-to-debug problem.
builder()107   XlaBuilder* builder() const {
108     CHECK(builder_ != nullptr);
109     return builder_;
110   }
111 
112   // Returns true if the XlaOp represents valid, non-erroneous value.
valid()113   bool valid() const { return handle_ >= 0; }
114 
115   // Returns true if the XlaOp was created by the XlaOp() constructor and
116   // not returned by a builder.
IsUninitialized()117   bool IsUninitialized() const { return builder_ == nullptr; }
118 
IsIdenticalTo(XlaOp rhs)119   bool IsIdenticalTo(XlaOp rhs) const {
120     return handle_ == rhs.handle_ && builder_ == rhs.builder_;
121   }
122 
123   friend std::ostream& operator<<(std::ostream& out, XlaOp op) {
124     out << op.handle();
125     return out;
126   }
127 
128  private:
XlaOp(XlaBuilder * builder)129   explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {}
XlaOp(int64_t handle,XlaBuilder * builder)130   XlaOp(int64_t handle, XlaBuilder* builder)
131       : handle_(handle), builder_(builder) {}
132 
handle()133   int64_t handle() const { return handle_; }
134 
135   friend class XlaBuilder;
136   friend class ValueInference;
137   friend class MlirHloBuilder;
138   friend struct internal::XlaBuilderFriend;
139 
140   // < 0 means "invalid handle".
141   int64_t handle_;
142 
143   // Not owned. Non-null for any handle returned by XlaBuilder, even if the
144   // handle is invalid.
145   XlaBuilder* builder_;
146 };
147 
148 // Arithmetic operator overloads for the XlaOp type.
149 XlaOp operator-(XlaOp x);
150 XlaOp operator+(XlaOp x, XlaOp y);
151 XlaOp operator-(XlaOp x, XlaOp y);
152 XlaOp operator*(XlaOp x, XlaOp y);
153 XlaOp operator/(XlaOp x, XlaOp y);
154 XlaOp operator%(XlaOp x, XlaOp y);
155 
156 // Bitwise operator overloads for the XlaOp type.
157 XlaOp operator~(XlaOp x);
158 XlaOp operator&(XlaOp x, XlaOp y);
159 XlaOp operator|(XlaOp x, XlaOp y);
160 XlaOp operator^(XlaOp x, XlaOp y);
161 XlaOp operator<<(XlaOp x, XlaOp y);
162 // Performs a right arithmetic shift if 'x' is a signed type, otherwise performs
163 // a right logical shift.
164 XlaOp operator>>(XlaOp x, XlaOp y);
165 
166 // We don't overload the relational operators (==, !=, <, <=, >, >=) because the
167 // semantics might be surprising since their result types are usually 'bool'.
168 // Further programmers may expect == to be a structural equality.
169 // We also choose not to overload any of the mutating operators (e.g., +=, -=)
170 // because the semantics might be misleading — XLA computations are immutable.
171 
172 // A convenient interface for building up computations.
173 //
174 // Thread-compatible.
175 class XlaBuilder {
176  public:
177   // computation_name: name to use for the built computation.
178   explicit XlaBuilder(const std::string& computation_name);
179 
180   XlaBuilder(const XlaBuilder&) = delete;
181   XlaBuilder& operator=(const XlaBuilder&) = delete;
182 
183   virtual ~XlaBuilder();
184 
185   // Returns the computation name.
name()186   const std::string& name() const { return name_; }
187 
188   // Sets OpMetadata that will be added to all instructions until cleared.
189   //
190   // OpMetadata is often applied to a series of XLA HLO instructions. As a
191   // result, OpMetadata is set on the computation builder. All subsequent
192   // instructions generated via this computation builder will have the same
193   // OpMetadata attached until a call to ClearOpMetadata.
SetOpMetadata(OpMetadata metadata)194   void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); }
195 
196   // Swaps the passed op metadata with the ones currently set.
197   //
198   // Returns the old op metadata.
SwapOpMetadata(OpMetadata metadata)199   OpMetadata SwapOpMetadata(OpMetadata metadata) {
200     OpMetadata old_metadata = std::move(metadata_);
201     metadata_ = std::move(metadata);
202     return old_metadata;
203   }
204 
205   // Similar to SetOpMetadata, but only set the metadata for the next op.
SetOneShotOpMetadata(OpMetadata metadata)206   void SetOneShotOpMetadata(OpMetadata metadata) {
207     one_shot_metadata_ = std::move(metadata);
208   }
209 
210   // Clears the HloMetadata state.
ClearOpMetadata()211   void ClearOpMetadata() { metadata_.Clear(); }
212 
213   // Sets an OpSharding that will be attached to all instructions until cleared.
SetSharding(const OpSharding & sharding)214   void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
215 
216   // Sets the FrontendAttributes that will be added to all instructions until
217   // cleared.
218   //
219   // FrontendAttributes are often applied to a series of XLA HLO instructions.
220   // As a result they are set on the computation builder and all the
221   // instructions generated via the computation builder will have the same
222   // frontend attributes attached to them.
SetFrontendAttributes(const FrontendAttributes & frontend_attributes)223   void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) {
224     frontend_attributes_ = frontend_attributes;
225   }
226 
227   // Swap the passed FrontendAttributes with the ones currently set.
228   //
229   // Return the old attributes.
SwapFrontendAttributes(const FrontendAttributes & frontend_attributes)230   FrontendAttributes SwapFrontendAttributes(
231       const FrontendAttributes& frontend_attributes) {
232     FrontendAttributes old_attributes = std::move(frontend_attributes_);
233     frontend_attributes_ = frontend_attributes;
234     return old_attributes;
235   }
236 
237   // Returns the FrontendAttributes that will be attached to all instructions.
frontend_attributes()238   const FrontendAttributes& frontend_attributes() const {
239     return frontend_attributes_;
240   }
241 
242   // Clears all the frontend attributes.
ClearFrontendAttributes()243   void ClearFrontendAttributes() { frontend_attributes_.Clear(); }
244 
245   // Clears the sharding. Ops will be sharded according to the default placement
246   // policy.
ClearSharding()247   void ClearSharding() { sharding_ = std::nullopt; }
248 
249   // Returns the OpSharding that will be attached to all instructions.
sharding()250   const std::optional<OpSharding>& sharding() const { return sharding_; }
251 
252   // Sets the builder to a mode where it will die immediately when an error is
253   // encountered, rather than producing it in a deferred fashion when Build() is
254   // called (which is the default).
set_die_immediately_on_error(bool enabled)255   void set_die_immediately_on_error(bool enabled) {
256     die_immediately_on_error_ = enabled;
257   }
258 
259   // Default dimension numbers used for a 2D convolution.
260   static constexpr int64_t kConvBatchDimension = 0;
261   static constexpr int64_t kConvFeatureDimension = 1;
262   static constexpr int64_t kConvFirstSpatialDimension = 2;
263   static constexpr int64_t kConvSecondSpatialDimension = 3;
264   static constexpr int64_t kConvKernelOutputDimension = 0;
265   static constexpr int64_t kConvKernelInputDimension = 1;
266   static constexpr int64_t kConvKernelFirstSpatialDimension = 2;
267   static constexpr int64_t kConvKernelSecondSpatialDimension = 3;
268 
269   // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
270   // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
271   // the kernel operand
272   // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
273   static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
274       int num_spatial_dims = 2);
275 
276   // Returns an error if the convolution dimension numbers have conflicts.
277   static Status Validate(const ConvolutionDimensionNumbers& dnum);
278 
279   // Returns a new XlaBuilder whose resultant Computation is used only by this
280   // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
281   // behavior as the parent.
282   std::unique_ptr<XlaBuilder> CreateSubBuilder(
283       const std::string& computation_name);
284 
285   // Builds the computation with the requested operations, or returns a non-ok
286   // status. Note that all ops that have been enqueued will be moved to the
287   // computation being returned. The root of the computation will be the last
288   // added operation.
289   //
290   // `remove_dynamic_dimensions` tells the builder whether to remove the
291   // dynamic dimensions information in all ops.
292   //
293   // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the
294   // dynamic dimensions information when XLA backend can handle dynamic
295   // dimensions.
296   StatusOr<XlaComputation> Build(bool remove_dynamic_dimensions = false);
297 
298   // Overload of Build which specifies a particular root instruction for the
299   // computation.
300   StatusOr<XlaComputation> Build(XlaOp root,
301                                  bool remove_dynamic_dimensions = false);
302 
303   // Builds the computation with the requested operations, or notes an error in
304   // the parent XlaBuilder and returns an empty computation if building failed.
305   // This function is intended to be used where the returned XlaComputation is
306   // only used by the parent XlaBuilder and hence further operation on the
307   // returned XlaComputation will simply be error'ed out if an error occurred
308   // while building this computation. If the built computation is to be used by
309   // a XlaBuilder other than the parent XlaBuilder then Build() should be used
310   // instead.
311   XlaComputation BuildAndNoteError();
312 
313   // Returns a subgraph that roots on the given root. If the root is not a
314   // compile-time constant (see `IsConstant`), returns an error.
315   //
316   // This will copy the needed ops/computations to the subgraph.
317   StatusOr<XlaComputation> BuildConstantSubGraph(
318       XlaOp root_op, bool dynamic_dimension_is_uint_max = false);
319 
320   // Returns the first error that was encountered while building the
321   // computation. When an error is encountered, by default we return a vacuous
322   // XlaOp and inform the user of the error that occurred while
323   // building the computation when they make a final call to Build().
324   //
325   // See also set_die_immediately_on_error().
first_error()326   Status first_error() const { return first_error_; }
327 
328   // Returns the current status of the builder, complete with the stack trace
329   // information.
330   Status GetCurrentStatus() const;
331 
332   // Returns the shape of the given op.
333   StatusOr<Shape> GetShape(XlaOp op) const;
334 
335   // Returns the shape of the given op.
336   virtual StatusOr<const Shape*> GetShapePtr(XlaOp op) const;
337 
338   // Returns the (inferred) result for the current computation's shape. This
339   // assumes the root instruction is the last added instruction.
340   StatusOr<ProgramShape> GetProgramShape() const;
341 
342   // Returns the (inferred) result for the current computation's shape using the
343   // given operation as the root.
344   StatusOr<ProgramShape> GetProgramShape(XlaOp root) const;
345 
346   // Reports an error to the builder, by
347   // * storing it internally and capturing a backtrace if it's the first error
348   //   (this deferred value will be produced on the call to
349   //    Build()/GetShape()/...)
350   // * dying if die_immediately_on_error_ is true.
351   // Returns an XlaOp with an invalid handle but a valid builder. This value can
352   // be returned in place of a value in APIs that return an XlaOp.
353   XlaOp ReportError(const Status& error);
354 
355   // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
356   // If the Status was an error, reports the error to builder and returns an
357   // invalid XlaOp handle.
358   XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
359 
360   // A helper function that runs a function that returns a StatusOr<XlaOp> and
361   // returns an XlaOp.
362   XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
363 
364   // Returns true if 'operand' is a compile-time constant. A compile-time
365   // constant does not depend on any parameters, or on stateful operators such
366   // as `RngNormal` or `Infeed`.
367   //
368   // This tests whether a computation is a compile-time constant without
369   // evaluating the computation.
370   StatusOr<bool> IsConstant(XlaOp operand) const;
371 
372   // Sets up binding which indicates that the `target_dim_num` in the subshape
373   // `target_param_index` of parameter `target_param_num` is a dynamic dimension
374   // and its real dynamic size is represented by `dynamic_param_index` in
375   // parameter `dynamic_param_num`.
376   //
377   // Note that this should be called before the dynamic parameters are used to
378   // create other operations, otherwise created operations won't have the
379   // dynamic dimensions information.
380   //
381   // TODO(b/119520625): Remove this API once we have more dynamic shape infra
382   // ready.
383   ABSL_DEPRECATED("Use SetDimensionSize to set a dynamic dimension.")
384   Status SetDynamicBinding(int64_t dynamic_size_param_num,
385                            ShapeIndex dynamic_size_param_index,
386                            int64_t target_param_num,
387                            ShapeIndex target_param_index,
388                            int64_t target_dim_num);
389 
390   // Adds a new input/output alias. Since the input/output shape information are
391   // not available until the computation is built, and eventual error in the
392   // arguments of this API will be detected only at computation Build() time.
393   //
394   // Note: Except when 'must-alias' is true, alias is assumed to be 'may-alias'
395   // and only donated buffer at runtime will be aliased with output. If a buffer
396   // is not donated at runtime, a copy will be inserted by XLA to prevent buffer
397   // clobbering.
398   void SetUpAlias(const ShapeIndex& output_index, int64_t param_number,
399                   const ShapeIndex& param_index,
400                   HloInputOutputAliasConfig::AliasKind kind =
401                       HloInputOutputAliasConfig::AliasKind::kMayAlias) {
402     input_output_aliases_.push_back(
403         {output_index, param_number, param_index, kind});
404   }
405 
406   // Describes an input/output alias as inserted by the SetUpAlias() API.
407   struct InputOutputAlias {
408     // Specifies the index of the aliased buffer in the result tuple.
409     ShapeIndex output_index;
410     // Specifies the parameter containing the buffer to be aliased.
411     int64_t param_number;
412     // Specifies the index of the aliased buffer in the parameter
413     ShapeIndex param_index;
414     // Specifies if the alias is a must alias or may alias.
415     HloInputOutputAliasConfig::AliasKind kind;
416   };
417 
418   // Looks up the HloInstruction and sets the frontend attribute "attribute" to
419   // "value".
420   //
421   // If the attribute already existed then its value is updated.
422   //
423   // Note: the attribute is only added to the HloInstruction, not to the
424   // builder.
425   Status SetInstructionFrontendAttribute(XlaOp op, std::string attribute,
426                                          std::string value);
427 
428   // Returns shapes for the operands.
429   StatusOr<std::vector<Shape>> GetOperandShapes(
430       absl::Span<const XlaOp> operands) const;
431 
432   // Converts the op to string for the ease of debugging.
433   std::string OpToString(XlaOp op) const;
434 
435  private:
436   void ToStringHelper(std::string* out, int ident, int64_t op_handle) const;
437 
438   // Build helper which takes the id of the root operation..
439   StatusOr<XlaComputation> Build(int64_t root_id,
440                                  bool remove_dynamic_dimensions);
441 
442   // Description for the methods below can be found in the corresponding public
443   // functions section in this file.
444 
445   XlaOp Parameter(int64_t parameter_number, const Shape& shape,
446                   const std::string& name,
447                   const std::vector<bool>& replicated_at_leaf_buffers);
Parameter(int64_t parameter_number,const Shape & shape,const std::string & name)448   XlaOp Parameter(int64_t parameter_number, const Shape& shape,
449                   const std::string& name) {
450     std::vector<bool> empty_bools;
451     return Parameter(parameter_number, shape, name, empty_bools);
452   }
453 
454   virtual XlaOp ConstantLiteral(const LiteralSlice& literal);
455 
456   XlaOp Broadcast(XlaOp operand, absl::Span<const int64_t> broadcast_sizes);
457 
458   XlaOp BroadcastInDim(XlaOp operand,
459                        const absl::Span<const int64_t> out_dim_size,
460                        const absl::Span<const int64_t> broadcast_dimensions);
461 
462   XlaOp Pad(XlaOp operand, XlaOp padding_value,
463             const PaddingConfig& padding_config);
464   XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
465                  int64_t pad_lo, int64_t pad_hi);
466 
467   virtual StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
468                                       XlaOp padding_value,
469                                       const PaddingConfig& padding_config);
470 
471   XlaOp Reshape(XlaOp operand, absl::Span<const int64_t> dimensions,
472                 absl::Span<const int64_t> new_sizes,
473                 int64_t inferred_dimension = -1);
474 
475   XlaOp Reshape(XlaOp operand, absl::Span<const int64_t> new_sizes,
476                 int64_t inferred_dimension = -1);
477 
478   XlaOp Reshape(const Shape& shape, XlaOp operand,
479                 int64_t inferred_dimension = -1);
480 
481   XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
482                        absl::Span<const int64_t> new_size_bounds,
483                        const std::vector<bool>& dims_are_dynamic);
484 
485   XlaOp Collapse(XlaOp operand, absl::Span<const int64_t> dimensions);
486 
487   XlaOp Slice(XlaOp operand, absl::Span<const int64_t> start_indices,
488               absl::Span<const int64_t> limit_indices,
489               absl::Span<const int64_t> strides);
490   virtual StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand,
491                                         absl::Span<const int64_t> start_indices,
492                                         absl::Span<const int64_t> limit_indices,
493                                         absl::Span<const int64_t> strides);
494   virtual XlaOp SliceInDim(XlaOp operand, int64_t start_index,
495                            int64_t limit_index, int64_t stride, int64_t dimno);
496 
497   XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
498                      absl::Span<const int64_t> slice_sizes);
499   virtual StatusOr<XlaOp> DynamicSliceInternal(
500       const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
501       absl::Span<const int64_t> slice_sizes);
502 
503   XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
504                            absl::Span<const XlaOp> start_indices);
505   virtual StatusOr<XlaOp> DynamicUpdateSliceInternal(
506       const Shape& shape, XlaOp operand, XlaOp update,
507       absl::Span<const XlaOp> start_indices);
508 
509   XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64_t dimension);
510   virtual StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
511                                               absl::Span<const XlaOp> operands,
512                                               int64_t dimension);
513 
514   XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
515 
516   XlaOp Tuple(absl::Span<const XlaOp> elements);
517   virtual StatusOr<XlaOp> TupleInternal(const Shape& shape,
518                                         absl::Span<const XlaOp> elements);
519 
520   XlaOp GetTupleElement(XlaOp tuple_data, int64_t index);
521   virtual StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape,
522                                                   XlaOp tuple_data,
523                                                   int64_t index);
524 
525   XlaOp Dot(XlaOp lhs, XlaOp rhs,
526             const PrecisionConfig* precision_config = nullptr,
527             std::optional<PrimitiveType> preferred_element_type = std::nullopt);
528 
529   XlaOp DotGeneral(
530       XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
531       const PrecisionConfig* precision_config = nullptr,
532       std::optional<PrimitiveType> preferred_element_type = std::nullopt);
533 
534   XlaOp Conv(
535       XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
536       Padding padding, int64_t feature_group_count = 1,
537       int64_t batch_group_count = 1,
538       const PrecisionConfig* precision_config = nullptr,
539       std::optional<PrimitiveType> preferred_element_type = std::nullopt);
540 
541   XlaOp ConvWithGeneralPadding(
542       XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
543       absl::Span<const std::pair<int64_t, int64_t>> padding,
544       int64_t feature_group_count = 1, int64_t batch_group_count = 1,
545       const PrecisionConfig* precision_config = nullptr,
546       std::optional<PrimitiveType> preferred_element_type = std::nullopt);
547 
548   XlaOp ConvWithGeneralDimensions(
549       XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
550       Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
551       int64_t feature_group_count = 1, int64_t batch_group_count = 1,
552       const PrecisionConfig* precision_config = nullptr,
553       std::optional<PrimitiveType> preferred_element_type = std::nullopt);
554 
555   XlaOp ConvGeneral(
556       XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
557       absl::Span<const std::pair<int64_t, int64_t>> padding,
558       const ConvolutionDimensionNumbers& dimension_numbers,
559       int64_t feature_group_count = 1, int64_t batch_group_count = 1,
560       const PrecisionConfig* precision_config = nullptr,
561       std::optional<PrimitiveType> preferred_element_type = std::nullopt);
562 
563   XlaOp ConvGeneralDilated(
564       XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
565       absl::Span<const std::pair<int64_t, int64_t>> padding,
566       absl::Span<const int64_t> lhs_dilation,
567       absl::Span<const int64_t> rhs_dilation,
568       const ConvolutionDimensionNumbers& dimension_numbers,
569       int64_t feature_group_count = 1, int64_t batch_group_count = 1,
570       const PrecisionConfig* precision_config = nullptr,
571       std::optional<PrimitiveType> preferred_element_type = std::nullopt,
572       std::optional<std::vector<bool>> window_reversal = std::nullopt);
573 
574   XlaOp DynamicConvForward(
575       XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
576       absl::Span<const std::pair<int64_t, int64_t>> padding,
577       absl::Span<const int64_t> lhs_dilation,
578       absl::Span<const int64_t> rhs_dilation,
579       const ConvolutionDimensionNumbers& dimension_numbers,
580       int64_t feature_group_count, int64_t batch_group_count,
581       const PrecisionConfig* precision_config, PaddingType padding_type,
582       std::optional<PrimitiveType> preferred_element_type = std::nullopt);
583 
584   XlaOp DynamicConvInputGrad(
585       XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
586       absl::Span<const int64_t> window_strides,
587       absl::Span<const std::pair<int64_t, int64_t>> padding,
588       absl::Span<const int64_t> lhs_dilation,
589       absl::Span<const int64_t> rhs_dilation,
590       const ConvolutionDimensionNumbers& dimension_numbers,
591       int64_t feature_group_count, int64_t batch_group_count,
592       const PrecisionConfig* precision_config, PaddingType padding_type,
593       std::optional<PrimitiveType> preferred_element_type = std::nullopt);
594 
595   XlaOp DynamicConvKernelGrad(
596       XlaOp activations, XlaOp gradients,
597       absl::Span<const int64_t> window_strides,
598       absl::Span<const std::pair<int64_t, int64_t>> padding,
599       absl::Span<const int64_t> lhs_dilation,
600       absl::Span<const int64_t> rhs_dilation,
601       const ConvolutionDimensionNumbers& dimension_numbers,
602       int64_t feature_group_count, int64_t batch_group_count,
603       const PrecisionConfig* precision_config, PaddingType padding_type,
604       std::optional<PrimitiveType> preferred_element_type = std::nullopt);
605 
606   StatusOr<HloInstructionProto> DynamicConvInstruction(
607       XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
608       absl::Span<const std::pair<int64_t, int64_t>> padding,
609       absl::Span<const int64_t> lhs_dilation,
610       absl::Span<const int64_t> rhs_dilation,
611       const ConvolutionDimensionNumbers& dimension_numbers,
612       int64_t feature_group_count, int64_t batch_group_count,
613       const PrecisionConfig* precision_config, PaddingType padding_type,
614       std::optional<PrimitiveType> preferred_element_type = std::nullopt);
615 
616   virtual StatusOr<XlaOp> ConvGeneralDilatedInternal(
617       const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
618       absl::Span<const int64_t> window_strides,
619       absl::Span<const std::pair<int64_t, int64_t>> padding,
620       absl::Span<const int64_t> lhs_dilation,
621       absl::Span<const int64_t> rhs_dilation,
622       const ConvolutionDimensionNumbers& dimension_numbers,
623       int64_t feature_group_count, int64_t batch_group_count,
624       const PrecisionConfig* precision_config);
625 
626   XlaOp Fft(XlaOp operand, FftType fft_type,
627             absl::Span<const int64_t> fft_length);
628   virtual StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand,
629                                       FftType fft_type,
630                                       absl::Span<const int64_t> fft_length);
631 
632   virtual StatusOr<XlaOp> TriangularSolveInternal(
633       const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options);
634 
635   virtual StatusOr<XlaOp> CholeskyInternal(const Shape& shape, XlaOp a,
636                                            bool lower);
637 
638   XlaOp Infeed(const Shape& shape, const std::string& config = "");
639   XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
640                         const std::string& config);
641   virtual StatusOr<XlaOp> InfeedWithTokenInternal(
642       const Shape& infeed_instruction_shape, XlaOp token,
643       const std::string& config);
644 
645   void Outfeed(XlaOp operand, const Shape& shape_with_layout,
646                const std::string& outfeed_config);
647   XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
648                          const Shape& shape_with_layout,
649                          const std::string& outfeed_config);
650   virtual StatusOr<XlaOp> OutfeedWithTokenInternal(
651       XlaOp operand, XlaOp token, const Shape& shape_with_layout,
652       const std::string& outfeed_config);
653   XlaOp Call(const XlaComputation& computation,
654              absl::Span<const XlaOp> operands);
655 
656   XlaOp CustomCall(
657       const std::string& call_target_name, absl::Span<const XlaOp> operands,
658       const Shape& shape_with_layout, const std::string& opaque,
659       std::optional<absl::Span<const Shape>> operand_shapes_with_layout,
660       bool has_side_effect,
661       absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
662           output_operand_aliasing,
663       const Literal* literal, std::optional<Window> window,
664       std::optional<ConvolutionDimensionNumbers> dnums,
665       CustomCallSchedule schedule, CustomCallApiVersion api_version);
666 
667   // Internal version of CustomCall without computation that doesn't do op
668   // specific error handling and expects arguments to be legal. CustomCall
669   // method above calls this method after error handling.
670   virtual StatusOr<XlaOp> CustomCallInternal(
671       const std::string& call_target_name, absl::Span<const XlaOp> operands,
672       const XlaComputation* computation, const Shape& shape_with_layout,
673       const std::string& opaque,
674       std::optional<absl::Span<const Shape>> operand_shapes_with_layout,
675       bool has_side_effect,
676       absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
677           output_operand_aliasing,
678       const Literal* literal, std::optional<Window> window,
679       std::optional<ConvolutionDimensionNumbers> dnums,
680       CustomCallSchedule schedule, CustomCallApiVersion api_version);
681 
682   // TODO(b/239474321) Remove this overload as it has simply led to code
683   // duplication.
684   XlaOp CustomCall(
685       const std::string& call_target_name, absl::Span<const XlaOp> operands,
686       const XlaComputation& computation, const Shape& shape_with_layout,
687       const std::string& opaque,
688       std::optional<absl::Span<const Shape>> operand_shapes_with_layout,
689       bool has_side_effect,
690       absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
691           output_operand_aliasing,
692       const Literal* literal, CustomCallSchedule schedule,
693       CustomCallApiVersion api_version);
694 
695   XlaOp OptimizationBarrier(XlaOp operand);
696 
697   XlaOp Reduce(XlaOp operand, XlaOp init_value,
698                const XlaComputation& computation,
699                absl::Span<const int64_t> dimensions_to_reduce);
700 
701   XlaOp Reduce(absl::Span<const XlaOp> operands,
702                absl::Span<const XlaOp> init_values,
703                const XlaComputation& computation,
704                absl::Span<const int64_t> dimensions_to_reduce);
705 
706   virtual StatusOr<XlaOp> ReduceInternal(
707       const Shape& shape, absl::Span<const XlaOp> all_operands,
708       const XlaComputation& computation,
709       absl::Span<const int64_t> dimensions_to_reduce);
710 
711   XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
712                   const XlaComputation& computation);
713 
714   XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
715                      const XlaComputation& computation,
716                      absl::Span<const int64_t> window_dimensions,
717                      absl::Span<const int64_t> window_strides, Padding padding);
718 
719   XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
720                      absl::Span<const XlaOp> init_values,
721                      const XlaComputation& computation,
722                      absl::Span<const int64_t> window_dimensions,
723                      absl::Span<const int64_t> window_strides, Padding padding);
724 
725   XlaOp ReduceWindowWithGeneralPadding(
726       absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
727       const XlaComputation& computation,
728       absl::Span<const int64_t> window_dimensions,
729       absl::Span<const int64_t> window_strides,
730       absl::Span<const int64_t> base_dilations,
731       absl::Span<const int64_t> window_dilations,
732       absl::Span<const std::pair<int64_t, int64_t>> padding);
733   StatusOr<HloInstructionProto> ReduceWindowInternal(
734       absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
735       const XlaComputation& computation,
736       absl::Span<const int64_t> window_dimensions,
737       absl::Span<const int64_t> window_strides,
738       absl::Span<const int64_t> base_dilations,
739       absl::Span<const int64_t> window_dilations,
740       absl::Span<const std::pair<int64_t, int64_t>> padding);
741   virtual StatusOr<XlaOp> ReduceWindowInternal(
742       const Shape& shape, XlaOp operand, XlaOp init_value,
743       const XlaComputation& computation, Window window);
744   XlaOp CrossReplicaSum(XlaOp operand,
745                         absl::Span<const ReplicaGroup> replica_groups = {});
746 
747   XlaOp AllGather(
748       XlaOp operand, int64_t all_gather_dimension, int64_t shard_count,
749       absl::Span<const ReplicaGroup> replica_groups = {},
750       const std::optional<ChannelHandle>& channel_id = std::nullopt,
751       const std::optional<Layout>& layout = std::nullopt,
752       const std::optional<bool> use_global_device_ids = std::nullopt);
753 
754   XlaOp AllReduce(
755       XlaOp operand, const XlaComputation& computation,
756       absl::Span<const ReplicaGroup> replica_groups = {},
757       const std::optional<ChannelHandle>& channel_id = std::nullopt,
758       const std::optional<Shape>& shape_with_layout = std::nullopt,
759       const std::optional<bool> use_global_device_ids = std::nullopt);
760 
761   XlaOp ReduceScatter(
762       XlaOp operand, const XlaComputation& computation,
763       int64_t scatter_dimension, int64_t shard_count,
764       absl::Span<const ReplicaGroup> replica_groups = {},
765       const std::optional<ChannelHandle>& channel_id = std::nullopt,
766       const std::optional<Layout>& layout = std::nullopt,
767       const std::optional<bool> use_global_device_ids = std::nullopt);
768 
769   XlaOp AllToAll(XlaOp operand, int64_t split_dimension,
770                  int64_t concat_dimension, int64_t split_count,
771                  absl::Span<const ReplicaGroup> replica_groups,
772                  const std::optional<Layout>& layout = std::nullopt);
773 
774   XlaOp AllToAllTuple(absl::Span<const XlaOp> operands,
775                       absl::Span<const ReplicaGroup> replica_groups,
776                       const std::optional<Layout>& layout);
777 
778   XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension,
779                       int64_t concat_dimension, int64_t split_count,
780                       absl::Span<const ReplicaGroup> replica_groups,
781                       const std::optional<Layout>& layout);
782 
783   XlaOp CollectivePermute(
784       XlaOp operand,
785       const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs);
786 
787   XlaOp ReplicaId();
788 
789   XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
790                          absl::Span<const int64_t> window_dimensions,
791                          absl::Span<const int64_t> window_strides,
792                          Padding padding, XlaOp source, XlaOp init_value,
793                          const XlaComputation& scatter);
794 
795   XlaOp SelectAndScatterWithGeneralPadding(
796       XlaOp operand, const XlaComputation& select,
797       absl::Span<const int64_t> window_dimensions,
798       absl::Span<const int64_t> window_strides,
799       absl::Span<const std::pair<int64_t, int64_t>> padding, XlaOp source,
800       XlaOp init_value, const XlaComputation& scatter);
801 
802   StatusOr<HloInstructionProto> SelectAndScatterInternal(
803       XlaOp operand, const XlaComputation& select,
804       absl::Span<const int64_t> window_dimensions,
805       absl::Span<const int64_t> window_strides,
806       absl::Span<const std::pair<int64_t, int64_t>> padding, XlaOp source,
807       XlaOp init_value, const XlaComputation& scatter);
808 
809   virtual XlaOp Iota(const Shape& shape, int64_t iota_dimension);
810 
811   XlaOp Iota(PrimitiveType type, int64_t size);
812 
813   XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
814 
815   XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
816   virtual StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape,
817                                                      XlaOp operand);
818 
819   XlaOp Transpose(XlaOp operand, absl::Span<const int64_t> permutation);
820   virtual StatusOr<XlaOp> TransposeInternal(
821       const Shape& shape, XlaOp operand, absl::Span<const int64_t> permutation);
822 
823   XlaOp Rev(XlaOp operand, absl::Span<const int64_t> dimensions);
824   virtual StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand,
825                                       absl::Span<const int64_t> dimensions);
826 
827   XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
828              int64_t dimension = -1, bool is_stable = false);
829   virtual StatusOr<XlaOp> SortInternal(const Shape& shape,
830                                        absl::Span<const XlaOp> operands,
831                                        const XlaComputation& comparator,
832                                        int64_t dimension, bool is_stable);
833 
834   XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
835 
836   XlaOp Map(absl::Span<const XlaOp> operands, const XlaComputation& computation,
837             absl::Span<const int64_t> dimensions,
838             absl::Span<const XlaOp> static_operands = {});
839 
840   XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
841 
842   XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
843 
844   XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
845                         const Shape& shape);
846   // Internal variant for the op with the full result shape containing both data
847   // and state shape as a tuple.
848   virtual StatusOr<XlaOp> RngBitGeneratorInternal(
849       const Shape& full_result_shape, RandomAlgorithm algorithm,
850       XlaOp initial_state);
851 
852   XlaOp While(const XlaComputation& condition, const XlaComputation& body,
853               XlaOp init);
854   virtual StatusOr<XlaOp> WhileInternal(const Shape& shape,
855                                         const XlaComputation& condition,
856                                         const XlaComputation& body, XlaOp init);
857 
858   XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
859                     const XlaComputation& true_computation, XlaOp false_operand,
860                     const XlaComputation& false_computation);
861 
862   XlaOp Conditional(XlaOp branch_index,
863                     absl::Span<const XlaComputation* const> branch_computations,
864                     absl::Span<const XlaOp> branch_operands);
865 
866   XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
867                         const int mantissa_bits);
868   virtual StatusOr<XlaOp> ReducePrecisionInternal(const Shape& shape,
869                                                   XlaOp operand,
870                                                   const int exponent_bits,
871                                                   const int mantissa_bits);
872 
873   XlaOp Gather(XlaOp input, XlaOp start_indices,
874                const GatherDimensionNumbers& dimension_numbers,
875                absl::Span<const int64_t> slice_sizes,
876                bool indices_are_sorted = false);
877 
878   virtual StatusOr<XlaOp> GatherInternal(
879       const Shape& shape, XlaOp input, XlaOp start_indices,
880       const GatherDimensionNumbers& dimension_numbers,
881       absl::Span<const int64_t> slice_sizes, bool indices_are_sorted);
882 
883   XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
884                 const XlaComputation& update_computation,
885                 const ScatterDimensionNumbers& dimension_numbers,
886                 bool indices_are_sorted = false, bool unique_indices = false);
887   XlaOp Scatter(absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
888                 absl::Span<const XlaOp> updates,
889                 const XlaComputation& update_computation,
890                 const ScatterDimensionNumbers& dimension_numbers,
891                 bool indices_are_sorted = false, bool unique_indices = false);
892 
893   virtual StatusOr<XlaOp> ScatterInternal(
894       const Shape& shape, absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
895       absl::Span<const XlaOp> updates, const XlaComputation& update_computation,
896       const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
897       bool unique_indices);
898 
899   void Send(XlaOp operand, const ChannelHandle& handle);
900   XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle);
901 
902   XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout,
903                    const ChannelHandle& handle);
904 
905   XlaOp RecvFromHost(XlaOp token, const Shape& shape,
906                      const ChannelHandle& handle);
907 
908   virtual XlaOp CreateToken();
909 
910   XlaOp AfterAll(absl::Span<const XlaOp> tokens);
911 
912   XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
913   XlaOp RecvWithToken(XlaOp token, const Shape& shape,
914                       const ChannelHandle& handle);
915 
916   XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
917                           float epsilon, int64_t feature_index);
918 
919   XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean,
920                            XlaOp variance, float epsilon,
921                            int64_t feature_index);
922 
923   XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
924                       XlaOp batch_var, XlaOp grad_output, float epsilon,
925                       int64_t feature_index);
926 
927   XlaOp GetDimensionSize(XlaOp operand, int64_t dimension);
928 
929   XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension);
930 
931   virtual StatusOr<XlaOp> SetDimensionSizeInternal(const Shape& shape,
932                                                    XlaOp operand, XlaOp val,
933                                                    int64_t dimension);
934 
935   XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension);
936 
937   virtual StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr,
938                                          HloOpcode opcode,
939                                          absl::Span<const XlaOp> operands);
AddInstruction(HloInstructionProto && instr,HloOpcode opcode)940   StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr,
941                                  HloOpcode opcode) {
942     return AddInstruction(std::move(instr), opcode, /*operands=*/{});
943   }
944 
945   void AddCalledComputation(const XlaComputation& computation,
946                             HloInstructionProto* instr);
947 
948   StatusOr<const HloInstructionProto*> LookUpInstruction(XlaOp op) const;
949   StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
950       int64_t handle) const;
951   StatusOr<HloInstructionProto*> LookUpMutableInstruction(XlaOp op);
952   StatusOr<HloInstructionProto*> LookUpMutableInstructionByHandle(
953       int64_t handle);
954 
955   // Internal helper method that does the building for an arbitrary unary op.
956   virtual XlaOp UnaryOp(HloOpcode unop, XlaOp operand);
957 
958   // Internal helper method that does the building for an arbitrary binary op.
959   // broadcast_dimensions specifies which dimensions to use for broadcasting
960   // when the operation is between tensors of different ranks. The direction is
961   // only used if opcode is kCompare.
962   XlaOp BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
963                  absl::Span<const int64_t> broadcast_dimensions,
964                  std::optional<ComparisonDirection> direction = std::nullopt,
965                  std::optional<Comparison::Type> type = std::nullopt);
966 
967   StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
968                           ComparisonDirection direction);
969 
970   // Internal helper method for binary op compare without broadcast dimensions.
971   virtual StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
972                                   ComparisonDirection direction,
973                                   Comparison::Type type);
974 
975   // Internal helper method that does the building for an arbitrary binary op
976   // with same ranked operands that doesn't broadcast.
977   virtual XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
978                                     XlaOp lhs, XlaOp rhs);
979 
980   // Internal helper method that does the building for an arbitrary ternary op.
981   XlaOp TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs);
982 
983   XlaOp RngOp(RandomDistribution distribution,
984               absl::Span<const XlaOp> parameters, const Shape& shape);
985 
986   virtual StatusOr<XlaOp> RngOpInternal(RandomDistribution distribution,
987                                         absl::Span<const XlaOp> parameters,
988                                         const Shape& shape);
989 
990   virtual StatusOr<XlaOp> InDimBroadcast(
991       const Shape& shape, XlaOp operand,
992       absl::Span<const int64_t> broadcast_dimensions);
993 
994   // Internal helper method that creates a sequence of instructions that
995   // performs an explicit broadcast of the operand to the target shape.
996   StatusOr<XlaOp> AddBroadcastSequence(const Shape& output_shape,
997                                        XlaOp operand);
998 
999   // Internal helper method for creating a Reshape op with the already inferred
1000   // shape.
1001   virtual StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand,
1002                                           int64_t inferred_dimension);
1003 
1004   // Returns the (inferred) result for the program shape using the given root.
1005   StatusOr<ProgramShape> GetProgramShape(int64_t root_id) const;
1006 
1007   // A visitor which checks whether an operation is a compile-time constant,
1008   // meaning that it doesn't depend on any parameters, or on any stateful
1009   // operation such as `RngNormal` or `Infeed`. The visitor walks the
1010   // computation starting at a given operation and sets is_constant to false iff
1011   // a parameter or stateful operation is encountered.
1012   void IsConstantVisitor(const int64_t op_handle, int depth,
1013                          absl::flat_hash_set<int64_t>* visited,
1014                          bool* is_constant) const;
1015 
1016   // Checks bounds for convolution parameters.
1017   Status VerifyConvolution(
1018       const Shape& lhs_shape, const Shape& rhs_shape,
1019       const ConvolutionDimensionNumbers& dimension_numbers) const;
1020 
GetNextId()1021   int64_t GetNextId() { return ++next_id_; }
1022 
1023   // Populates the module with the input/output alias information stored within
1024   // the input_output_aliases vector.
1025   static Status PopulateInputOutputAlias(
1026       HloModuleProto* module, const ProgramShape& program_shape,
1027       const std::vector<InputOutputAlias>& input_output_aliases);
1028 
1029   std::string name_;  // Name to use for the built computation.
1030 
1031   // The next sequential ID for every instruction/computation contained within
1032   // this computation.
1033   int64_t next_id_ = 0;
1034 
1035   // The first error encountered while building the computation.
1036   // This is OK until the first error is encountered.
1037   Status first_error_;
1038 
1039   // The saved stack trace from the point at which the first error occurred.
1040   tensorflow::SavedStackTrace first_error_backtrace_;
1041 
1042   // The instructions of this computation.
1043   // Use a deque so pointers into this are stable, for example the return
1044   // value of LookUpInstructionByHandle().
1045   std::deque<HloInstructionProto> instructions_;
1046   // An cache for the HloInstructionProto shapes, to avoid recreating Shape
1047   // objects from protos and to support the GetShapePtr() API.
1048   std::vector<std::unique_ptr<Shape>> instruction_shapes_;
1049 
1050   // Dynamic parameter configuration of this computation.
1051   DynamicParameterBinding dynamic_parameter_binding_;
1052 
1053   // Holds the input/output alias information populated by the SetUpAlias() API.
1054   std::vector<InputOutputAlias> input_output_aliases_;
1055 
1056   // A map from XlaOp::Handle to the index in the instructions_ vector where the
1057   // instruction is held.
1058   absl::flat_hash_map<int64_t, int64_t> handle_to_index_;
1059 
1060   // Track imported instructions by their computation id and the position in
1061   // their computation's instruction list.
1062   struct ImportedInstruction {
1063     int64_t computation_id;
1064     int64_t instruction_index;
1065   };
1066 
1067   absl::flat_hash_map<int64_t, ImportedInstruction> handle_to_imported_index_;
1068 
1069   // The embedded computations used by this computation. Each computation was
1070   // the entry computation of some XlaComputation, the key is the unique id of
1071   // that XlaComputation.
1072   std::map<int64_t, HloComputationProto> embedded_;
1073 
1074   // The unique parameter numbers.
1075   absl::flat_hash_set<int64_t> parameter_numbers_;
1076 
1077   // The metadata to attach to each op. This is structured as a "modal"-like
1078   // operation, in order to simplify client code (and not sprinkle this metadata
1079   // throughout the TensorFlow op kernel implementations).
1080   OpMetadata metadata_;
1081 
1082   // A temporary metadata that will only be applied to the next op created.
1083   std::optional<OpMetadata> one_shot_metadata_;
1084 
1085   // Sharding for this operator. This is structured as a "model"-like operation,
1086   // in order to simplify client code, similar to metadata_.
1087   std::optional<OpSharding> sharding_;
1088 
1089   // Mode bit that indicates whether to die when a first error is encountered.
1090   bool die_immediately_on_error_ = false;
1091 
1092   XlaBuilder* parent_builder_{nullptr};
1093 
1094   FrontendAttributes frontend_attributes_;
1095 
1096   friend XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
1097                          const Shape& shape, const std::string& name,
1098                          const std::vector<bool>& replicated_at_leaf_buffers);
1099   friend XlaOp ConstantLiteral(XlaBuilder* builder,
1100                                const LiteralSlice& literal);
1101 
1102   friend XlaOp Broadcast(XlaOp operand,
1103                          absl::Span<const int64_t> broadcast_sizes);
1104 
1105   friend XlaOp BroadcastInDim(
1106       XlaOp operand, const absl::Span<const int64_t> out_dim_size,
1107       const absl::Span<const int64_t> broadcast_dimensions);
1108 
1109   friend XlaOp Copy(XlaOp operand);
1110 
1111   friend XlaOp Pad(XlaOp operand, XlaOp padding_value,
1112                    const PaddingConfig& padding_config);
1113 
1114   friend XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
1115                         int64_t pad_lo, int64_t pad_hi);
1116 
1117   friend XlaOp Reshape(XlaOp operand, absl::Span<const int64_t> dimensions,
1118                        absl::Span<const int64_t> new_sizes);
1119 
1120   friend XlaOp Reshape(XlaOp operand, absl::Span<const int64_t> new_sizes);
1121 
1122   friend XlaOp Reshape(const Shape& shape, XlaOp operand);
1123 
1124   friend XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
1125                               absl::Span<const int64_t> new_size_bounds,
1126                               const std::vector<bool>& dims_are_dynamic);
1127 
1128   friend XlaOp ReshapeWithInferredDimension(XlaOp operand,
1129                                             absl::Span<const int64_t> new_sizes,
1130                                             int64_t inferred_dimension);
1131 
1132   friend XlaOp Collapse(XlaOp operand, absl::Span<const int64_t> dimensions);
1133 
1134   friend XlaOp Slice(XlaOp operand, absl::Span<const int64_t> start_indices,
1135                      absl::Span<const int64_t> limit_indices,
1136                      absl::Span<const int64_t> strides);
1137 
1138   friend XlaOp SliceInDim(XlaOp operand, int64_t start_index,
1139                           int64_t limit_index, int64_t stride, int64_t dimno);
1140 
1141   friend XlaOp DynamicSlice(XlaOp operand,
1142                             absl::Span<const XlaOp> start_indices,
1143                             absl::Span<const int64_t> slice_sizes);
1144 
1145   friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
1146                                   absl::Span<const XlaOp> start_indices);
1147 
1148   friend XlaOp ConcatInDim(XlaBuilder* builder,
1149                            absl::Span<const XlaOp> operands, int64_t dimension);
1150 
1151   friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
1152   friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
1153   friend XlaOp GetTupleElement(XlaOp tuple_data, int64_t index);
1154   friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
1155                        absl::Span<const int64_t> broadcast_dimensions,
1156                        ComparisonDirection direction);
1157   friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
1158                        absl::Span<const int64_t> broadcast_dimensions,
1159                        ComparisonDirection direction,
1160                        Comparison::Type compare_type);
1161   friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
1162                    const PrecisionConfig* precision_config,
1163                    std::optional<PrimitiveType> preferred_element_type);
1164   friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
1165                           const DotDimensionNumbers& dimension_number,
1166                           const PrecisionConfig* precision_config,
1167                           std::optional<PrimitiveType> preferred_element_type);
1168   virtual StatusOr<XlaOp> DotGeneralInternal(
1169       const Shape& shape, XlaOp lhs, XlaOp rhs,
1170       const DotDimensionNumbers& dimension_number,
1171       const PrecisionConfig* precision_config);
1172   friend XlaOp Conv(XlaOp lhs, XlaOp rhs,
1173                     absl::Span<const int64_t> window_strides, Padding padding,
1174                     int64_t feature_group_count, int64_t batch_group_count,
1175                     const PrecisionConfig* precision_config,
1176                     std::optional<PrimitiveType> preferred_element_type);
1177   friend XlaOp ConvWithGeneralPadding(
1178       XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1179       absl::Span<const std::pair<int64_t, int64_t>> padding,
1180       int64_t feature_group_count, int64_t batch_group_count,
1181       const PrecisionConfig* precision_config,
1182       std::optional<PrimitiveType> preferred_element_type);
1183   friend XlaOp ConvWithGeneralDimensions(
1184       XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1185       Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1186       int64_t feature_group_count, int64_t batch_group_count,
1187       const PrecisionConfig* precision_config,
1188       std::optional<PrimitiveType> preferred_element_type);
1189   friend XlaOp ConvGeneral(
1190       XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1191       absl::Span<const std::pair<int64_t, int64_t>> padding,
1192       const ConvolutionDimensionNumbers& dimension_numbers,
1193       int64_t feature_group_count, int64_t batch_group_count,
1194       const PrecisionConfig* precision_config,
1195       std::optional<PrimitiveType> preferred_element_type);
1196   friend XlaOp DynamicConvForward(
1197       XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1198       absl::Span<const std::pair<int64_t, int64_t>> padding,
1199       absl::Span<const int64_t> lhs_dilation,
1200       absl::Span<const int64_t> rhs_dilation,
1201       const ConvolutionDimensionNumbers& dimension_numbers,
1202       int64_t feature_group_count, int64_t batch_group_count,
1203       const PrecisionConfig* precision_config, PaddingType padding_type,
1204       std::optional<PrimitiveType> preferred_element_type);
1205   friend XlaOp DynamicConvKernelGrad(
1206       XlaOp activations, XlaOp gradients,
1207       absl::Span<const int64_t> window_strides,
1208       absl::Span<const std::pair<int64_t, int64_t>> padding,
1209       absl::Span<const int64_t> lhs_dilation,
1210       absl::Span<const int64_t> rhs_dilation,
1211       const ConvolutionDimensionNumbers& dimension_numbers,
1212       int64_t feature_group_count, int64_t batch_group_count,
1213       const PrecisionConfig* precision_config, PaddingType padding_type,
1214       std::optional<PrimitiveType> preferred_element_type);
1215   friend XlaOp DynamicConvInputGrad(
1216       XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
1217       absl::Span<const int64_t> window_strides,
1218       absl::Span<const std::pair<int64_t, int64_t>> padding,
1219       absl::Span<const int64_t> lhs_dilation,
1220       absl::Span<const int64_t> rhs_dilation,
1221       const ConvolutionDimensionNumbers& dimension_numbers,
1222       int64_t feature_group_count, int64_t batch_group_count,
1223       const PrecisionConfig* precision_config, PaddingType padding_type,
1224       std::optional<PrimitiveType> preferred_element_type);
1225 
1226   friend XlaOp ConvKernelGrad(
1227       XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1228       absl::Span<const std::pair<int64_t, int64_t>> padding,
1229       absl::Span<const int64_t> lhs_dilation,
1230       absl::Span<const int64_t> rhs_dilation,
1231       const ConvolutionDimensionNumbers& dimension_numbers,
1232       int64_t feature_group_count, int64_t batch_group_count,
1233       const PrecisionConfig* precision_config,
1234       std::optional<PrimitiveType> preferred_element_type);
1235 
1236   friend XlaOp ConvGeneralDilated(
1237       XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1238       absl::Span<const std::pair<int64_t, int64_t>> padding,
1239       absl::Span<const int64_t> lhs_dilation,
1240       absl::Span<const int64_t> rhs_dilation,
1241       const ConvolutionDimensionNumbers& dimension_numbers,
1242       int64_t feature_group_count, int64_t batch_group_count,
1243       const PrecisionConfig* precision_config,
1244       std::optional<PrimitiveType> preferred_element_type,
1245       std::optional<std::vector<bool>> window_reversal);
1246 
1247   friend XlaOp Fft(XlaOp operand, FftType fft_type,
1248                    absl::Span<const int64_t> fft_length);
1249   friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
1250                                bool unit_diagonal,
1251                                TriangularSolveOptions::Transpose transpose_a);
1252   friend XlaOp Cholesky(XlaOp a, bool lower);
1253   friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
1254                       const std::string& config);
1255   friend void Outfeed(XlaOp operand, const Shape& shape_with_layout,
1256                       const std::string& outfeed_config);
1257   friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
1258                     absl::Span<const XlaOp> operands);
1259   friend XlaOp CustomCall(
1260       XlaBuilder* builder, const std::string& call_target_name,
1261       absl::Span<const XlaOp> operands, const Shape& shape,
1262       const std::string& opaque, bool has_side_effect,
1263       absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
1264           output_operand_aliasing,
1265       const Literal* literal, CustomCallSchedule schedule,
1266       CustomCallApiVersion api_version);
1267   friend XlaOp CustomCallWithComputation(
1268       XlaBuilder* builder, const std::string& call_target_name,
1269       absl::Span<const XlaOp> operands, const XlaComputation& computation,
1270       const Shape& shape, const std::string& opaque, bool has_side_effect,
1271       absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
1272           output_operand_aliasing,
1273       const Literal* literal, CustomCallSchedule schedule,
1274       CustomCallApiVersion api_version);
1275   friend XlaOp CustomCallWithLayout(
1276       XlaBuilder* builder, const std::string& call_target_name,
1277       absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
1278       absl::Span<const Shape> operand_shapes_with_layout,
1279       const std::string& opaque, bool has_side_effect,
1280       absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
1281           output_operand_aliasing,
1282       const Literal* literal, CustomCallSchedule schedule,
1283       CustomCallApiVersion api_version);
1284   friend XlaOp CustomCallWithConvDnums(
1285       XlaBuilder* builder, const std::string& call_target_name,
1286       absl::Span<const XlaOp> operands, const Shape& shape,
1287       absl::Span<const Shape> operand_shapes_with_layout,
1288       const std::string& opaque, bool has_side_effect,
1289       absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
1290           output_operand_aliasing,
1291       const Literal* literal, Window window, ConvolutionDimensionNumbers dnums,
1292       CustomCallSchedule schedule, CustomCallApiVersion api_version);
1293   friend XlaOp OptimizationBarrier(XlaOp operand);
1294   friend XlaOp Complex(XlaOp real, XlaOp imag,
1295                        absl::Span<const int64_t> broadcast_dimensions);
1296   friend XlaOp Conj(XlaOp operand);
1297   friend XlaOp Add(XlaOp lhs, XlaOp rhs,
1298                    absl::Span<const int64_t> broadcast_dimensions);
1299   friend XlaOp Sub(XlaOp lhs, XlaOp rhs,
1300                    absl::Span<const int64_t> broadcast_dimensions);
1301   friend XlaOp Mul(XlaOp lhs, XlaOp rhs,
1302                    absl::Span<const int64_t> broadcast_dimensions);
1303   friend XlaOp Div(XlaOp lhs, XlaOp rhs,
1304                    absl::Span<const int64_t> broadcast_dimensions);
1305   friend XlaOp Rem(XlaOp lhs, XlaOp rhs,
1306                    absl::Span<const int64_t> broadcast_dimensions);
1307   friend XlaOp Max(XlaOp lhs, XlaOp rhs,
1308                    absl::Span<const int64_t> broadcast_dimensions);
1309   friend XlaOp Min(XlaOp lhs, XlaOp rhs,
1310                    absl::Span<const int64_t> broadcast_dimensions);
1311   friend XlaOp And(XlaOp lhs, XlaOp rhs,
1312                    absl::Span<const int64_t> broadcast_dimensions);
1313   friend XlaOp Or(XlaOp lhs, XlaOp rhs,
1314                   absl::Span<const int64_t> broadcast_dimensions);
1315   friend XlaOp Xor(XlaOp lhs, XlaOp rhs,
1316                    absl::Span<const int64_t> broadcast_dimensions);
1317   friend XlaOp Not(XlaOp operand);
1318   friend XlaOp PopulationCount(XlaOp operand);
1319   friend XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs,
1320                          absl::Span<const int64_t> broadcast_dimensions);
1321   friend XlaOp ShiftRightArithmetic(
1322       XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> broadcast_dimensions);
1323   friend XlaOp ShiftRightLogical(
1324       XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> broadcast_dimensions);
1325   friend XlaOp Reduce(XlaOp operand, XlaOp init_value,
1326                       const XlaComputation& computation,
1327                       absl::Span<const int64_t> dimensions_to_reduce);
1328   friend XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1329                       absl::Span<const XlaOp> init_values,
1330                       const XlaComputation& computation,
1331                       absl::Span<const int64_t> dimensions_to_reduce);
1332   friend XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
1333                          const XlaComputation& computation);
1334   friend XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
1335                             const XlaComputation& computation,
1336                             absl::Span<const int64_t> window_dimensions,
1337                             absl::Span<const int64_t> window_strides,
1338                             Padding padding);
1339   friend XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
1340                             absl::Span<const XlaOp> init_values,
1341                             const XlaComputation& computation,
1342                             absl::Span<const int64_t> window_dimensions,
1343                             absl::Span<const int64_t> window_strides,
1344                             Padding padding);
1345   friend XlaOp ReduceWindowWithGeneralPadding(
1346       XlaOp operand, XlaOp init_value, const XlaComputation& computation,
1347       absl::Span<const int64_t> window_dimensions,
1348       absl::Span<const int64_t> window_strides,
1349       absl::Span<const int64_t> base_dilations,
1350       absl::Span<const int64_t> window_dilations,
1351       absl::Span<const std::pair<int64_t, int64_t>> padding);
1352   friend XlaOp ReduceWindowWithGeneralPadding(
1353       absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
1354       const XlaComputation& computation,
1355       absl::Span<const int64_t> window_dimensions,
1356       absl::Span<const int64_t> window_strides,
1357       absl::Span<const int64_t> base_dilations,
1358       absl::Span<const int64_t> window_dilations,
1359       absl::Span<const std::pair<int64_t, int64_t>> padding);
1360 
1361   friend XlaOp CrossReplicaSum(XlaOp operand,
1362                                absl::Span<const ReplicaGroup> replica_groups);
1363   friend XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension,
1364                          int64_t shard_count,
1365                          absl::Span<const ReplicaGroup> replica_groups,
1366                          const std::optional<ChannelHandle>& channel_id,
1367                          const std::optional<Layout>& layout,
1368                          const std::optional<bool> use_global_device_ids);
1369   friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
1370                          absl::Span<const ReplicaGroup> replica_groups,
1371                          const std::optional<ChannelHandle>& channel_id,
1372                          const std::optional<Shape>& shape_with_layout,
1373                          const std::optional<bool> use_global_device_ids);
1374   friend XlaOp ReduceScatter(XlaOp operand, const XlaComputation& computation,
1375                              int64_t scatter_dimension, int64_t shard_count,
1376                              absl::Span<const ReplicaGroup> replica_groups,
1377                              const std::optional<ChannelHandle>& channel_id,
1378                              const std::optional<Layout>& layout,
1379                              const std::optional<bool> use_global_device_ids);
1380 
1381   friend XlaOp AllToAll(XlaOp operand, int64_t split_dimension,
1382                         int64_t concat_dimension, int64_t split_count,
1383                         absl::Span<const ReplicaGroup> replica_groups,
1384                         const std::optional<Layout>& layout);
1385   friend XlaOp AllToAllTuple(absl::Span<const XlaOp> operands,
1386                              absl::Span<const ReplicaGroup> replica_groups,
1387                              const std::optional<Layout>& layout);
1388   friend XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension,
1389                              int64_t concat_dimension, int64_t split_count,
1390                              absl::Span<const ReplicaGroup> replica_groups,
1391                              const std::optional<Layout>& layout);
1392   friend XlaOp CollectivePermute(
1393       XlaOp operand,
1394       const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs);
1395   friend XlaOp ReplicaId(XlaBuilder* builder);
1396   friend XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
1397                                 absl::Span<const int64_t> window_dimensions,
1398                                 absl::Span<const int64_t> window_strides,
1399                                 Padding padding, XlaOp source, XlaOp init_value,
1400                                 const XlaComputation& scatter);
1401   friend XlaOp SelectAndScatterWithGeneralPadding(
1402       XlaOp operand, const XlaComputation& select,
1403       absl::Span<const int64_t> window_dimensions,
1404       absl::Span<const int64_t> window_strides,
1405       absl::Span<const std::pair<int64_t, int64_t>> padding, XlaOp source,
1406       XlaOp init_value, const XlaComputation& scatter);
1407   friend XlaOp Abs(XlaOp operand);
1408   friend XlaOp Atan2(XlaOp y, XlaOp x,
1409                      absl::Span<const int64_t> broadcast_dimensions);
1410   friend XlaOp Exp(XlaOp operand);
1411   friend XlaOp Expm1(XlaOp operand);
1412   friend XlaOp Floor(XlaOp operand);
1413   friend XlaOp Ceil(XlaOp operand);
1414   friend XlaOp Round(XlaOp operand);
1415   friend XlaOp RoundNearestEven(XlaOp operand);
1416   friend XlaOp Log(XlaOp operand);
1417   friend XlaOp Log1p(XlaOp operand);
1418   friend XlaOp Logistic(XlaOp operand);
1419   friend XlaOp Sign(XlaOp operand);
1420   friend XlaOp Clz(XlaOp operand);
1421   friend XlaOp Cos(XlaOp operand);
1422   friend XlaOp Sin(XlaOp operand);
1423   friend XlaOp Tanh(XlaOp operand);
1424   friend XlaOp Real(XlaOp operand);
1425   friend XlaOp Imag(XlaOp operand);
1426   friend XlaOp Sqrt(XlaOp operand);
1427   friend XlaOp Rsqrt(XlaOp operand);
1428   friend XlaOp Cbrt(XlaOp operand);
1429   friend XlaOp Pow(XlaOp lhs, XlaOp rhs,
1430                    absl::Span<const int64_t> broadcast_dimensions);
1431   friend XlaOp IsFinite(XlaOp operand);
1432   friend XlaOp Iota(XlaBuilder* builder, const Shape& shape,
1433                     int64_t iota_dimension);
1434   friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size);
1435   friend XlaOp ConvertElementType(XlaOp operand,
1436                                   PrimitiveType new_element_type);
1437   friend XlaOp BitcastConvertType(XlaOp operand,
1438                                   PrimitiveType new_element_type);
1439   friend XlaOp Neg(XlaOp operand);
1440   friend XlaOp Transpose(XlaOp operand, absl::Span<const int64_t> permutation);
1441   friend XlaOp Rev(XlaOp operand, absl::Span<const int64_t> dimensions);
1442   friend XlaOp Sort(absl::Span<const XlaOp> operands,
1443                     const XlaComputation& comparator, int64_t dimension,
1444                     bool is_stable);
1445   friend XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
1446   friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1447                    const XlaComputation& computation,
1448                    absl::Span<const int64_t> dimensions,
1449                    absl::Span<const XlaOp> static_operands);
1450   friend XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
1451   friend XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
1452   friend XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
1453                                const Shape& shape);
1454   friend XlaOp While(const XlaComputation& condition,
1455                      const XlaComputation& body, XlaOp init);
1456   friend XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
1457                            const XlaComputation& true_computation,
1458                            XlaOp false_operand,
1459                            const XlaComputation& false_computation);
1460   friend XlaOp Conditional(
1461       XlaOp branch_index,
1462       absl::Span<const XlaComputation* const> branch_computations,
1463       absl::Span<const XlaOp> branch_operands);
1464   friend XlaOp ConditionalImpl(
1465       XlaOp branch_index,
1466       absl::Span<const XlaComputation* const> branch_computations,
1467       absl::Span<const XlaOp> branch_operands);
1468   friend XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
1469                                const int mantissa_bits);
1470   friend XlaOp Gather(XlaOp input, XlaOp start_indices,
1471                       const GatherDimensionNumbers& dimension_numbers,
1472                       absl::Span<const int64_t> slice_sizes,
1473                       bool indices_are_sorted);
1474   friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
1475                        const XlaComputation& update_computation,
1476                        const ScatterDimensionNumbers& dimension_numbers,
1477                        bool indices_are_sorted, bool unique_indices);
1478   friend XlaOp Scatter(absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
1479                        absl::Span<const XlaOp> updates,
1480                        const XlaComputation& update_computation,
1481                        const ScatterDimensionNumbers& dimension_numbers,
1482                        bool indices_are_sorted, bool unique_indices);
1483   friend void Send(XlaOp operand, const ChannelHandle& handle);
1484   friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
1485                     const ChannelHandle& handle);
1486   friend XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
1487                                  float epsilon, int64_t feature_index);
1488   friend XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset,
1489                                   XlaOp mean, XlaOp variance, float epsilon,
1490                                   int64_t feature_index);
1491   friend XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
1492                              XlaOp batch_var, XlaOp grad_output, float epsilon,
1493                              int64_t feature_index);
1494   friend XlaOp SendWithToken(XlaOp operand, XlaOp token,
1495                              const ChannelHandle& handle);
1496   friend XlaOp RecvWithToken(XlaOp token, const Shape& shape,
1497                              const ChannelHandle& handle);
1498   friend XlaOp SendToHost(XlaOp operand, XlaOp token,
1499                           const Shape& shape_with_layout,
1500                           const ChannelHandle& handle);
1501   friend XlaOp RecvFromHost(XlaOp token, const Shape& shape,
1502                             const ChannelHandle& handle);
1503   friend XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
1504                                const std::string& config);
1505   friend XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
1506                                 const Shape& shape_with_layout,
1507                                 const std::string& outfeed_config);
1508   friend XlaOp CreateToken(XlaBuilder* builder);
1509   friend XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
1510 
1511   friend XlaOp GetDimensionSize(XlaOp operand, int64_t dimension);
1512   friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension);
1513   friend XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension);
1514 
1515  protected:
1516   // Returns OK status if the given op was built using this builder. Otherwise,
1517   // returns an error.
1518   Status CheckOpBuilder(XlaOp op) const;
1519 
1520  private:
1521   XlaOp ConditionalImpl(
1522       XlaOp branch_index,
1523       absl::Span<const XlaComputation* const> branch_computations,
1524       absl::Span<const XlaOp> branch_operands);
1525 
1526   XlaOp AllToAllArray(XlaOp operand, int64_t split_dimension,
1527                       int64_t concat_dimension, int64_t split_count,
1528                       absl::Span<const ReplicaGroup> replica_groups);
1529 
1530   // Creates an op with the given opcode and the output shape.
1531   virtual StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
1532                                          absl::Span<const XlaOp> operands);
1533 
1534   // Here, InstructionType is either const HloInstructionProto* or non-const
1535   // HloInstructionProto*.
1536   template <typename InstructionType>
LookUpInstructionByHandleInternal(int64_t handle)1537   StatusOr<InstructionType> LookUpInstructionByHandleInternal(
1538       int64_t handle) const {
1539     auto it = handle_to_index_.find(handle);
1540     if (it == handle_to_index_.end()) {
1541       // Try look for the instruction in the imported instructions.
1542       auto imported_it = handle_to_imported_index_.find(handle);
1543       if (imported_it != handle_to_imported_index_.end()) {
1544         ImportedInstruction imported = imported_it->second;
1545         return const_cast<InstructionType>(
1546             &embedded_.at(imported.computation_id)
1547                  .instructions(imported.instruction_index));
1548       }
1549       return InvalidArgument("No XlaOp with handle %d", handle);
1550     }
1551     return const_cast<InstructionType>(&instructions_.at(it->second));
1552   }
1553 
1554   // Here, InstructionType is either const HloInstructionProto* or non-const
1555   // HloInstructionProto*.
1556   //
1557   // TODO(hinsu): Return const pointer within StatusOr and use
1558   // absl::implicit_cast at callsites. This requires implicit_cast support in
1559   // stream_executor::port::StatusOr similar to absl::StatusOr.
1560   template <typename InstructionType>
LookUpInstructionInternal(XlaOp op)1561   StatusOr<InstructionType> LookUpInstructionInternal(XlaOp op) const {
1562     TF_RETURN_IF_ERROR(CheckOpBuilder(op));
1563     return LookUpInstructionByHandleInternal<InstructionType>(op.handle());
1564   }
1565 
1566   friend struct internal::XlaBuilderFriend;
1567 
1568   friend class ValueInference;
1569 };
1570 
1571 // RAII-style object: sets the current sharding assignment in builder on
1572 // construction, and sets back to the previous assignment on destruction.
1573 class XlaScopedShardingAssignment {
1574  public:
XlaScopedShardingAssignment(xla::XlaBuilder * builder,std::optional<OpSharding> sharding)1575   XlaScopedShardingAssignment(xla::XlaBuilder* builder,
1576                               std::optional<OpSharding> sharding)
1577       : builder_(builder), prev_sharding_(builder->sharding()) {
1578     SetSharding(sharding);
1579   }
1580 
1581   XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
1582   XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
1583       delete;
1584 
~XlaScopedShardingAssignment()1585   ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
1586 
1587  private:
SetSharding(const std::optional<OpSharding> & sharding)1588   void SetSharding(const std::optional<OpSharding>& sharding) {
1589     if (sharding.has_value()) {
1590       builder_->SetSharding(sharding.value());
1591     } else {
1592       builder_->ClearSharding();
1593     }
1594   }
1595 
1596   xla::XlaBuilder* const builder_;
1597   std::optional<OpSharding> prev_sharding_;
1598 };
1599 
1600 // RAII-style object: save the current builder's frontend attributes, and merge
1601 // them with the new ones on construction.
1602 // Restore the original attributes on destruction.
1603 class XlaScopedFrontendAttributesAssignment {
1604  public:
XlaScopedFrontendAttributesAssignment(xla::XlaBuilder * builder,FrontendAttributes attributes)1605   XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder,
1606                                         FrontendAttributes attributes)
1607       : builder_(builder) {
1608     saved_ = builder_->SwapFrontendAttributes(attributes);
1609   }
1610 
~XlaScopedFrontendAttributesAssignment()1611   ~XlaScopedFrontendAttributesAssignment() {
1612     builder_->SetFrontendAttributes(saved_);
1613   }
1614 
1615  private:
1616   xla::XlaBuilder* const builder_;
1617   FrontendAttributes saved_;
1618 
1619   XlaScopedFrontendAttributesAssignment(
1620       const XlaScopedFrontendAttributesAssignment&) = delete;
1621   XlaScopedFrontendAttributesAssignment& operator=(
1622       const XlaScopedFrontendAttributesAssignment&) = delete;
1623 };
1624 
1625 // RAII-style object: sets the current op metadata in builder on construction,
1626 // and sets back to the previous assignment on destruction.
1627 class XlaScopedOpMetadataAssignment {
1628  public:
XlaScopedOpMetadataAssignment(xla::XlaBuilder * builder,OpMetadata metadata)1629   XlaScopedOpMetadataAssignment(xla::XlaBuilder* builder, OpMetadata metadata)
1630       : builder_(builder) {
1631     saved_ = builder_->SwapOpMetadata(metadata);
1632   }
1633 
~XlaScopedOpMetadataAssignment()1634   ~XlaScopedOpMetadataAssignment() { builder_->SwapOpMetadata(saved_); }
1635 
1636  private:
1637   xla::XlaBuilder* const builder_;
1638   OpMetadata saved_;
1639 
1640   XlaScopedOpMetadataAssignment(const XlaScopedOpMetadataAssignment&) = delete;
1641   XlaScopedOpMetadataAssignment& operator=(
1642       const XlaScopedOpMetadataAssignment&) = delete;
1643 };
1644 
1645 // Free functions for building XlaOps. The intention is that these will
1646 // become the public API for building XlaOps rather than calling methods on
1647 // XlaBuilder directly.
1648 //
1649 
1650 // Enqueues a "retrieve parameter value" instruction for a parameter that was
1651 // passed to the computation.
1652 XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
1653                 const Shape& shape, const std::string& name);
1654 
1655 // Same as above, but with leaf buffer replication annotation.
1656 XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
1657                 const Shape& shape, const std::string& name,
1658                 const std::vector<bool>& replicated_at_leaf_buffers);
1659 
1660 // Enqueues a constant with the value of the given literal onto the
1661 // computation.
1662 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal);
1663 
1664 // Enqueues a constant onto the computation. Methods are templated on the
1665 // native host type (NativeT) which corresponds to a specific XLA
1666 // PrimitiveType as given in the following table:
1667 //
1668 //  Native Type   PrimitiveType
1669 // -----------------------------
1670 //   bool           PRED
1671 //   int32_t        S32
1672 //   int64_t        S64
1673 //   uint32_t       U32
1674 //   uint64_t       U64
1675 //   float          F32
1676 //   double         F64
1677 //
1678 // Note: not all primitive types defined in xla_data.proto have a
1679 // corresponding native type yet.
1680 template <typename NativeT>
1681 XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
1682 template <typename NativeT>
1683 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values);
1684 XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values);
1685 template <typename NativeT>
1686 XlaOp ConstantR2(XlaBuilder* builder,
1687                  std::initializer_list<std::initializer_list<NativeT>> values);
1688 template <typename NativeT>
1689 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
1690                                   const Array<NativeT>& values,
1691                                   const Layout& layout);
1692 template <typename NativeT>
1693 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values);
1694 template <typename NativeT>
1695 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
1696                                       const Array2D<NativeT>& values,
1697                                       const Layout& layout);
1698 template <typename NativeT>
1699 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
1700                             const Array2D<NativeT>& values);
1701 template <typename NativeT>
1702 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
1703                                       const Array3D<NativeT>& values,
1704                                       const Layout& layout);
1705 template <typename NativeT>
1706 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
1707                             const Array3D<NativeT>& values);
1708 template <typename NativeT>
1709 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
1710                                       const Array4D<NativeT>& values,
1711                                       const Layout& layout);
1712 template <typename NativeT>
1713 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
1714                             const Array4D<NativeT>& values);
1715 
1716 // Enqueues a rank one constant (XlaBuilder* builder, vector) onto the
1717 // computation. The vector has size 'length' and every element has the value
1718 // 'value'.
1719 template <typename NativeT>
1720 XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value);
1721 
1722 // Adds dimensions to an array by duplicating the data in the array.
1723 //
1724 // The new dimensions are inserted on the left, i.e. if
1725 // broadcast_sizes has values {a0, ..., aN} and the operand shape
1726 // has dimensions {b0, ..., bM} then the shape of the output has
1727 // dimensions {a0, ..., aN, b0, ..., bM}.
1728 //
1729 // The new dimensions index into copies of the operand, i.e.
1730 //
1731 //   output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
1732 XlaOp Broadcast(XlaOp operand, absl::Span<const int64_t> broadcast_sizes);
1733 
1734 // This op broadcasts the `operand` to an output with the given `shape`.
1735 // `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the
1736 // i'th dimension of the operand is mapped to the broadcast_dimensions[i]'th
1737 // dimension of the output. This also requires that the i'th input dimension is
1738 // either 1 or is the same as the output dimension it's broadcasting into.
1739 //
1740 // For example, say operand = {1, 2}, i.e., a 1D tensor in shape s32[2]; the
1741 // output shape is s32[2,2]:
1742 // - Specifying {1} as broadcast_dimension will generate output
1743 //   {{1, 2},
1744 //    {1, 2}}
1745 // - On the other hand, specifying {0} as broadcast_dimension
1746 //   will generate output
1747 //   {{1 , 1},
1748 //    {2 , 2}}
1749 XlaOp BroadcastInDim(XlaOp operand,
1750                      const absl::Span<const int64_t> out_dim_size,
1751                      const absl::Span<const int64_t> broadcast_dimensions);
1752 
1753 // Copies the input operand to the output. This operation is for internal
1754 // purpose and is only used by the compiler for optimization purposes or to
1755 // ensure correctness. The XLA client should never have to generate this
1756 // instruction.
1757 //
1758 // Copy has two potential use cases:
1759 //
1760 // * Create a copy of the operand with a new layout.
1761 //
1762 // * Create a copy of the operand in a separately allocated buffer. This is
1763 //   necessary for some backends if the operand is a parameter or constant and
1764 //   the operand is returned within a tuple. In this case, the lifetime of the
1765 //   operand buffer must be the same as the lifetime of the output result.
1766 //   However, the lifetimes of parameters and constants are managed separately
1767 //   from the lifetime of the output result. Creating a separate copy of the
1768 //   parameter or constant buffer resolves this issue.
1769 XlaOp Copy(XlaOp operand);
1770 
1771 // Enqueues a pad operation onto the computation that pads the given value on
1772 // the edges as well as between the elements of the input. padding_config
1773 // specifies the padding amount for each dimension.
1774 XlaOp Pad(XlaOp operand, XlaOp padding_value,
1775           const PaddingConfig& padding_config);
1776 
1777 // Enqueues a pad operation in a given dimension, taking all other
1778 // dimensions as they are.
1779 XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
1780                int64_t pad_lo, int64_t pad_hi);
1781 
1782 // Enqueues an operation onto the computation that flattens the operand based
1783 // on the dimension order (major/slowest-varying to minor/fastest-varying)
1784 // given, followed by reshaping it into the shape with the given dimension
1785 // sizes (also major to minor). Conceptually, this is a limited form of
1786 // "shape casting".
1787 XlaOp Reshape(XlaOp operand, absl::Span<const int64_t> dimensions,
1788               absl::Span<const int64_t> new_sizes);
1789 
1790 // Enqueues a dynamic reshape operation. The dynamic reshape takes additional
1791 // XlaOps as sizes for the result dimension. The result dim i is a dynamic
1792 // dimension dimension if dims_are_dynamic[i] is true.
1793 XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
1794                      absl::Span<const int64_t> new_size_bounds,
1795                      const std::vector<bool>& dims_are_dynamic);
1796 
1797 // Enqueues an operation onto the computation that collapses the operand,
1798 // from first to last dimension (C order), then reshapes it to the given
1799 // dimension sizes. Conceptually, this is a limited form of "shape casting".
1800 XlaOp Reshape(XlaOp operand, absl::Span<const int64_t> new_sizes);
1801 
1802 // Enqueues a Reshape op that uses an explicit target shape.
1803 XlaOp Reshape(const Shape& shape, XlaOp operand);
1804 
1805 // `inferred_dimension` represents the output dimension that's inferred by
1806 // upper-level framework by dividing the input element count by the known
1807 // output element count. While an inferred_dimension can be static, if there
1808 // is a dynamic dimension in the output, it must be the inferred dimension.
1809 XlaOp ReshapeWithInferredDimension(XlaOp operand,
1810                                    absl::Span<const int64_t> new_sizes,
1811                                    int64_t inferred_dimension);
1812 
1813 // Wrapper for Reshape.
1814 // Enqueues an operation to collapse the provided dimensions; e.g. an
1815 // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
1816 // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
1817 // be a consecutive, in-order subsequence of the operand dimensions.
1818 //
1819 // Note that collapsing a single dimension does nothing:
1820 //
1821 //    {256} collapsing {0} => {256}
1822 //    {1} collapsing {0} => {1}
1823 //
1824 // Collapsing multiple dimensions produces a single result dimension:
1825 //
1826 //    {256, 2} collapsing {0,1} => {512}
1827 //    {256, 2, 3} collapsing {0,1} => {512, 3}
1828 //
1829 // This could potentially cause data to be moved -- it provides a more
1830 // structured form of reshaping than an arbitrary Reshape operation.
1831 XlaOp Collapse(XlaOp operand, absl::Span<const int64_t> dimensions);
1832 
1833 // Enqueues a slice operation onto the computation that slices the operand
1834 // from the start indices to the limit indices; e.g.
1835 //
1836 //        x
1837 //   [ 0 1 2 3 ]
1838 // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
1839 //   [ 8 9 a b ]
1840 //
1841 // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
1842 // range notation.
1843 // The strides parameter determines the stride over the slice
1844 XlaOp Slice(XlaOp operand, absl::Span<const int64_t> start_indices,
1845             absl::Span<const int64_t> limit_indices,
1846             absl::Span<const int64_t> strides);
1847 
1848 // Enqueues a slice operation in a given dimension, taking all other
1849 // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
1850 // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
1851 // for:
1852 //
1853 //  array[:, 2:4:1, :]
1854 XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index,
1855                  int64_t stride, int64_t dimno);
1856 
1857 // Enqueues a slice operation onto the computation that slices the 'operand'
1858 // from dynamic start indices which are passed in 'start_indices'.
1859 // The size of the slice in each dimension is passed in 'slice_sizes',
1860 // which specify the end point of exclusive slice intervals in each
1861 // dimension [start, start + size).
1862 // The shape of each element of 'start_indices' must be scalar, with the span
1863 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1864 // have the same shape.
1865 // Slice index calculations are computed modulo input dimension sizes to
1866 // prevent dynamic start indices from generating out-of-bound array accesses.
1867 XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
1868                    absl::Span<const int64_t> slice_sizes);
1869 
1870 // Enqueues a dynamic update slice operation onto the computation, which
1871 // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
1872 // The shape of 'update' determines the shape of the slice of 'operand'
1873 // which is updated.
1874 // The indices specified in 'start_indices' specify the offset of the slice
1875 // of 'operand' which is updated.
1876 //
1877 //               update = {10, 11} // calculated at runtime.
1878 //   [1 2 3]     start  = {1, 1}   // calculated at runtime.  [1 2  3 ]
1879 //   [4 5 6]  => DynamicUpdateslice(data, update, start)   => [4 10 11]
1880 //   [7 8 9]                                                  [7 8  9 ]
1881 //
1882 // The shape of each element of 'start_indices' must be scalar, with the span
1883 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1884 // have the same shape.
1885 // Slice index calculations are computed modulo update dimension sizes to
1886 // prevent dynamic start indices from generating out-of-bound array accesses.
1887 XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
1888                          absl::Span<const XlaOp> start_indices);
1889 
1890 // Enqueues a concatenate instruction onto the computation. 'operands' must
1891 // have >= 1 entry.
1892 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1893                   int64_t dimension);
1894 
1895 // Enqueues a conditional-move-like select operation onto the computation;
1896 // predicated on pred, selects between on_true and on_false.
1897 XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
1898 
1899 // Enqueues a tuple-creation instruction onto the computation.
1900 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
1901 
1902 // Enqueues a tuple-element-get instruction onto the computation.
1903 XlaOp GetTupleElement(XlaOp tuple_data, int64_t index);
1904 
1905 // Enqueues an equal-to comparison instruction onto the computation.
1906 XlaOp Eq(XlaOp lhs, XlaOp rhs,
1907          absl::Span<const int64_t> broadcast_dimensions = {});
1908 XlaOp EqTotalOrder(XlaOp lhs, XlaOp rhs,
1909                    absl::Span<const int64_t> broadcast_dimensions = {});
1910 
1911 // Enqueues a not-equal comparison instruction onto the computation.
1912 XlaOp Ne(XlaOp lhs, XlaOp rhs,
1913          absl::Span<const int64_t> broadcast_dimensions = {});
1914 XlaOp NeTotalOrder(XlaOp lhs, XlaOp rhs,
1915                    absl::Span<const int64_t> broadcast_dimensions = {});
1916 
1917 // Enqueues a greater-or-equal comparison instruction onto the computation.
1918 XlaOp Ge(XlaOp lhs, XlaOp rhs,
1919          absl::Span<const int64_t> broadcast_dimensions = {});
1920 XlaOp GeTotalOrder(XlaOp lhs, XlaOp rhs,
1921                    absl::Span<const int64_t> broadcast_dimensions = {});
1922 
1923 // Enqueues a greater-than comparison instruction onto the computation.
1924 XlaOp Gt(XlaOp lhs, XlaOp rhs,
1925          absl::Span<const int64_t> broadcast_dimensions = {});
1926 XlaOp GtTotalOrder(XlaOp lhs, XlaOp rhs,
1927                    absl::Span<const int64_t> broadcast_dimensions = {});
1928 
1929 // Enqueues a less-than comparison instruction onto the computation.
1930 XlaOp Lt(XlaOp lhs, XlaOp rhs,
1931          absl::Span<const int64_t> broadcast_dimensions = {});
1932 XlaOp LtTotalOrder(XlaOp lhs, XlaOp rhs,
1933                    absl::Span<const int64_t> broadcast_dimensions = {});
1934 
1935 // Enqueues a less-or-equal comparison instruction onto the computation.
1936 XlaOp Le(XlaOp lhs, XlaOp rhs,
1937          absl::Span<const int64_t> broadcast_dimensions = {});
1938 XlaOp LeTotalOrder(XlaOp lhs, XlaOp rhs,
1939                    absl::Span<const int64_t> broadcast_dimensions = {});
1940 
1941 // Enqueues a comparison instruction onto the computation (optionally without
1942 // broadcast_dimensions for consistency with others).
1943 XlaOp Compare(XlaOp lhs, XlaOp rhs,
1944               absl::Span<const int64_t> broadcast_dimensions,
1945               ComparisonDirection direction, Comparison::Type compare_type);
1946 XlaOp Compare(XlaOp lhs, XlaOp rhs,
1947               absl::Span<const int64_t> broadcast_dimensions,
1948               ComparisonDirection direction);
1949 XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction);
1950 
1951 // Enqueues a dot instruction onto the computation.
1952 XlaOp Dot(XlaOp lhs, XlaOp rhs,
1953           const PrecisionConfig* precision_config = nullptr,
1954           std::optional<PrimitiveType> preferred_element_type = std::nullopt);
1955 
1956 // Enqueues a general dot instruction onto the computation.
1957 XlaOp DotGeneral(
1958     XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
1959     const PrecisionConfig* precision_config = nullptr,
1960     std::optional<PrimitiveType> preferred_element_type = std::nullopt);
1961 
1962 // Enqueues a convolution instruction onto the computation, which uses the
1963 // default convolution dimension numbers.
1964 XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1965            Padding padding, int64_t feature_group_count = 1,
1966            int64_t batch_group_count = 1,
1967            const PrecisionConfig* precision_config = nullptr,
1968            std::optional<PrimitiveType> preferred_element_type = std::nullopt);
1969 
1970 // Enqueues a convolution instruction onto the computation, with the caller
1971 // provided padding configuration in the format returned by MakePadding().
1972 XlaOp ConvWithGeneralPadding(
1973     XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1974     absl::Span<const std::pair<int64_t, int64_t>> padding,
1975     int64_t feature_group_count = 1, int64_t batch_group_count = 1,
1976     const PrecisionConfig* precision_config = nullptr,
1977     std::optional<PrimitiveType> preferred_element_type = std::nullopt);
1978 
1979 // Enqueues a convolution instruction onto the computation, with the caller
1980 // provided dimension numbers configuration.
1981 XlaOp ConvWithGeneralDimensions(
1982     XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1983     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1984     int64_t feature_group_count = 1, int64_t batch_group_count = 1,
1985     const PrecisionConfig* precision_config = nullptr,
1986     std::optional<PrimitiveType> preferred_element_type = std::nullopt);
1987 
1988 // Enqueues a convolution instruction onto the computation, with the caller
1989 // provided padding configuration as well as the dimension numbers.
1990 XlaOp ConvGeneral(
1991     XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
1992     absl::Span<const std::pair<int64_t, int64_t>> padding,
1993     const ConvolutionDimensionNumbers& dimension_numbers,
1994     int64_t feature_group_count = 1, int64_t batch_group_count = 1,
1995     const PrecisionConfig* precision_config = nullptr,
1996     std::optional<PrimitiveType> preferred_element_type = std::nullopt);
1997 
1998 // Enqueues a convolution instruction onto the computation, with the caller
1999 // provided padding configuration, dilation factors and dimension numbers.
2000 XlaOp ConvGeneralDilated(
2001     XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
2002     absl::Span<const std::pair<int64_t, int64_t>> padding,
2003     absl::Span<const int64_t> lhs_dilation,
2004     absl::Span<const int64_t> rhs_dilation,
2005     const ConvolutionDimensionNumbers& dimension_numbers,
2006     int64_t feature_group_count = 1, int64_t batch_group_count = 1,
2007     const PrecisionConfig* precision_config = nullptr,
2008     std::optional<PrimitiveType> preferred_element_type = std::nullopt,
2009     std::optional<std::vector<bool>> window_reversal = std::nullopt);
2010 
2011 XlaOp DynamicConvForward(
2012     XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
2013     absl::Span<const std::pair<int64_t, int64_t>> padding,
2014     absl::Span<const int64_t> lhs_dilation,
2015     absl::Span<const int64_t> rhs_dilation,
2016     const ConvolutionDimensionNumbers& dimension_numbers,
2017     int64_t feature_group_count, int64_t batch_group_count,
2018     const PrecisionConfig* precision_config, PaddingType padding_type,
2019     std::optional<PrimitiveType> preferred_element_type = std::nullopt);
2020 
2021 XlaOp DynamicConvInputGrad(
2022     XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
2023     absl::Span<const int64_t> window_strides,
2024     absl::Span<const std::pair<int64_t, int64_t>> padding,
2025     absl::Span<const int64_t> lhs_dilation,
2026     absl::Span<const int64_t> rhs_dilation,
2027     const ConvolutionDimensionNumbers& dimension_numbers,
2028     int64_t feature_group_count, int64_t batch_group_count,
2029     const PrecisionConfig* precision_config, PaddingType padding_type,
2030     std::optional<PrimitiveType> preferred_element_type = std::nullopt);
2031 
2032 XlaOp DynamicConvKernelGrad(
2033     XlaOp activations, XlaOp gradients,
2034     absl::Span<const int64_t> window_strides,
2035     absl::Span<const std::pair<int64_t, int64_t>> padding,
2036     absl::Span<const int64_t> lhs_dilation,
2037     absl::Span<const int64_t> rhs_dilation,
2038     const ConvolutionDimensionNumbers& dimension_numbers,
2039     int64_t feature_group_count, int64_t batch_group_count,
2040     const PrecisionConfig* precision_config, PaddingType padding_type,
2041     std::optional<PrimitiveType> preferred_element_type = std::nullopt);
2042 
2043 // Enqueues an FFT instruction onto the computation, of the given type and
2044 // with the given FFT length.
2045 XlaOp Fft(XlaOp operand, FftType fft_type,
2046           absl::Span<const int64_t> fft_length);
2047 
2048 // Solves systems of linear equations with lower or upper triangular coefficient
2049 // matrices by forward- or back-substitution. Broadcasting along leading
2050 // dimensions, this routine solves for x in one of the matrix systems
2051 //   `op(a) * x = b`,  or `x * op(a) = b`,
2052 // for the variable `x` given `a` and `b`, where `op(a)` is either
2053 //   `op(a) = a`,  or `op(a) = transpose(a)`,  or `op(a) = conj(transpose(a))`.
2054 //
2055 // * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form
2056 //   square matrices. If `lower` is true (false), then the strictly upper
2057 //   (lower) triangular part of each innermost matrix in `a` is assumed to be
2058 //   zero and is not accessed.
2059 // * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a
2060 //   tensor of shape `[..., K, M]`.
2061 // * `left_side` is a boolean, indicating whether to solve a system of the form
2062 //   op(a) * x = b (true) or x * op(a) = b (false).
2063 // * `lower` is a boolean, indicating whether the argument `a` is
2064 //   lower-triangular (true) or upper-triangular (false).
2065 // * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be
2066 //   1 and not accessed.
2067 // * `transpose_a` indicates which function `op` we use to transform the tensor
2068 //   `a`: the identity function, transpose(a), or conjugate(transpose(a))
2069 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
2070                       bool unit_diagonal,
2071                       TriangularSolveOptions::Transpose transpose_a);
2072 
2073 // Computes the Cholesky decompositions of a batch of symmetric (Hermitian)
2074 // positive definite matrices.
2075 // `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the
2076 // two minor dimensions equal.
2077 // If `lower` is true, the data from the lower triangle is used; if false, the
2078 // upper triangle is used. The input data in the other triangle of the input
2079 // does not affect the output. Returns the output in the same lower/upper
2080 // triangle. The data returned in the other output triangle is arbitrary and
2081 // implementation-defined.
2082 //
2083 // If `a` is not Hermitian positive definite, returns an array full of NaNs.
2084 XlaOp Cholesky(XlaOp a, bool lower);
2085 
2086 // Enqueues an infeed instruction onto the computation, which writes data of
2087 // the given shape to the infeed buffer of the device.
2088 XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
2089              const std::string& config = "");
2090 
2091 // Variant of Infeed which takes a token-shaped operand and produces a
2092 // two-element tuple containing the data value and a token-shaped value.
2093 // Tokens are used for ordering side-effecting operations.
2094 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2095 XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
2096                       const std::string& config = "");
2097 
2098 // Enqueues an outfeed instruction onto the computation. This instruction
2099 // generates outgoing data transfers for the given data.
2100 //
2101 // shape_with_layout communicates the laid out shape that we want to outfeed
2102 // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
2103 // will occur.
2104 void Outfeed(XlaOp operand, const Shape& shape_with_layout,
2105              const std::string& outfeed_config);
2106 
2107 // Variant of Outfeed which takes a token-shaped operand and produces a
2108 // token-shaped value. Tokens are used for ordering side-effecting operations.
2109 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2110 XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
2111                        const Shape& shape_with_layout,
2112                        const std::string& outfeed_config);
2113 
2114 // Enqueues a call instruction onto the computation.
2115 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
2116            absl::Span<const XlaOp> operands);
2117 
2118 // Enqueues a custom call instruction onto the computation. A custom call
2119 // invokes code external to XLA. The |operands| are passed to the external code,
2120 // and the external code is expected to produce a result of the given
2121 // |shape|. The exact mechanism is backend-specific. For example, in the CPU
2122 // backend, a call instruction is emitted which targets a symbol with the name
2123 // |call_target_name|.  |call_target_name| and |opaque| can arbitrary strings,
2124 // but |call_target_name| should be short as it may be used in labels. |opaque|
2125 // can encode arbitrarily large amounts of information. |has_side_effect|
2126 // specifies whether the instruction can have side effects.
2127 // |output_operand_aliasing| specifies a list of output/operand buffer pairs
2128 // that alias each other, where the output buffer is represented as a
2129 // ShapeIndex, and the operand buffer is represented as the operand index and
2130 // the ShapeIndex.
2131 XlaOp CustomCall(
2132     XlaBuilder* builder, const std::string& call_target_name,
2133     absl::Span<const XlaOp> operands, const Shape& shape,
2134     const std::string& opaque = "", bool has_side_effect = false,
2135     absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
2136         output_operand_aliasing = {},
2137     const Literal* literal = nullptr,
2138     CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE,
2139     CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
2140 
2141 // Overload which constructs a custom call that applies an Xla computation.
2142 XlaOp CustomCallWithComputation(
2143     XlaBuilder* builder, const std::string& call_target_name,
2144     absl::Span<const XlaOp> operands, const XlaComputation& computation,
2145     const Shape& shape, const std::string& opaque = "",
2146     bool has_side_effect = false,
2147     absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
2148         output_operand_aliasing = {},
2149     const Literal* literal = nullptr,
2150     CustomCallSchedule = CustomCallSchedule::SCHEDULE_NONE,
2151     CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
2152 
2153 // Overload which constructs a custom call with fixed layouts. The operands will
2154 // have the layouts specified by |operand_shapes_with_layout| when provided to
2155 // external code, and the external code is expected to produce a result with the
2156 // layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
2157 // and |operand_shapes_with_layout| must have layouts.
2158 XlaOp CustomCallWithLayout(
2159     XlaBuilder* builder, const std::string& call_target_name,
2160     absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
2161     absl::Span<const Shape> operand_shapes_with_layout,
2162     const std::string& opaque = "", bool has_side_effect = false,
2163     absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
2164         output_operand_aliasing = {},
2165     const Literal* literal = nullptr,
2166     CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE,
2167     CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
2168 
2169 // Overload which annotates a custom call with the given Window and
2170 // ConvolutionDimensionNumbers.  Useful for custom-calls which represent
2171 // convolutions.
2172 //
2173 // This sets the layout of its operands if operand_shapes_with_layout is
2174 // nonempty, and it sets the layout of its result if `shape` has a layout.
2175 XlaOp CustomCallWithConvDnums(
2176     XlaBuilder* builder, const std::string& call_target_name,
2177     absl::Span<const XlaOp> operands, const Shape& shape,
2178     absl::Span<const Shape> operand_shapes_with_layout,
2179     const std::string& opaque, bool has_side_effect,
2180     absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
2181         output_operand_aliasing,
2182     const Literal* literal, Window window, ConvolutionDimensionNumbers dnums,
2183     CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE,
2184     CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
2185 
2186 // Enqueues an optimization barrier onto the computation.
2187 XlaOp OptimizationBarrier(XlaOp operand);
2188 
2189 // The following methods enqueue element-wise binary arithmetic operations
2190 // onto the computation. The shapes of the operands have to match unless one
2191 // of the operands is a scalar, or an explicit broadcast dimension is given
2192 // (see g3doc for more details).
2193 
2194 // Enqueues a complex compose instruction onto the computation.
2195 XlaOp Complex(XlaOp real, XlaOp imag,
2196               absl::Span<const int64_t> broadcast_dimensions = {});
2197 
2198 // Enqueues a complex conjugate instruction onto the computation.
2199 XlaOp Conj(XlaOp operand);
2200 
2201 // Enqueues an add instruction onto the computation.
2202 XlaOp Add(XlaOp lhs, XlaOp rhs,
2203           absl::Span<const int64_t> broadcast_dimensions = {});
2204 
2205 // Enqueues a subtract instruction onto the computation.
2206 XlaOp Sub(XlaOp lhs, XlaOp rhs,
2207           absl::Span<const int64_t> broadcast_dimensions = {});
2208 
2209 // Enqueues a multiply instruction onto the computation.
2210 XlaOp Mul(XlaOp lhs, XlaOp rhs,
2211           absl::Span<const int64_t> broadcast_dimensions = {});
2212 
2213 // Enqueues a divide instruction onto the computation.
2214 XlaOp Div(XlaOp lhs, XlaOp rhs,
2215           absl::Span<const int64_t> broadcast_dimensions = {});
2216 
2217 // Enqueues a remainder instruction onto the computation.
2218 XlaOp Rem(XlaOp lhs, XlaOp rhs,
2219           absl::Span<const int64_t> broadcast_dimensions = {});
2220 
2221 // Enqueues a max instruction onto the computation.
2222 XlaOp Max(XlaOp lhs, XlaOp rhs,
2223           absl::Span<const int64_t> broadcast_dimensions = {});
2224 
2225 // Enqueues a min instruction onto the computation.
2226 XlaOp Min(XlaOp lhs, XlaOp rhs,
2227           absl::Span<const int64_t> broadcast_dimensions = {});
2228 
2229 // Element-wise logical operators
2230 XlaOp And(XlaOp lhs, XlaOp rhs,
2231           absl::Span<const int64_t> broadcast_dimensions = {});
2232 
2233 // Overload to call And with 3 or more operands.  We need the following somewhat
2234 // convoluted overload set to disambiguate with the overload that takes the
2235 // `broadcast_dimensions` optional param.
And(XlaOp op1,XlaOp op2,XlaOp op3)2236 inline XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3) {
2237   return And(op1, And(op2, op3));
2238 }
2239 template <typename... XlaOpTs>
And(XlaOp op1,XlaOp op2,XlaOp op3,const XlaOpTs &...operands)2240 XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) {
2241   return And(op1, And(op2, And(op3, operands...)));
2242 }
2243 
2244 XlaOp Or(XlaOp lhs, XlaOp rhs,
2245          absl::Span<const int64_t> broadcast_dimensions = {});
2246 
2247 // Overload to call Or with 3 or more operands.  As with `And`, we need the
2248 // following complicated overload set to handle the default arg in the `Or`
2249 // overload above.
Or(XlaOp op1,XlaOp op2,XlaOp op3)2250 inline XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3) {
2251   return Or(op1, Or(op2, op3));
2252 }
2253 template <typename... XlaOpTs>
Or(XlaOp op1,XlaOp op2,XlaOp op3,const XlaOpTs &...operands)2254 XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) {
2255   return Or(op1, Or(op2, Or(op3, operands...)));
2256 }
2257 
2258 XlaOp Xor(XlaOp lhs, XlaOp rhs,
2259           absl::Span<const int64_t> broadcast_dimensions = {});
2260 
2261 XlaOp Not(XlaOp operand);
2262 
2263 XlaOp PopulationCount(XlaOp operand);
2264 
2265 XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs,
2266                 absl::Span<const int64_t> broadcast_dimensions = {});
2267 XlaOp ShiftRightArithmetic(XlaOp lhs, XlaOp rhs,
2268                            absl::Span<const int64_t> broadcast_dimensions = {});
2269 XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs,
2270                         absl::Span<const int64_t> broadcast_dimensions = {});
2271 
2272 // Reduces an array among the provided dimensions, given "computation" as a
2273 // reduction operator.
2274 XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation,
2275              absl::Span<const int64_t> dimensions_to_reduce);
2276 
2277 // Reduces several arrays simultaneously among the provided dimensions, given
2278 // "computation" as a reduction operator.
2279 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
2280              absl::Span<const XlaOp> init_values,
2281              const XlaComputation& computation,
2282              absl::Span<const int64_t> dimensions_to_reduce);
2283 
2284 // Convenience wrapper around the above that reduces all the dimensions in the
2285 // operand shape.
2286 XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
2287                 const XlaComputation& computation);
2288 
2289 // Enqueues a windowed reduce instruction onto the computation.
2290 XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
2291                    const XlaComputation& computation,
2292                    absl::Span<const int64_t> window_dimensions,
2293                    absl::Span<const int64_t> window_strides, Padding padding);
2294 
2295 XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
2296                    absl::Span<const XlaOp> init_values,
2297                    const XlaComputation& computation,
2298                    absl::Span<const int64_t> window_dimensions,
2299                    absl::Span<const int64_t> window_strides, Padding padding);
2300 
2301 // As ReduceWindow(), but the padding is given in the format
2302 // returned by MakePadding().
2303 XlaOp ReduceWindowWithGeneralPadding(
2304     XlaOp operand, XlaOp init_value, const XlaComputation& computation,
2305     absl::Span<const int64_t> window_dimensions,
2306     absl::Span<const int64_t> window_strides,
2307     absl::Span<const int64_t> base_dilations,
2308     absl::Span<const int64_t> window_dilations,
2309     absl::Span<const std::pair<int64_t, int64_t>> padding);
2310 XlaOp ReduceWindowWithGeneralPadding(
2311     absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
2312     const XlaComputation& computation,
2313     absl::Span<const int64_t> window_dimensions,
2314     absl::Span<const int64_t> window_strides,
2315     absl::Span<const int64_t> base_dilations,
2316     absl::Span<const int64_t> window_dilations,
2317     absl::Span<const std::pair<int64_t, int64_t>> padding);
2318 
2319 // Returns the sum of the operand value within each subgroup of replicas. All
2320 // replicas supply one input to the sum and all replicas receive the resulting
2321 // sum for each subgroup.
2322 XlaOp CrossReplicaSum(XlaOp operand,
2323                       absl::Span<const ReplicaGroup> replica_groups = {});
2324 
2325 XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension,
2326                 int64_t shard_count,
2327                 absl::Span<const ReplicaGroup> replica_groups = {},
2328                 const std::optional<ChannelHandle>& channel_id = std::nullopt,
2329                 const std::optional<Layout>& layout = std::nullopt,
2330                 const std::optional<bool> use_global_device_ids = std::nullopt);
2331 
2332 // Enqueues an operation that do an AllReduce of the operand cross cores. Here
2333 // AllReduce means doing a reduction on the input operand cross cores and then
2334 // broadcasting the reduction result to those cores. The reduction function is
2335 // defined by `computation`, which should be a commutative computation on
2336 // scalars, e.g., add, min, or max. The way that AllReduce is applied is
2337 // configured by:
2338 //
2339 // - `replica_groups`: each ReplicaGroup contains a list of replica id. If
2340 // empty, all replicas belong to one group. Allreduce will be applied within
2341 // subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}}
2342 // means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
2343 //
2344 // - `channel_id`: for Allreduce nodes from different modules, if they have the
2345 // same channel_id, they will be 'AllReduce'd. If empty, AllReduce will not be
2346 // applied cross modules.
2347 //
2348 // - `shape_with_layout`: forces the layout of the AllReduce to the given
2349 // layout. This is used to guarantee the same layout for a group of AllReduce
2350 // ops compiled separately.
2351 XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
2352                 absl::Span<const ReplicaGroup> replica_groups = {},
2353                 const std::optional<ChannelHandle>& channel_id = std::nullopt,
2354                 const std::optional<Shape>& shape_with_layout = std::nullopt,
2355                 const std::optional<bool> use_global_device_ids = std::nullopt);
2356 
2357 XlaOp ReduceScatter(
2358     XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension,
2359     int64_t shard_count, absl::Span<const ReplicaGroup> replica_groups = {},
2360     const std::optional<ChannelHandle>& channel_id = std::nullopt,
2361     const std::optional<Layout>& layout = std::nullopt,
2362     const std::optional<bool> use_global_device_ids = std::nullopt);
2363 
2364 // Enqueues an operation that do an Alltoall of the operand cross cores.
2365 // An optional `layout` can be specified to force the layout of the instruction.
2366 // This is used to guarantee the same layout for a group of AllToAll ops
2367 // compiled separately.
2368 XlaOp AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension,
2369                int64_t split_count,
2370                absl::Span<const ReplicaGroup> replica_groups = {},
2371                const std::optional<Layout>& layout = std::nullopt);
2372 
2373 XlaOp AllToAllTuple(absl::Span<const XlaOp> operand,
2374                     absl::Span<const ReplicaGroup> replica_groups = {},
2375                     const std::optional<Layout>& layout = std::nullopt);
2376 
2377 XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension,
2378                     int64_t concat_dimension, int64_t split_count,
2379                     absl::Span<const ReplicaGroup> replica_groups = {},
2380                     const std::optional<Layout>& layout = std::nullopt);
2381 
2382 // Enqueues an collective operation that sends and receives data cross replicas.
2383 //
2384 // - `source_target_pair`: a list of (source_replica_id, target_replica_id)
2385 // pairs. For each pair, the operand is sent from source replica to target
2386 // replica. Note that, 1) any two pairs should not have the same target replica
2387 // id, and they should not have the same source replica id; 2) if a replica id
2388 // is not a target in any pair, then the output on that replica is a tensor
2389 // consists of 0(s) with the same shape as the input.
2390 XlaOp CollectivePermute(
2391     XlaOp operand,
2392     const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs);
2393 
2394 // Enqueues an operation that returns the replica ID.
2395 XlaOp ReplicaId(XlaBuilder* builder);
2396 
2397 // Enqueues an operation that scatters the `source` array to the selected
2398 // indices of each window.
2399 XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
2400                        absl::Span<const int64_t> window_dimensions,
2401                        absl::Span<const int64_t> window_strides,
2402                        Padding padding, XlaOp source, XlaOp init_value,
2403                        const XlaComputation& scatter);
2404 
2405 // As SelectAndScatter(), but the padding is given in the format
2406 // returned by MakePadding().
2407 XlaOp SelectAndScatterWithGeneralPadding(
2408     XlaOp operand, const XlaComputation& select,
2409     absl::Span<const int64_t> window_dimensions,
2410     absl::Span<const int64_t> window_strides,
2411     absl::Span<const std::pair<int64_t, int64_t>> padding, XlaOp source,
2412     XlaOp init_value, const XlaComputation& scatter);
2413 
2414 // Enqueues an abs instruction onto the computation.
2415 XlaOp Abs(XlaOp operand);
2416 
2417 // Enqueues a atan2 instruction onto the computation.
2418 XlaOp Atan2(XlaOp y, XlaOp x,
2419             absl::Span<const int64_t> broadcast_dimensions = {});
2420 
2421 // Enqueues an exp instruction onto the computation.
2422 XlaOp Exp(XlaOp operand);
2423 
2424 // Enqueues an expm1 instruction onto the computation.
2425 XlaOp Expm1(XlaOp operand);
2426 
2427 // Enqueues a floor instruction onto the computation.
2428 XlaOp Floor(XlaOp operand);
2429 
2430 // Enqueues a ceil instruction onto the computation.
2431 XlaOp Ceil(XlaOp operand);
2432 
2433 // Enqueues a round instruction onto the computation,
2434 // with half-way cases rounding away from zero.
2435 XlaOp Round(XlaOp operand);
2436 
2437 // Enqueues a round instruction onto the computation, rounding to nearest even
2438 XlaOp RoundNearestEven(XlaOp operand);
2439 
2440 // Enqueues an log instruction (natural logarithm) onto the computation.
2441 XlaOp Log(XlaOp operand);
2442 
2443 // Enqueues an log1p instruction (log(x+1)) onto the computation.
2444 XlaOp Log1p(XlaOp operand);
2445 
2446 // Enqueues a logistic instruction onto the computation.
2447 XlaOp Logistic(XlaOp operand);
2448 
2449 // Enqueues a sign instruction onto the computation.
2450 XlaOp Sign(XlaOp operand);
2451 
2452 // Enqueues a count leading zeros instruction onto the computation.
2453 XlaOp Clz(XlaOp operand);
2454 
2455 // Enqueues a cosine instruction onto the computation.
2456 XlaOp Cos(XlaOp operand);
2457 
2458 // Enqueues a sine instruction onto the computation.
2459 XlaOp Sin(XlaOp operand);
2460 
2461 // Enqueues a tanh instruction onto the computation.
2462 XlaOp Tanh(XlaOp operand);
2463 
2464 // Enqueues a real-part instruction onto the computation.
2465 XlaOp Real(XlaOp operand);
2466 
2467 // Enqueues an imaginary-part instruction onto the computation.
2468 XlaOp Imag(XlaOp operand);
2469 
2470 // Enqueues a sqrt computation onto the computation.
2471 XlaOp Sqrt(XlaOp operand);
2472 
2473 // Enqueues a cbrt computation onto the computation.
2474 XlaOp Cbrt(XlaOp operand);
2475 
2476 // Enqueues a rsqrt computation onto the computation.
2477 XlaOp Rsqrt(XlaOp operand);
2478 
2479 // Enqueues a lhs^rhs computation onto the computation.
2480 XlaOp Pow(XlaOp lhs, XlaOp rhs,
2481           absl::Span<const int64_t> broadcast_dimensions = {});
2482 
2483 // Enqueues an operator that tests if the operand's values are finite, i.e., not
2484 // +/-Inf or NaN.  Returns an array of booleans with the same shape where
2485 // entries are true iff the corresponding entry was not infinite or NaN.
2486 //
2487 // Defined only for real-valued (i.e. not complex) floating-point types; raises
2488 // an error for other types.
2489 //
2490 // See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h.
2491 XlaOp IsFinite(XlaOp operand);
2492 
2493 // Enqueues an iota operation onto the computation.
2494 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64_t iota_dimension);
2495 
2496 // Enqueues a rank-1 iota operation onto the computation.
2497 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size);
2498 
2499 // Enqueues a convert instruction onto the computation that changes the
2500 // element type of the operand array to primitive_type.
2501 XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
2502 
2503 // Enqueues a no-op instruction onto the computation that changes
2504 // the element type of the operand array to primitive_type. The
2505 // bit-widths of the source and destination element types must be
2506 // identical.
2507 XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
2508 
2509 // Enqueues a negate instruction onto the computation.
2510 XlaOp Neg(XlaOp operand);
2511 
2512 // Enqueues a transpose instruction onto the computation.
2513 XlaOp Transpose(XlaOp operand, absl::Span<const int64_t> permutation);
2514 
2515 // Enqueues a reverse instruction onto the computation. The order of the
2516 // elements in the given dimensions is reversed (i.e., the element at index i
2517 // is moved to index dimension_size - 1 - i).
2518 XlaOp Rev(XlaOp operand, absl::Span<const int64_t> dimensions);
2519 
2520 // Enqueues a sort instruction onto the computation, using 'comparator' for
2521 // comparisons. 'comparator' needs to define a strict weak order. 'is_stable'
2522 // determines whether the stable sorting should be used.
2523 // If only one operand is provided:
2524 // * If the operand is a rank-1 tensor (an array), the result is a sorted array.
2525 //   The resulting sorting order has the property that for all index positions
2526 //   i, j with i < j, either
2527 //   comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or
2528 //   comparator(value[i], value[j]) = true.
2529 // * If the operand has higher rank, the operand is sorted along the provided
2530 //   dimension. For example, for a rank-2 tensor (a matrix), a dimension value
2531 //   of 0 will independently sort every column, and a dimension value of 1 will
2532 //   independently sort each row. If no dimension number is provided, then the
2533 //   last dimension is chosen by default. For the dimension which is sorted, the
2534 //   same sorting order applies as in the rank-1 case.
2535 //
2536 // If more than one operand is provided:
2537 // * All operands must be tensors with the same dimensions. The element types of
2538 //   the tensors may be different.
2539 // * The result is a tuple that consists of the operands in sorted order (along
2540 //   the provided dimension, as above). The same permutation as implied by the
2541 //   comparison computation is applied to all operand tensors. When comparing
2542 //   two index positions, 'comparator' is called with 2 * n scalar parameters,
2543 //   where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at
2544 //   two index positions.
2545 // Default comparator computations can be found in lib/comparators.h
2546 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
2547            int64_t dimension = -1, bool is_stable = false);
2548 
2549 // Enqueues a clamp instruction onto the computation.
2550 XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
2551 
2552 // Enqueues a map instruction onto the computation.
2553 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
2554           const XlaComputation& computation,
2555           absl::Span<const int64_t> dimensions,
2556           absl::Span<const XlaOp> static_operands = {});
2557 
2558 // Enqueues a N(mu, sigma) random number generation instruction onto the
2559 // computation.
2560 XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
2561 
2562 // Enqueues a U(a, b) random number generation instruction onto the
2563 // computation. Returns values in the semi-open interval [a, b).
2564 XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
2565 
2566 // Enqueues a B(initial_state) random bit generation instruction onto the
2567 // computation. Returns the new key and random bits with the specified shape.
2568 XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
2569                       const Shape& shape);
2570 
2571 // Enqueues a while node onto the computation.
2572 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
2573             XlaOp init);
2574 
2575 // Enqueues a conditional node onto the computation.
2576 XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
2577                   const XlaComputation& true_computation, XlaOp false_operand,
2578                   const XlaComputation& false_computation);
2579 
2580 // Enqueues either a predicated (if/else) or indexed (switch/case/default)
2581 // conditional node onto the computation. N >= 1 branch_computations and
2582 // branch_operands are matched by index. branch_index selects the branch that
2583 // will be executed. Out of range branch_index uses the N-1'th
2584 // branch_computation as default.
2585 XlaOp Conditional(XlaOp branch_index,
2586                   absl::Span<const XlaComputation* const> branch_computations,
2587                   absl::Span<const XlaOp> branch_operands);
2588 
2589 // Enqueues a ReducePrecision node onto the computation.
2590 XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
2591                       const int mantissa_bits);
2592 
2593 // Enqueues a Gather node onto the computation.
2594 XlaOp Gather(XlaOp input, XlaOp start_indices,
2595              const GatherDimensionNumbers& dimension_numbers,
2596              absl::Span<const int64_t> slice_sizes,
2597              bool indices_are_sorted = false);
2598 
2599 // Enqueues a Scatter node onto the computation.
2600 XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
2601               const XlaComputation& update_computation,
2602               const ScatterDimensionNumbers& dimension_numbers,
2603               bool indices_are_sorted = false, bool unique_indices = false);
2604 XlaOp Scatter(absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
2605               absl::Span<const XlaOp> updates,
2606               const XlaComputation& update_computation,
2607               const ScatterDimensionNumbers& dimension_numbers,
2608               bool indices_are_sorted = false, bool unique_indices = false);
2609 
2610 // Enqueues a Send node onto the computation for device-to-device
2611 // communication. This operation sends the given operand to
2612 // a Recv instruction in a different computation that shares the same channel
2613 // handle.
2614 void Send(XlaOp operand, const ChannelHandle& handle);
2615 
2616 // Variant of Send which takes a token-shaped operand and produces a
2617 // token-shaped value.  Tokens are used for ordering side-effecting operations.
2618 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2619 XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle);
2620 
2621 // Enqueues a Recv node onto the computation for device-to-device
2622 // communication. The data comes from a Send instruction in a different
2623 // computation that shares the same channel handle and its shape must be the
2624 // same as the given shape.
2625 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
2626            const ChannelHandle& handle);
2627 
2628 // Variant of Recv which takes a token-shaped operand and produces a two-element
2629 // tuple containing the data value and a token-shaped value. Tokens are used
2630 // for ordering side-effecting operations.
2631 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2632 XlaOp RecvWithToken(XlaOp token, const Shape& shape,
2633                     const ChannelHandle& handle);
2634 
2635 // Enqueues a Send node which transfers data from the device to the host. The
2636 // 'shape_with_layout' argument defines the layout of the data transferred; its
2637 // shape must be compatible with the shape of the operand. The operand must be
2638 // array-shaped.
2639 // TODO(b/111544877): Support tuple shapes.
2640 XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout,
2641                  const ChannelHandle& handle);
2642 
2643 // Enqueues a Recv node which transfers data from the host to the device. The
2644 // given shape must contain a layout and must be an array.
2645 // TODO(b/111544877): Support tuple shapes.
2646 XlaOp RecvFromHost(XlaOp token, const Shape& shape,
2647                    const ChannelHandle& handle);
2648 
2649 // Enqueues an operation (AfterAll) with no operands that produces a
2650 // token-shaped value.  Tokens are used for ordering side-effecting operations.
2651 // This is a separate method from AfterAll to facility the removal of
2652 // operand-less AfterAll instructions.
2653 // TODO(b/110532604): Remove this function when all tokens are derived from a
2654 // single token generated or passed into the entry computation.
2655 XlaOp CreateToken(XlaBuilder* builder);
2656 
2657 // Enqueues an AfterAll instruction which produces a token-shaped value and
2658 // takes a variadic number of token-shaped operands. The number of operands must
2659 // be greater than zero. Used for joining tokens.
2660 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
2661 
2662 // Normalizes operand across spatial and batch dimensions for each feature.
2663 //
2664 // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
2665 // is the normalized result and batch_mean and batch_var are the mean and
2666 // variance, respectively, across batch for the operand.
2667 XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, float epsilon,
2668                         int64_t feature_index);
2669 
2670 // Normalizes operand across spatial and batch dimensions for each feature.
2671 //
2672 // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
2673 // computing `mean` and `variance` for each batch inside the operation. It
2674 // uses the input `mean` and `variance` instead as estimated values. The
2675 // purpose of this op is to reduce latency in inference, hence the name
2676 // `BatchNormInference`.
2677 //
2678 // The output has the same shape as `operand`, and contains the normalized
2679 // values for each batch.
2680 XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean,
2681                          XlaOp variance, float epsilon, int64_t feature_index);
2682 
2683 // Calculates the gradients of a batch norm op.
2684 //
2685 // The inputs `batch_mean` and `batch_var` represent the mean and variance
2686 // across the batch.
2687 //
2688 // Returns a tuple of three elements:
2689 //   - grad_operand: Gradient with respect to input `operand`
2690 //   - grad_offset: Gradient with respect to input `offset`
2691 //   - grad_scale: Gradient with respect to input `scale`
2692 XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
2693                     XlaOp batch_var, XlaOp grad_output, float epsilon,
2694                     int64_t feature_index);
2695 
2696 // Returns the size of the given dimension of the operand. The operand must be
2697 // array shaped.
2698 XlaOp GetDimensionSize(XlaOp operand, int64_t dimension);
2699 
2700 XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension);
2701 
2702 // Returns the same op but with dynamic dimension removed.
2703 XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension);
2704 
2705 // Implementation details below this point.
2706 //
2707 
2708 // Free function template implementations.
2709 
2710 template <typename NativeT>
ConstantR0(XlaBuilder * builder,NativeT value)2711 XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
2712   return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
2713 }
2714 
2715 template <typename NativeT>
ConstantR1(XlaBuilder * builder,absl::Span<const NativeT> values)2716 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
2717   BorrowingLiteral literal(
2718       reinterpret_cast<const char*>(values.begin()),
2719       ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
2720                            {static_cast<int64_t>(values.size())}));
2721   return ConstantLiteral(builder, literal);
2722 }
2723 
2724 template <typename NativeT>
ConstantR1(XlaBuilder * builder,int64_t length,NativeT value)2725 XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value) {
2726   Literal literal(ShapeUtil::MakeShape(
2727       primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
2728   literal.PopulateWithValue(value);
2729   return ConstantLiteral(builder, literal);
2730 }
2731 
ConstantR1(XlaBuilder * builder,const tensorflow::core::Bitmap & values)2732 inline XlaOp ConstantR1(XlaBuilder* builder,
2733                         const tensorflow::core::Bitmap& values) {
2734   return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
2735 }
2736 
2737 template <typename NativeT>
ConstantR2(XlaBuilder * builder,std::initializer_list<std::initializer_list<NativeT>> values)2738 XlaOp ConstantR2(XlaBuilder* builder,
2739                  std::initializer_list<std::initializer_list<NativeT>> values) {
2740   return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
2741 }
2742 
2743 template <typename NativeT>
ConstantFromArrayWithLayout(XlaBuilder * builder,const Array<NativeT> & values,const Layout & layout)2744 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
2745                                   const Array<NativeT>& values,
2746                                   const Layout& layout) {
2747   return ConstantLiteral(
2748       builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
2749 }
2750 
2751 template <typename NativeT>
ConstantFromArray(XlaBuilder * builder,const Array<NativeT> & values)2752 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
2753   return ConstantLiteral(builder,
2754                          LiteralUtil::CreateFromArray<NativeT>(values));
2755 }
2756 
2757 template <typename NativeT>
ConstantR2FromArray2DWithLayout(XlaBuilder * builder,const Array2D<NativeT> & values,const Layout & layout)2758 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
2759                                       const Array2D<NativeT>& values,
2760                                       const Layout& layout) {
2761   return ConstantLiteral(
2762       builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
2763 }
2764 
2765 template <typename NativeT>
ConstantR2FromArray2D(XlaBuilder * builder,const Array2D<NativeT> & values)2766 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
2767                             const Array2D<NativeT>& values) {
2768   return ConstantLiteral(builder,
2769                          LiteralUtil::CreateR2FromArray2D<NativeT>(values));
2770 }
2771 
2772 template <typename NativeT>
ConstantR3FromArray3DWithLayout(XlaBuilder * builder,const Array3D<NativeT> & values,const Layout & layout)2773 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
2774                                       const Array3D<NativeT>& values,
2775                                       const Layout& layout) {
2776   return ConstantLiteral(
2777       builder,
2778       LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
2779 }
2780 
2781 template <typename NativeT>
ConstantR3FromArray3D(XlaBuilder * builder,const Array3D<NativeT> & values)2782 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
2783                             const Array3D<NativeT>& values) {
2784   return ConstantFromArray(builder, values);
2785 }
2786 
2787 template <typename NativeT>
ConstantR4FromArray4DWithLayout(XlaBuilder * builder,const Array4D<NativeT> & values,const Layout & layout)2788 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
2789                                       const Array4D<NativeT>& values,
2790                                       const Layout& layout) {
2791   return ConstantFromArrayWithLayout(builder, values, layout);
2792 }
2793 
2794 template <typename NativeT>
ConstantR4FromArray4D(XlaBuilder * builder,const Array4D<NativeT> & values)2795 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
2796                             const Array4D<NativeT>& values) {
2797   return ConstantFromArray(builder, values);
2798 }
2799 
2800 // Switches from automatic SPMD partitioning to manual partitioning. Converts a
2801 // full-shaped tensor (to be automatically partitioned by SPMD partitioner) to a
2802 // shard-shaped tensor to be consumed by manually partitioned ops.
2803 StatusOr<xla::XlaOp> ConvertSpmdFullToShardShape(
2804     xla::XlaBuilder* builder, xla::XlaOp input, int single_dim,
2805     const xla::OpSharding& manual_sharding,
2806     absl::Span<const int64_t> unspecified_dims);
2807 
2808 // Switches from manual partitioning to automatic SPMD partitioning. Converts a
2809 // shard-shaped tensor (manually partitioned in SPMD-style) to a full-shaped
2810 // tensor to be partitioned automatically by the SPMD partitioner.
2811 StatusOr<xla::XlaOp> ConvertSpmdShardToFullShape(
2812     xla::XlaBuilder* builder, xla::XlaOp input, const xla::Shape& output_shape,
2813     int single_dim, const xla::OpSharding& manual_sharding,
2814     absl::Span<const int64_t> unspecified_dims);
2815 
2816 }  // namespace xla
2817 
2818 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
2819