• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
18 
19 #include <map>
20 #include <string>
21 #include <type_traits>
22 #include <utility>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/client/padding.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/comparison_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/literal_util.h"
33 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
34 #include "tensorflow/compiler/xla/service/hlo.pb.h"
35 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
36 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/stacktrace.h"
44 #include "tensorflow/core/platform/types.h"
45 
46 namespace xla {
47 
48 class XlaBuilder;
49 class XlaOp;
50 class HloInstruction;
51 
52 namespace internal {
53 
54 struct XlaBuilderFriend {
55   static XlaOp BuildFusion(XlaBuilder* builder,
56                            absl::Span<const XlaOp> operands,
57                            absl::string_view fusion_kind,
58                            const XlaComputation& fused_computation);
59 
60   static XlaOp BuildBitcast(XlaBuilder* builder, XlaOp operand,
61                             const Shape& shape);
62 
63   static HloInstructionProto* GetInstruction(XlaOp op);
64 };
65 
66 }  // namespace internal
67 
68 // This represents an instruction that has been enqueued using the XlaBuilder.
69 // This is used to pass to subsequent computations that depends upon the
70 // instruction as an operand.
71 class XlaOp {
72  public:
XlaOp()73   XlaOp() : handle_(-1), builder_(nullptr) {
74     static_assert(std::is_trivially_destructible<XlaOp>::value,
75                   "XlaOp should be trivially destructible");
76   }
77   ~XlaOp() = default;
78 
79   XlaOp(const XlaOp& other) = default;
80   XlaOp& operator=(const XlaOp& other) = default;
81 
82   // Precondition: !IsUninitialized().
83   //
84   // It's very common to do foo.builder()->bar().  Without this precondition, if
85   // foo.builder() is null, the call to bar will segfault at some point possibly
86   // deep in the callstack when we finally dereference `this`.  The precondition
87   // lets us avoid this tricky-to-debug problem.
builder()88   XlaBuilder* builder() const {
89     CHECK(builder_ != nullptr);
90     return builder_;
91   }
92 
93   // Returns true if the XlaOp represents valid, non-erroneous value.
valid()94   bool valid() const { return handle_ >= 0; }
95 
96   // Returns true if the XlaOp was created by the XlaOp() constructor and
97   // not returned by a builder.
IsUninitialized()98   bool IsUninitialized() const { return builder_ == nullptr; }
99 
IsIdenticalTo(XlaOp rhs)100   bool IsIdenticalTo(XlaOp rhs) const {
101     return handle_ == rhs.handle_ && builder_ == rhs.builder_;
102   }
103 
104   friend std::ostream& operator<<(std::ostream& out, XlaOp op) {
105     out << op.handle();
106     return out;
107   }
108 
109  private:
XlaOp(XlaBuilder * builder)110   explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {}
XlaOp(int64_t handle,XlaBuilder * builder)111   XlaOp(int64_t handle, XlaBuilder* builder)
112       : handle_(handle), builder_(builder) {}
113 
handle()114   int64 handle() const { return handle_; }
115 
116   friend class XlaBuilder;
117   friend class ValueInference;
118   friend class MlirHloBuilder;
119   friend struct internal::XlaBuilderFriend;
120 
121   // < 0 means "invalid handle".
122   int64 handle_;
123 
124   // Not owned. Non-null for any handle returned by XlaBuilder, even if the
125   // handle is invalid.
126   XlaBuilder* builder_;
127 };
128 
129 // Arithmetic operator overloads for the XlaOp type.
130 XlaOp operator-(XlaOp x);
131 XlaOp operator+(XlaOp x, XlaOp y);
132 XlaOp operator-(XlaOp x, XlaOp y);
133 XlaOp operator*(XlaOp x, XlaOp y);
134 XlaOp operator/(XlaOp x, XlaOp y);
135 XlaOp operator%(XlaOp x, XlaOp y);
136 
137 // Bitwise operator overloads for the XlaOp type.
138 XlaOp operator~(XlaOp x);
139 XlaOp operator&(XlaOp x, XlaOp y);
140 XlaOp operator|(XlaOp x, XlaOp y);
141 XlaOp operator^(XlaOp x, XlaOp y);
142 XlaOp operator<<(XlaOp x, XlaOp y);
143 // Performs a right arithmetic shift if 'x' is a signed type, otherwise performs
144 // a right logical shift.
145 XlaOp operator>>(XlaOp x, XlaOp y);
146 
147 // We don't overload the relational operators (==, !=, <, <=, >, >=) because the
148 // semantics might be surprising since their result types are usually 'bool'.
149 // Further programmers may expect == to be a structural equality.
150 // We also choose not to overload any of the mutating operators (e.g., +=, -=)
151 // because the semantics might be misleading — XLA computations are immutable.
152 
153 // A convenient interface for building up computations.
154 //
155 // Thread-compatible.
156 class XlaBuilder {
157  public:
158   // computation_name: name to use for the built computation.
159   XlaBuilder(const string& computation_name);
160 
161   XlaBuilder(const XlaBuilder&) = delete;
162   XlaBuilder& operator=(const XlaBuilder&) = delete;
163 
164   virtual ~XlaBuilder();
165 
166   // Returns the computation name.
name()167   const string& name() const { return name_; }
168 
169   // Sets OpMetadata that will be added to all instructions until cleared.
170   //
171   // OpMetadata is often applied to a series of XLA HLO instructions. As a
172   // result, OpMetadata is set on the computation builder. All subsequent
173   // instructions generated via this computation builder will have the same
174   // OpMetadata attached until a call to ClearOpMetadata.
SetOpMetadata(OpMetadata metadata)175   void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); }
176 
177   // Swaps the passed op metadata with the ones currently set.
178   //
179   // Returns the old op metadata.
SwapOpMetadata(OpMetadata metadata)180   OpMetadata SwapOpMetadata(OpMetadata metadata) {
181     OpMetadata old_metadata = std::move(metadata_);
182     metadata_ = std::move(metadata);
183     return old_metadata;
184   }
185 
186   // Similar to SetOpMetadata, but only set the metadata for the next op.
SetOneShotOpMetadata(OpMetadata metadata)187   void SetOneShotOpMetadata(OpMetadata metadata) {
188     one_shot_metadata_ = std::move(metadata);
189   }
190 
191   // Clears the HloMetadata state.
ClearOpMetadata()192   void ClearOpMetadata() { metadata_.Clear(); }
193 
194   // Sets an OpSharding that will be attached to all instructions until cleared.
SetSharding(const OpSharding & sharding)195   void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
196 
197   // Sets the FrontendAttributes that will be added to all instructions until
198   // cleared.
199   //
200   // FrontendAttributes are often applied to a series of XLA HLO instructions.
201   // As a result they are set on the computation builder and all the
202   // instructions generated via the computation builder will have the same
203   // frontend attributes attached to them.
SetFrontendAttributes(const FrontendAttributes & frontend_attributes)204   void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) {
205     frontend_attributes_ = frontend_attributes;
206   }
207 
208   // Swap the passed FrontendAttributes with the ones currently set.
209   //
210   // Return the old attributes.
SwapFrontendAttributes(const FrontendAttributes & frontend_attributes)211   FrontendAttributes SwapFrontendAttributes(
212       const FrontendAttributes& frontend_attributes) {
213     FrontendAttributes old_attributes = std::move(frontend_attributes_);
214     frontend_attributes_ = frontend_attributes;
215     return old_attributes;
216   }
217 
218   // Returns the FrontendAttributes that will be attached to all instructions.
frontend_attributes()219   const FrontendAttributes& frontend_attributes() const {
220     return frontend_attributes_;
221   }
222 
223   // Clears all the frontend attributes.
ClearFrontendAttributes()224   void ClearFrontendAttributes() { frontend_attributes_.Clear(); }
225 
226   // Clears the sharding. Ops will be sharded according to the default placement
227   // policy.
ClearSharding()228   void ClearSharding() { sharding_ = absl::nullopt; }
229 
230   // Returns the OpSharding that will be attached to all instructions.
sharding()231   const absl::optional<OpSharding>& sharding() const { return sharding_; }
232 
233   // Sets the builder to a mode where it will die immediately when an error is
234   // encountered, rather than producing it in a deferred fashion when Build() is
235   // called (which is the default).
set_die_immediately_on_error(bool enabled)236   void set_die_immediately_on_error(bool enabled) {
237     die_immediately_on_error_ = enabled;
238   }
239 
240   // Default dimension numbers used for a 2D convolution.
241   static constexpr int64_t kConvBatchDimension = 0;
242   static constexpr int64_t kConvFeatureDimension = 1;
243   static constexpr int64_t kConvFirstSpatialDimension = 2;
244   static constexpr int64_t kConvSecondSpatialDimension = 3;
245   static constexpr int64_t kConvKernelOutputDimension = 0;
246   static constexpr int64_t kConvKernelInputDimension = 1;
247   static constexpr int64_t kConvKernelFirstSpatialDimension = 2;
248   static constexpr int64_t kConvKernelSecondSpatialDimension = 3;
249 
250   // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
251   // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
252   // the kernel operand
253   // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
254   static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
255       int num_spatial_dims = 2);
256 
257   // Returns an error if the convolution dimension numbers have conflicts.
258   static Status Validate(const ConvolutionDimensionNumbers& dnum);
259 
260   // Returns a new XlaBuilder whose resultant Computation is used only by this
261   // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
262   // behavior as the parent.
263   std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
264 
265   // Builds the computation with the requested operations, or returns a non-ok
266   // status. Note that all ops that have been enqueued will be moved to the
267   // computation being returned. The root of the computation will be the last
268   // added operation.
269   //
270   // `remove_dynamic_dimensions` tells the builder whether to remove the
271   // dynamic dimensions information in all ops.
272   //
273   // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the
274   // dynamic dimensions information when XLA backend can handle dynamic
275   // dimensions.
276   StatusOr<XlaComputation> Build(bool remove_dynamic_dimensions = false);
277 
278   // Overload of Build which specifies a particular root instruction for the
279   // computation.
280   StatusOr<XlaComputation> Build(XlaOp root,
281                                  bool remove_dynamic_dimensions = false);
282 
283   // Builds the computation with the requested operations, or notes an error in
284   // the parent XlaBuilder and returns an empty computation if building failed.
285   // This function is intended to be used where the returned XlaComputation is
286   // only used by the parent XlaBuilder and hence further operation on the
287   // returned XlaComputation will simply be error'ed out if an error occurred
288   // while building this computation. If the built computation is to be used by
289   // a XlaBuilder other than the parent XlaBuilder then Build() should be used
290   // instead.
291   XlaComputation BuildAndNoteError();
292 
293   // Returns a subgraph that roots on the given root. If the root is not a
294   // compile-time constant (see `IsConstant`), returns an error.
295   //
296   // This will copy the needed ops/computations to the subgraph.
297   StatusOr<XlaComputation> BuildConstantSubGraph(
298       XlaOp root_op, bool dynamic_dimension_is_uint_max = false);
299 
300   // Returns the first error that was encountered while building the
301   // computation. When an error is encountered, by default we return a vacuous
302   // XlaOp and inform the user of the error that occurred while
303   // building the computation when they make a final call to Build().
304   //
305   // See also set_die_immediately_on_error().
first_error()306   Status first_error() const { return first_error_; }
307 
308   // Returns the current status of the builder, complete with the stack trace
309   // information.
310   Status GetCurrentStatus() const;
311 
312   // Returns the shape of the given op.
313   StatusOr<Shape> GetShape(XlaOp op) const;
314 
315   // Returns the shape of the given op.
316   virtual StatusOr<const Shape*> GetShapePtr(XlaOp op) const;
317 
318   // Returns the (inferred) result for the current computation's shape. This
319   // assumes the root instruction is the last added instruction.
320   StatusOr<ProgramShape> GetProgramShape() const;
321 
322   // Returns the (inferred) result for the current computation's shape using the
323   // given operation as the root.
324   StatusOr<ProgramShape> GetProgramShape(XlaOp root) const;
325 
326   // Reports an error to the builder, by
327   // * storing it internally and capturing a backtrace if it's the first error
328   //   (this deferred value will be produced on the call to
329   //    Build()/GetShape()/...)
330   // * dying if die_immediately_on_error_ is true.
331   // Returns an XlaOp with an invalid handle but a valid builder. This value can
332   // be returned in place of a value in APIs that return an XlaOp.
333   XlaOp ReportError(const Status& error);
334 
335   // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
336   // If the Status was an error, reports the error to builder and returns an
337   // invalid XlaOp handle.
338   XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
339 
340   // A helper function that runs a function that returns a StatusOr<XlaOp> and
341   // returns an XlaOp.
342   XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
343 
344   // Returns true if 'operand' is a compile-time constant. A compile-time
345   // constant does not depend on any parameters, or on stateful operators such
346   // as `RngNormal` or `Infeed`.
347   //
348   // This tests whether a computation is a compile-time constant without
349   // evaluating the computation.
350   StatusOr<bool> IsConstant(XlaOp operand) const;
351 
352   // Sets up binding which indicates that the `target_dim_num` in the subshape
353   // `target_param_index` of parameter `target_param_num` is a dynamic dimension
354   // and its real dynamic size is represented by `dynamic_param_index` in
355   // parameter `dynamic_param_num`.
356   //
357   // Note that this should be called before the dynamic parameters are used to
358   // create other operations, otherwise created operations won't have the
359   // dynamic dimensions information.
360   //
361   // TODO(b/119520625): Remove this API once we have more dynamic shape infra
362   // ready.
363   ABSL_DEPRECATED("Use SetDimensionSize to set a dynamic dimension.")
364   Status SetDynamicBinding(int64_t dynamic_size_param_num,
365                            ShapeIndex dynamic_size_param_index,
366                            int64_t target_param_num,
367                            ShapeIndex target_param_index,
368                            int64_t target_dim_num);
369 
370   // Adds a new input/output alias. Since the input/output shape information are
371   // not available until the computation is built, and eventual error in the
372   // arguments of this API will be detected only at computation Build() time.
373   //
374   // Note: Except when 'must-alias' is true, alias is assumed to be 'may-alias'
375   // and only donated buffer at runtime will be aliased with output. If a buffer
376   // is not donated at runtime, a copy will be inserted by XLA to prevent buffer
377   // clobbering.
378   void SetUpAlias(const ShapeIndex& output_index, int64_t param_number,
379                   const ShapeIndex& param_index,
380                   HloInputOutputAliasConfig::AliasKind kind =
381                       HloInputOutputAliasConfig::AliasKind::kMayAlias) {
382     input_output_aliases_.push_back(
383         {output_index, param_number, param_index, kind});
384   }
385 
386   // Describes an input/output alias as inserted by the SetUpAlias() API.
387   struct InputOutputAlias {
388     // Specifies the index of the aliased buffer in the result tuple.
389     ShapeIndex output_index;
390     // Specifies the parameter containing the buffer to be aliased.
391     int64 param_number;
392     // Specifies the index of the aliased buffer in the parameter
393     ShapeIndex param_index;
394     // Specifies if the alias is a must alias or may alias.
395     HloInputOutputAliasConfig::AliasKind kind;
396   };
397 
398   // Looks up the HloInstruction and sets the frontend attribute "attribute" to
399   // "value".
400   //
401   // If the attribute already existed then its value is updated.
402   //
403   // Note: the attribute is only added to the HloInstruction, not to the
404   // builder.
405   Status SetInstructionFrontendAttribute(XlaOp op, string attribute,
406                                          string value);
407 
408   // Returns shapes for the operands.
409   StatusOr<std::vector<Shape>> GetOperandShapes(
410       absl::Span<const XlaOp> operands) const;
411 
412   // Converts the op to string for the ease of debugging.
413   std::string OpToString(XlaOp op) const;
414 
415  private:
416   void ToStringHelper(std::string* out, int ident, int64_t op_handle) const;
417 
418   // Build helper which takes the id of the root operation..
419   StatusOr<XlaComputation> Build(int64_t root_id,
420                                  bool remove_dynamic_dimensions);
421 
422   // Description for the methods below can be found in the corresponding public
423   // functions section in this file.
424 
425   XlaOp Parameter(int64_t parameter_number, const Shape& shape,
426                   const string& name,
427                   const std::vector<bool>& replicated_at_leaf_buffers);
Parameter(int64_t parameter_number,const Shape & shape,const string & name)428   XlaOp Parameter(int64_t parameter_number, const Shape& shape,
429                   const string& name) {
430     std::vector<bool> empty_bools;
431     return Parameter(parameter_number, shape, name, empty_bools);
432   }
433 
434   virtual XlaOp ConstantLiteral(const LiteralSlice& literal);
435 
436   XlaOp Broadcast(XlaOp operand, absl::Span<const int64> broadcast_sizes);
437 
438   XlaOp BroadcastInDim(XlaOp operand,
439                        const absl::Span<const int64> out_dim_size,
440                        const absl::Span<const int64> broadcast_dimensions);
441 
442   XlaOp Pad(XlaOp operand, XlaOp padding_value,
443             const PaddingConfig& padding_config);
444   XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
445                  int64_t pad_lo, int64_t pad_hi);
446 
447   virtual StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
448                                       XlaOp padding_value,
449                                       const PaddingConfig& padding_config);
450 
451   XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
452                 absl::Span<const int64> new_sizes,
453                 int64_t inferred_dimension = -1);
454 
455   XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
456                 int64_t inferred_dimension = -1);
457 
458   XlaOp Reshape(const Shape& shape, XlaOp operand,
459                 int64_t inferred_dimension = -1);
460 
461   XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
462                        absl::Span<const int64> new_size_bounds,
463                        const std::vector<bool>& dims_are_dynamic);
464 
465   XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
466 
467   XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
468               absl::Span<const int64> limit_indices,
469               absl::Span<const int64> strides);
470   virtual StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand,
471                                         absl::Span<const int64> start_indices,
472                                         absl::Span<const int64> limit_indices,
473                                         absl::Span<const int64> strides);
474   virtual XlaOp SliceInDim(XlaOp operand, int64_t start_index,
475                            int64_t limit_index, int64_t stride, int64_t dimno);
476 
477   XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
478                      absl::Span<const int64> slice_sizes);
479   virtual StatusOr<XlaOp> DynamicSliceInternal(
480       const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
481       absl::Span<const int64> slice_sizes);
482 
483   XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
484                            absl::Span<const XlaOp> start_indices);
485   virtual StatusOr<XlaOp> DynamicUpdateSliceInternal(
486       const Shape& shape, XlaOp operand, XlaOp update,
487       absl::Span<const XlaOp> start_indices);
488 
489   XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64_t dimension);
490   virtual StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
491                                               absl::Span<const XlaOp> operands,
492                                               int64_t dimension);
493 
494   void Trace(const string& tag, XlaOp operand);
495 
496   XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
497 
498   XlaOp Tuple(absl::Span<const XlaOp> elements);
499   virtual StatusOr<XlaOp> TupleInternal(const Shape& shape,
500                                         absl::Span<const XlaOp> elements);
501 
502   XlaOp GetTupleElement(XlaOp tuple_data, int64_t index);
503   virtual StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape,
504                                                   XlaOp tuple_data,
505                                                   int64_t index);
506 
507   XlaOp Dot(
508       XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config = nullptr,
509       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
510 
511   XlaOp DotGeneral(
512       XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
513       const PrecisionConfig* precision_config = nullptr,
514       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
515 
516   XlaOp Conv(
517       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
518       Padding padding, int64_t feature_group_count = 1,
519       int64_t batch_group_count = 1,
520       const PrecisionConfig* precision_config = nullptr,
521       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
522 
523   XlaOp ConvWithGeneralPadding(
524       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
525       absl::Span<const std::pair<int64, int64>> padding,
526       int64_t feature_group_count = 1, int64_t batch_group_count = 1,
527       const PrecisionConfig* precision_config = nullptr,
528       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
529 
530   XlaOp ConvWithGeneralDimensions(
531       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
532       Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
533       int64_t feature_group_count = 1, int64_t batch_group_count = 1,
534       const PrecisionConfig* precision_config = nullptr,
535       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
536 
537   XlaOp ConvGeneral(
538       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
539       absl::Span<const std::pair<int64, int64>> padding,
540       const ConvolutionDimensionNumbers& dimension_numbers,
541       int64_t feature_group_count = 1, int64_t batch_group_count = 1,
542       const PrecisionConfig* precision_config = nullptr,
543       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
544 
545   XlaOp ConvGeneralDilated(
546       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
547       absl::Span<const std::pair<int64, int64>> padding,
548       absl::Span<const int64> lhs_dilation,
549       absl::Span<const int64> rhs_dilation,
550       const ConvolutionDimensionNumbers& dimension_numbers,
551       int64_t feature_group_count = 1, int64_t batch_group_count = 1,
552       const PrecisionConfig* precision_config = nullptr,
553       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
554 
555   XlaOp DynamicConvForward(
556       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
557       absl::Span<const std::pair<int64, int64>> padding,
558       absl::Span<const int64> lhs_dilation,
559       absl::Span<const int64> rhs_dilation,
560       const ConvolutionDimensionNumbers& dimension_numbers,
561       int64_t feature_group_count, int64_t batch_group_count,
562       const PrecisionConfig* precision_config, PaddingType padding_type,
563       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
564 
565   XlaOp DynamicConvInputGrad(
566       XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
567       absl::Span<const int64> window_strides,
568       absl::Span<const std::pair<int64, int64>> padding,
569       absl::Span<const int64> lhs_dilation,
570       absl::Span<const int64> rhs_dilation,
571       const ConvolutionDimensionNumbers& dimension_numbers,
572       int64_t feature_group_count, int64_t batch_group_count,
573       const PrecisionConfig* precision_config, PaddingType padding_type,
574       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
575 
576   XlaOp DynamicConvKernelGrad(
577       XlaOp activations, XlaOp gradients,
578       absl::Span<const int64> window_strides,
579       absl::Span<const std::pair<int64, int64>> padding,
580       absl::Span<const int64> lhs_dilation,
581       absl::Span<const int64> rhs_dilation,
582       const ConvolutionDimensionNumbers& dimension_numbers,
583       int64_t feature_group_count, int64_t batch_group_count,
584       const PrecisionConfig* precision_config, PaddingType padding_type,
585       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
586 
587   StatusOr<HloInstructionProto> DynamicConvInstruction(
588       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
589       absl::Span<const std::pair<int64, int64>> padding,
590       absl::Span<const int64> lhs_dilation,
591       absl::Span<const int64> rhs_dilation,
592       const ConvolutionDimensionNumbers& dimension_numbers,
593       int64_t feature_group_count, int64_t batch_group_count,
594       const PrecisionConfig* precision_config, PaddingType padding_type,
595       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
596 
597   virtual StatusOr<XlaOp> ConvGeneralDilatedInternal(
598       const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
599       absl::Span<const int64> window_strides,
600       absl::Span<const std::pair<int64, int64>> padding,
601       absl::Span<const int64> lhs_dilation,
602       absl::Span<const int64> rhs_dilation,
603       const ConvolutionDimensionNumbers& dimension_numbers,
604       int64_t feature_group_count, int64_t batch_group_count,
605       const PrecisionConfig* precision_config);
606 
607   XlaOp Fft(XlaOp operand, FftType fft_type,
608             absl::Span<const int64> fft_length);
609   virtual StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand,
610                                       FftType fft_type,
611                                       absl::Span<const int64> fft_length);
612 
613   virtual StatusOr<XlaOp> TriangularSolveInternal(
614       const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options);
615 
616   virtual StatusOr<XlaOp> CholeskyInternal(const Shape& shape, XlaOp a,
617                                            bool lower);
618 
619   XlaOp Infeed(const Shape& shape, const string& config = "");
620   XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const string& config);
621   virtual StatusOr<XlaOp> InfeedWithTokenInternal(
622       const Shape& infeed_instruction_shape, XlaOp token, const string& config);
623 
624   void Outfeed(XlaOp operand, const Shape& shape_with_layout,
625                const string& outfeed_config);
626   XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
627                          const Shape& shape_with_layout,
628                          const string& outfeed_config);
629   virtual StatusOr<XlaOp> OutfeedWithTokenInternal(
630       XlaOp operand, XlaOp token, const Shape& shape_with_layout,
631       const string& outfeed_config);
632   XlaOp Call(const XlaComputation& computation,
633              absl::Span<const XlaOp> operands);
634 
635   XlaOp CustomCall(
636       const string& call_target_name, absl::Span<const XlaOp> operands,
637       const Shape& shape_with_layout, const string& opaque,
638       absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
639       bool has_side_effect,
640       absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
641           output_operand_aliasing,
642       const Literal* literal, absl::optional<Window> window,
643       absl::optional<ConvolutionDimensionNumbers> dnums,
644       CustomCallSchedule schedule, CustomCallApiVersion api_version);
645 
646   // Internal version of CustomCall without computation that doesn't do op
647   // specific error handling and expects arguments to be legal. CustomCall
648   // method above calls this method after error handling.
649   virtual StatusOr<XlaOp> CustomCallInternal(
650       const string& call_target_name, absl::Span<const XlaOp> operands,
651       const Shape& shape_with_layout, const string& opaque,
652       absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
653       bool has_side_effect,
654       absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
655           output_operand_aliasing,
656       const Literal* literal, absl::optional<Window> window,
657       absl::optional<ConvolutionDimensionNumbers> dnums,
658       CustomCallSchedule schedule, CustomCallApiVersion api_version);
659 
660   XlaOp CustomCall(
661       const string& call_target_name, absl::Span<const XlaOp> operands,
662       const XlaComputation& computation, const Shape& shape_with_layout,
663       const string& opaque,
664       absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
665       bool has_side_effect,
666       absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
667           output_operand_aliasing,
668       const Literal* literal, CustomCallSchedule schedule,
669       CustomCallApiVersion api_version);
670 
671   XlaOp Reduce(XlaOp operand, XlaOp init_value,
672                const XlaComputation& computation,
673                absl::Span<const int64> dimensions_to_reduce);
674 
675   XlaOp Reduce(absl::Span<const XlaOp> operands,
676                absl::Span<const XlaOp> init_values,
677                const XlaComputation& computation,
678                absl::Span<const int64> dimensions_to_reduce);
679 
680   virtual StatusOr<XlaOp> ReduceInternal(
681       const Shape& shape, absl::Span<const XlaOp> all_operands,
682       const XlaComputation& computation,
683       absl::Span<const int64> dimensions_to_reduce);
684 
685   XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
686                   const XlaComputation& computation);
687 
688   XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
689                      const XlaComputation& computation,
690                      absl::Span<const int64> window_dimensions,
691                      absl::Span<const int64> window_strides, Padding padding);
692 
693   XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
694                      absl::Span<const XlaOp> init_values,
695                      const XlaComputation& computation,
696                      absl::Span<const int64> window_dimensions,
697                      absl::Span<const int64> window_strides, Padding padding);
698 
699   XlaOp ReduceWindowWithGeneralPadding(
700       absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
701       const XlaComputation& computation,
702       absl::Span<const int64> window_dimensions,
703       absl::Span<const int64> window_strides,
704       absl::Span<const int64> base_dilations,
705       absl::Span<const int64> window_dilations,
706       absl::Span<const std::pair<int64, int64>> padding);
707   StatusOr<HloInstructionProto> ReduceWindowInternal(
708       absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
709       const XlaComputation& computation,
710       absl::Span<const int64> window_dimensions,
711       absl::Span<const int64> window_strides,
712       absl::Span<const int64> base_dilations,
713       absl::Span<const int64> window_dilations,
714       absl::Span<const std::pair<int64, int64>> padding);
715   virtual StatusOr<XlaOp> ReduceWindowInternal(
716       const Shape& shape, XlaOp operand, XlaOp init_value,
717       const XlaComputation& computation, Window window);
718   XlaOp CrossReplicaSum(XlaOp operand,
719                         absl::Span<const ReplicaGroup> replica_groups = {});
720 
721   XlaOp AllGather(
722       XlaOp operand, int64_t all_gather_dimension, int64_t shard_count,
723       absl::Span<const ReplicaGroup> replica_groups = {},
724       const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
725       const absl::optional<Layout>& layout = absl::nullopt,
726       const absl::optional<bool> use_global_device_ids = absl::nullopt);
727 
728   XlaOp AllReduce(
729       XlaOp operand, const XlaComputation& computation,
730       absl::Span<const ReplicaGroup> replica_groups = {},
731       const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
732       const absl::optional<Shape>& shape_with_layout = absl::nullopt);
733 
734   XlaOp ReduceScatter(
735       XlaOp operand, const XlaComputation& computation,
736       int64_t scatter_dimension, int64_t shard_count,
737       absl::Span<const ReplicaGroup> replica_groups = {},
738       const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
739       const absl::optional<Layout>& layout = absl::nullopt,
740       const absl::optional<bool> use_global_device_ids = absl::nullopt);
741 
742   XlaOp AllToAll(XlaOp operand, int64_t split_dimension,
743                  int64_t concat_dimension, int64_t split_count,
744                  absl::Span<const ReplicaGroup> replica_groups,
745                  const absl::optional<Layout>& layout = absl::nullopt);
746 
747   XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension,
748                       int64_t concat_dimension, int64_t split_count,
749                       absl::Span<const ReplicaGroup> replica_groups,
750                       const absl::optional<Layout>& layout);
751 
752   XlaOp CollectivePermute(
753       XlaOp operand,
754       const std::vector<std::pair<int64, int64>>& source_target_pairs);
755 
756   XlaOp ReplicaId();
757 
758   XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
759                          absl::Span<const int64> window_dimensions,
760                          absl::Span<const int64> window_strides,
761                          Padding padding, XlaOp source, XlaOp init_value,
762                          const XlaComputation& scatter);
763 
764   XlaOp SelectAndScatterWithGeneralPadding(
765       XlaOp operand, const XlaComputation& select,
766       absl::Span<const int64> window_dimensions,
767       absl::Span<const int64> window_strides,
768       absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
769       XlaOp init_value, const XlaComputation& scatter);
770 
771   StatusOr<HloInstructionProto> SelectAndScatterInternal(
772       XlaOp operand, const XlaComputation& select,
773       absl::Span<const int64> window_dimensions,
774       absl::Span<const int64> window_strides,
775       absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
776       XlaOp init_value, const XlaComputation& scatter);
777 
778   virtual XlaOp Iota(const Shape& shape, int64_t iota_dimension);
779 
780   XlaOp Iota(PrimitiveType type, int64_t size);
781 
782   XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
783 
784   XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
785   virtual StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape,
786                                                      XlaOp operand);
787 
788   XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
789   virtual StatusOr<XlaOp> TransposeInternal(
790       const Shape& shape, XlaOp operand, absl::Span<const int64> permutation);
791 
792   XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
793   virtual StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand,
794                                       absl::Span<const int64> dimensions);
795 
796   XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
797              int64_t dimension = -1, bool is_stable = false);
798   virtual StatusOr<XlaOp> SortInternal(const Shape& shape,
799                                        absl::Span<const XlaOp> operands,
800                                        const XlaComputation& comparator,
801                                        int64_t dimension, bool is_stable);
802 
803   XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
804 
805   XlaOp Map(absl::Span<const XlaOp> operands, const XlaComputation& computation,
806             absl::Span<const int64> dimensions,
807             absl::Span<const XlaOp> static_operands = {});
808 
809   XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
810 
811   XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
812 
813   XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
814                         const Shape& shape);
815   // Internal variant for the op with the full result shape containing both data
816   // and state shape as a tuple.
817   virtual StatusOr<XlaOp> RngBitGeneratorInternal(
818       const Shape& full_result_shape, RandomAlgorithm algorithm,
819       XlaOp initial_state);
820 
821   XlaOp While(const XlaComputation& condition, const XlaComputation& body,
822               XlaOp init);
823   virtual StatusOr<XlaOp> WhileInternal(const Shape& shape,
824                                         const XlaComputation& condition,
825                                         const XlaComputation& body, XlaOp init);
826 
827   XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
828                     const XlaComputation& true_computation, XlaOp false_operand,
829                     const XlaComputation& false_computation);
830 
831   XlaOp Conditional(XlaOp branch_index,
832                     absl::Span<const XlaComputation* const> branch_computations,
833                     absl::Span<const XlaOp> branch_operands);
834 
835   XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
836                         const int mantissa_bits);
837   virtual StatusOr<XlaOp> ReducePrecisionInternal(const Shape& shape,
838                                                   XlaOp operand,
839                                                   const int exponent_bits,
840                                                   const int mantissa_bits);
841 
842   XlaOp Gather(XlaOp input, XlaOp start_indices,
843                const GatherDimensionNumbers& dimension_numbers,
844                absl::Span<const int64> slice_sizes,
845                bool indices_are_sorted = false);
846 
847   virtual StatusOr<XlaOp> GatherInternal(
848       const Shape& shape, XlaOp input, XlaOp start_indices,
849       const GatherDimensionNumbers& dimension_numbers,
850       absl::Span<const int64> slice_sizes, bool indices_are_sorted);
851 
852   XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
853                 const XlaComputation& update_computation,
854                 const ScatterDimensionNumbers& dimension_numbers,
855                 bool indices_are_sorted = false, bool unique_indices = false);
856 
857   virtual StatusOr<XlaOp> ScatterInternal(
858       const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
859       const XlaComputation& update_computation,
860       const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
861       bool unique_indices);
862 
863   void Send(XlaOp operand, const ChannelHandle& handle);
864   XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle);
865 
866   XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout,
867                    const ChannelHandle& handle);
868 
869   XlaOp RecvFromHost(XlaOp token, const Shape& shape,
870                      const ChannelHandle& handle);
871 
872   virtual XlaOp CreateToken();
873 
874   XlaOp AfterAll(absl::Span<const XlaOp> tokens);
875 
876   XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
877   XlaOp RecvWithToken(XlaOp token, const Shape& shape,
878                       const ChannelHandle& handle);
879 
880   XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
881                           float epsilon, int64_t feature_index);
882 
883   XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean,
884                            XlaOp variance, float epsilon,
885                            int64_t feature_index);
886 
887   XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
888                       XlaOp batch_var, XlaOp grad_output, float epsilon,
889                       int64_t feature_index);
890 
891   XlaOp GetDimensionSize(XlaOp operand, int64_t dimension);
892 
893   XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension);
894 
895   virtual StatusOr<XlaOp> SetDimensionSizeInternal(const Shape& shape,
896                                                    XlaOp operand, XlaOp val,
897                                                    int64_t dimension);
898 
899   XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension);
900 
901   virtual StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr,
902                                          HloOpcode opcode,
903                                          absl::Span<const XlaOp> operands);
AddInstruction(HloInstructionProto && instr,HloOpcode opcode)904   StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr,
905                                  HloOpcode opcode) {
906     return AddInstruction(std::move(instr), opcode, /*operands=*/{});
907   }
908 
909   void AddCalledComputation(const XlaComputation& computation,
910                             HloInstructionProto* instr);
911 
912   StatusOr<const HloInstructionProto*> LookUpInstruction(XlaOp op) const;
913   StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
914       int64_t handle) const;
915   StatusOr<HloInstructionProto*> LookUpMutableInstruction(XlaOp op);
916   StatusOr<HloInstructionProto*> LookUpMutableInstructionByHandle(
917       int64_t handle);
918 
919   // Internal helper method that does the building for an arbitrary unary op.
920   virtual XlaOp UnaryOp(HloOpcode unop, XlaOp operand);
921 
922   // Internal helper method that does the building for an arbitrary binary op.
923   // broadcast_dimensions specifies which dimensions to use for broadcasting
924   // when the operation is between tensors of different ranks. The direction is
925   // only used if opcode is kCompare.
926   XlaOp BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
927                  absl::Span<const int64> broadcast_dimensions,
928                  absl::optional<ComparisonDirection> direction = absl::nullopt,
929                  absl::optional<Comparison::Type> type = absl::nullopt);
930 
931   StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
932                           ComparisonDirection direction);
933 
934   // Internal helper method for binary op compare without broadcast dimensions.
935   virtual StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
936                                   ComparisonDirection direction,
937                                   Comparison::Type type);
938 
939   // Internal helper method that does the building for an arbitrary binary op
940   // with same ranked operands that doesn't broadcast.
941   virtual XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
942                                     XlaOp lhs, XlaOp rhs);
943 
944   // Internal helper method that does the building for an arbitrary ternary op.
945   XlaOp TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs);
946 
947   XlaOp RngOp(RandomDistribution distribution,
948               absl::Span<const XlaOp> parameters, const Shape& shape);
949 
950   virtual StatusOr<XlaOp> RngOpInternal(RandomDistribution distribution,
951                                         absl::Span<const XlaOp> parameters,
952                                         const Shape& shape);
953 
954   virtual StatusOr<XlaOp> InDimBroadcast(
955       const Shape& shape, XlaOp operand,
956       absl::Span<const int64> broadcast_dimensions);
957 
958   // Internal helper method that creates a sequence of instructions that
959   // performs an explicit broadcast of the operand to the target shape.
960   StatusOr<XlaOp> AddBroadcastSequence(const Shape& output_shape,
961                                        XlaOp operand);
962 
963   // Internal helper method for creating a Reshape op with the already inferred
964   // shape.
965   virtual StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand,
966                                           int64_t inferred_dimension);
967 
968   // Returns the (inferred) result for the program shape using the given root.
969   StatusOr<ProgramShape> GetProgramShape(int64_t root_id) const;
970 
971   // A visitor which checks whether an operation is a compile-time constant,
972   // meaning that it doesn't depend on any parameters, or on any stateful
973   // operation such as `RngNormal` or `Infeed`. The visitor walks the
974   // computation starting at a given operation and sets is_constant to false iff
975   // a parameter or stateful operation is encountered.
976   void IsConstantVisitor(const int64_t op_handle, int depth,
977                          absl::flat_hash_set<int64>* visited,
978                          bool* is_constant) const;
979 
980   // Checks bounds for convolution parameters.
981   Status VerifyConvolution(
982       const Shape& lhs_shape, const Shape& rhs_shape,
983       const ConvolutionDimensionNumbers& dimension_numbers) const;
984 
GetNextId()985   int64 GetNextId() { return ++next_id_; }
986 
987   // Populates the module with the input/output alias information stored within
988   // the input_output_aliases vector.
989   static Status PopulateInputOutputAlias(
990       HloModuleProto* module, const ProgramShape& program_shape,
991       const std::vector<InputOutputAlias>& input_output_aliases);
992 
993   string name_;  // Name to use for the built computation.
994 
995   // The next sequential ID for every instruction/computation contained within
996   // this computation.
997   int64 next_id_ = 0;
998 
999   // The first error encountered while building the computation.
1000   // This is OK until the first error is encountered.
1001   Status first_error_;
1002 
1003   // The saved stack trace from the point at which the first error occurred.
1004   tensorflow::SavedStackTrace first_error_backtrace_;
1005 
1006   // The instructions of this computation.
1007   std::vector<HloInstructionProto> instructions_;
1008 
1009   // An cache for the HloInstructionProto shapes, to avoid recreating Shape
1010   // objects from protos and to support the GetShapePtr() API.
1011   std::vector<std::unique_ptr<Shape>> instruction_shapes_;
1012 
1013   // Dynamic parameter configuration of this computation.
1014   DynamicParameterBinding dynamic_parameter_binding_;
1015 
1016   // Holds the input/output alias information populated by the SetUpAlias() API.
1017   std::vector<InputOutputAlias> input_output_aliases_;
1018 
1019   // A map from XlaOp::Handle to the index in the instructions_ vector where the
1020   // instruction is held.
1021   absl::flat_hash_map<int64, int64> handle_to_index_;
1022 
1023   // Track imported instructions by their computation id and the position in
1024   // their computation's instruction list.
1025   struct ImportedInstruction {
1026     int64 computation_id;
1027     int64 instruction_index;
1028   };
1029 
1030   absl::flat_hash_map<int64, ImportedInstruction> handle_to_imported_index_;
1031 
1032   // The embedded computations used by this computation. Each computation was
1033   // the entry computation of some XlaComputation, the key is the unique id of
1034   // that XlaComputation.
1035   std::map<int64, HloComputationProto> embedded_;
1036 
1037   // The unique parameter numbers.
1038   absl::flat_hash_set<int64> parameter_numbers_;
1039 
1040   // The metadata to attach to each op. This is structured as a "modal"-like
1041   // operation, in order to simplify client code (and not sprinkle this metadata
1042   // throughout the TensorFlow op kernel implementations).
1043   OpMetadata metadata_;
1044 
1045   // A temporary metadata that will only be applied to the next op created.
1046   absl::optional<OpMetadata> one_shot_metadata_;
1047 
1048   // Sharding for this operator. This is structured as a "model"-like operation,
1049   // in order to simplify client code, similar to metadata_.
1050   absl::optional<OpSharding> sharding_;
1051 
1052   // Mode bit that indicates whether to die when a first error is encountered.
1053   bool die_immediately_on_error_ = false;
1054 
1055   XlaBuilder* parent_builder_{nullptr};
1056 
1057   FrontendAttributes frontend_attributes_;
1058 
1059   friend XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
1060                          const Shape& shape, const string& name,
1061                          const std::vector<bool>& replicated_at_leaf_buffers);
1062   friend XlaOp ConstantLiteral(XlaBuilder* builder,
1063                                const LiteralSlice& literal);
1064 
1065   friend XlaOp Broadcast(XlaOp operand,
1066                          absl::Span<const int64> broadcast_sizes);
1067 
1068   friend XlaOp BroadcastInDim(
1069       XlaOp operand, const absl::Span<const int64> out_dim_size,
1070       const absl::Span<const int64> broadcast_dimensions);
1071 
1072   friend XlaOp Copy(XlaOp operand);
1073 
1074   friend XlaOp Pad(XlaOp operand, XlaOp padding_value,
1075                    const PaddingConfig& padding_config);
1076 
1077   friend XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
1078                         int64_t pad_lo, int64_t pad_hi);
1079 
1080   friend XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
1081                        absl::Span<const int64> new_sizes);
1082 
1083   friend XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
1084 
1085   friend XlaOp Reshape(const Shape& shape, XlaOp operand);
1086 
1087   friend XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
1088                               absl::Span<const int64> new_size_bounds,
1089                               const std::vector<bool>& dims_are_dynamic);
1090 
1091   friend XlaOp ReshapeWithInferredDimension(XlaOp operand,
1092                                             absl::Span<const int64> new_sizes,
1093                                             int64_t inferred_dimension);
1094 
1095   friend XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
1096 
1097   friend XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
1098                      absl::Span<const int64> limit_indices,
1099                      absl::Span<const int64> strides);
1100 
1101   friend XlaOp SliceInDim(XlaOp operand, int64_t start_index,
1102                           int64_t limit_index, int64_t stride, int64_t dimno);
1103 
1104   friend XlaOp DynamicSlice(XlaOp operand,
1105                             absl::Span<const XlaOp> start_indices,
1106                             absl::Span<const int64> slice_sizes);
1107 
1108   friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
1109                                   absl::Span<const XlaOp> start_indices);
1110 
1111   friend XlaOp ConcatInDim(XlaBuilder* builder,
1112                            absl::Span<const XlaOp> operands, int64_t dimension);
1113 
1114   friend void Trace(const string& tag, XlaOp operand);
1115 
1116   friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
1117   friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
1118   friend XlaOp GetTupleElement(XlaOp tuple_data, int64_t index);
1119   friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
1120                        absl::Span<const int64> broadcast_dimensions,
1121                        ComparisonDirection direction);
1122   friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
1123                        absl::Span<const int64> broadcast_dimensions,
1124                        ComparisonDirection direction,
1125                        Comparison::Type compare_type);
1126   friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
1127                    const PrecisionConfig* precision_config,
1128                    absl::optional<PrimitiveType> preferred_element_type);
1129   friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
1130                           const DotDimensionNumbers& dimension_number,
1131                           const PrecisionConfig* precision_config,
1132                           absl::optional<PrimitiveType> preferred_element_type);
1133   virtual StatusOr<XlaOp> DotGeneralInternal(
1134       const Shape& shape, XlaOp lhs, XlaOp rhs,
1135       const DotDimensionNumbers& dimension_number,
1136       const PrecisionConfig* precision_config);
1137   friend XlaOp Conv(XlaOp lhs, XlaOp rhs,
1138                     absl::Span<const int64> window_strides, Padding padding,
1139                     int64_t feature_group_count, int64_t batch_group_count,
1140                     const PrecisionConfig* precision_config,
1141                     absl::optional<PrimitiveType> preferred_element_type);
1142   friend XlaOp ConvWithGeneralPadding(
1143       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1144       absl::Span<const std::pair<int64, int64>> padding,
1145       int64_t feature_group_count, int64_t batch_group_count,
1146       const PrecisionConfig* precision_config,
1147       absl::optional<PrimitiveType> preferred_element_type);
1148   friend XlaOp ConvWithGeneralDimensions(
1149       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1150       Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1151       int64_t feature_group_count, int64_t batch_group_count,
1152       const PrecisionConfig* precision_config,
1153       absl::optional<PrimitiveType> preferred_element_type);
1154   friend XlaOp ConvGeneral(
1155       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1156       absl::Span<const std::pair<int64, int64>> padding,
1157       const ConvolutionDimensionNumbers& dimension_numbers,
1158       int64_t feature_group_count, int64_t batch_group_count,
1159       const PrecisionConfig* precision_config,
1160       absl::optional<PrimitiveType> preferred_element_type);
1161   friend XlaOp DynamicConvForward(
1162       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1163       absl::Span<const std::pair<int64, int64>> padding,
1164       absl::Span<const int64> lhs_dilation,
1165       absl::Span<const int64> rhs_dilation,
1166       const ConvolutionDimensionNumbers& dimension_numbers,
1167       int64_t feature_group_count, int64_t batch_group_count,
1168       const PrecisionConfig* precision_config, PaddingType padding_type,
1169       absl::optional<PrimitiveType> preferred_element_type);
1170   friend XlaOp DynamicConvKernelGrad(
1171       XlaOp activations, XlaOp gradients,
1172       absl::Span<const int64> window_strides,
1173       absl::Span<const std::pair<int64, int64>> padding,
1174       absl::Span<const int64> lhs_dilation,
1175       absl::Span<const int64> rhs_dilation,
1176       const ConvolutionDimensionNumbers& dimension_numbers,
1177       int64_t feature_group_count, int64_t batch_group_count,
1178       const PrecisionConfig* precision_config, PaddingType padding_type,
1179       absl::optional<PrimitiveType> preferred_element_type);
1180   friend XlaOp DynamicConvInputGrad(
1181       XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
1182       absl::Span<const int64> window_strides,
1183       absl::Span<const std::pair<int64, int64>> padding,
1184       absl::Span<const int64> lhs_dilation,
1185       absl::Span<const int64> rhs_dilation,
1186       const ConvolutionDimensionNumbers& dimension_numbers,
1187       int64_t feature_group_count, int64_t batch_group_count,
1188       const PrecisionConfig* precision_config, PaddingType padding_type,
1189       absl::optional<PrimitiveType> preferred_element_type);
1190 
1191   friend XlaOp ConvKernelGrad(
1192       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1193       absl::Span<const std::pair<int64, int64>> padding,
1194       absl::Span<const int64> lhs_dilation,
1195       absl::Span<const int64> rhs_dilation,
1196       const ConvolutionDimensionNumbers& dimension_numbers,
1197       int64_t feature_group_count, int64_t batch_group_count,
1198       const PrecisionConfig* precision_config,
1199       absl::optional<PrimitiveType> preferred_element_type);
1200 
1201   friend XlaOp ConvGeneralDilated(
1202       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1203       absl::Span<const std::pair<int64, int64>> padding,
1204       absl::Span<const int64> lhs_dilation,
1205       absl::Span<const int64> rhs_dilation,
1206       const ConvolutionDimensionNumbers& dimension_numbers,
1207       int64_t feature_group_count, int64_t batch_group_count,
1208       const PrecisionConfig* precision_config,
1209       absl::optional<PrimitiveType> preferred_element_type);
1210   friend XlaOp Fft(XlaOp operand, FftType fft_type,
1211                    absl::Span<const int64> fft_length);
1212   friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
1213                                bool unit_diagonal,
1214                                TriangularSolveOptions::Transpose transpose_a);
1215   friend XlaOp Cholesky(XlaOp a, bool lower);
1216   friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
1217                       const string& config);
1218   friend void Outfeed(XlaOp operand, const Shape& shape_with_layout,
1219                       const string& outfeed_config);
1220   friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
1221                     absl::Span<const XlaOp> operands);
1222   friend XlaOp CustomCall(
1223       XlaBuilder* builder, const string& call_target_name,
1224       absl::Span<const XlaOp> operands, const Shape& shape,
1225       const string& opaque, bool has_side_effect,
1226       absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1227           output_operand_aliasing,
1228       const Literal* literal, CustomCallSchedule schedule,
1229       CustomCallApiVersion api_version);
1230   friend XlaOp CustomCallWithComputation(
1231       XlaBuilder* builder, const string& call_target_name,
1232       absl::Span<const XlaOp> operands, const XlaComputation& computation,
1233       const Shape& shape, const string& opaque, bool has_side_effect,
1234       absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1235           output_operand_aliasing,
1236       const Literal* literal, CustomCallSchedule schedule,
1237       CustomCallApiVersion api_version);
1238   friend XlaOp CustomCallWithLayout(
1239       XlaBuilder* builder, const string& call_target_name,
1240       absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
1241       absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
1242       bool has_side_effect,
1243       absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1244           output_operand_aliasing,
1245       const Literal* literal, CustomCallSchedule schedule,
1246       CustomCallApiVersion api_version);
1247   friend XlaOp CustomCallWithConvDnums(
1248       XlaBuilder* builder, const string& call_target_name,
1249       absl::Span<const XlaOp> operands, const Shape& shape,
1250       absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
1251       bool has_side_effect,
1252       absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1253           output_operand_aliasing,
1254       const Literal* literal, Window window, ConvolutionDimensionNumbers dnums,
1255       CustomCallSchedule schedule, CustomCallApiVersion api_version);
1256   friend XlaOp Complex(XlaOp real, XlaOp imag,
1257                        absl::Span<const int64> broadcast_dimensions);
1258   friend XlaOp Conj(XlaOp operand);
1259   friend XlaOp Add(XlaOp lhs, XlaOp rhs,
1260                    absl::Span<const int64> broadcast_dimensions);
1261   friend XlaOp Sub(XlaOp lhs, XlaOp rhs,
1262                    absl::Span<const int64> broadcast_dimensions);
1263   friend XlaOp Mul(XlaOp lhs, XlaOp rhs,
1264                    absl::Span<const int64> broadcast_dimensions);
1265   friend XlaOp Div(XlaOp lhs, XlaOp rhs,
1266                    absl::Span<const int64> broadcast_dimensions);
1267   friend XlaOp Rem(XlaOp lhs, XlaOp rhs,
1268                    absl::Span<const int64> broadcast_dimensions);
1269   friend XlaOp Max(XlaOp lhs, XlaOp rhs,
1270                    absl::Span<const int64> broadcast_dimensions);
1271   friend XlaOp Min(XlaOp lhs, XlaOp rhs,
1272                    absl::Span<const int64> broadcast_dimensions);
1273   friend XlaOp And(XlaOp lhs, XlaOp rhs,
1274                    absl::Span<const int64> broadcast_dimensions);
1275   friend XlaOp Or(XlaOp lhs, XlaOp rhs,
1276                   absl::Span<const int64> broadcast_dimensions);
1277   friend XlaOp Xor(XlaOp lhs, XlaOp rhs,
1278                    absl::Span<const int64> broadcast_dimensions);
1279   friend XlaOp Not(XlaOp operand);
1280   friend XlaOp PopulationCount(XlaOp operand);
1281   friend XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs,
1282                          absl::Span<const int64> broadcast_dimensions);
1283   friend XlaOp ShiftRightArithmetic(
1284       XlaOp lhs, XlaOp rhs, absl::Span<const int64> broadcast_dimensions);
1285   friend XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs,
1286                                  absl::Span<const int64> broadcast_dimensions);
1287   friend XlaOp Reduce(XlaOp operand, XlaOp init_value,
1288                       const XlaComputation& computation,
1289                       absl::Span<const int64> dimensions_to_reduce);
1290   friend XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1291                       absl::Span<const XlaOp> init_values,
1292                       const XlaComputation& computation,
1293                       absl::Span<const int64> dimensions_to_reduce);
1294   friend XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
1295                          const XlaComputation& computation);
1296   friend XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
1297                             const XlaComputation& computation,
1298                             absl::Span<const int64> window_dimensions,
1299                             absl::Span<const int64> window_strides,
1300                             Padding padding);
1301   friend XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
1302                             absl::Span<const XlaOp> init_values,
1303                             const XlaComputation& computation,
1304                             absl::Span<const int64> window_dimensions,
1305                             absl::Span<const int64> window_strides,
1306                             Padding padding);
1307   friend XlaOp ReduceWindowWithGeneralPadding(
1308       XlaOp operand, XlaOp init_value, const XlaComputation& computation,
1309       absl::Span<const int64> window_dimensions,
1310       absl::Span<const int64> window_strides,
1311       absl::Span<const int64> base_dilations,
1312       absl::Span<const int64> window_dilations,
1313       absl::Span<const std::pair<int64, int64>> padding);
1314   friend XlaOp ReduceWindowWithGeneralPadding(
1315       absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
1316       const XlaComputation& computation,
1317       absl::Span<const int64> window_dimensions,
1318       absl::Span<const int64> window_strides,
1319       absl::Span<const int64> base_dilations,
1320       absl::Span<const int64> window_dilations,
1321       absl::Span<const std::pair<int64, int64>> padding);
1322 
1323   friend XlaOp CrossReplicaSum(XlaOp operand,
1324                                absl::Span<const ReplicaGroup> replica_groups);
1325   friend XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension,
1326                          int64_t shard_count,
1327                          absl::Span<const ReplicaGroup> replica_groups,
1328                          const absl::optional<ChannelHandle>& channel_id,
1329                          const absl::optional<Layout>& layout,
1330                          const absl::optional<bool> use_global_device_ids);
1331   friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
1332                          absl::Span<const ReplicaGroup> replica_groups,
1333                          const absl::optional<ChannelHandle>& channel_id,
1334                          const absl::optional<Shape>& shape_with_layout);
1335   friend XlaOp ReduceScatter(XlaOp operand, const XlaComputation& computation,
1336                              int64_t scatter_dimension, int64_t shard_count,
1337                              absl::Span<const ReplicaGroup> replica_groups,
1338                              const absl::optional<ChannelHandle>& channel_id,
1339                              const absl::optional<Layout>& layout,
1340                              const absl::optional<bool> use_global_device_ids);
1341 
1342   friend XlaOp AllToAll(XlaOp operand, int64_t split_dimension,
1343                         int64_t concat_dimension, int64_t split_count,
1344                         absl::Span<const ReplicaGroup> replica_groups,
1345                         const absl::optional<Layout>& layout);
1346   friend XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension,
1347                              int64_t concat_dimension, int64_t split_count,
1348                              absl::Span<const ReplicaGroup> replica_groups,
1349                              const absl::optional<Layout>& layout);
1350   friend XlaOp CollectivePermute(
1351       XlaOp operand,
1352       const std::vector<std::pair<int64, int64>>& source_target_pairs);
1353   friend XlaOp ReplicaId(XlaBuilder* builder);
1354   friend XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
1355                                 absl::Span<const int64> window_dimensions,
1356                                 absl::Span<const int64> window_strides,
1357                                 Padding padding, XlaOp source, XlaOp init_value,
1358                                 const XlaComputation& scatter);
1359   friend XlaOp SelectAndScatterWithGeneralPadding(
1360       XlaOp operand, const XlaComputation& select,
1361       absl::Span<const int64> window_dimensions,
1362       absl::Span<const int64> window_strides,
1363       absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
1364       XlaOp init_value, const XlaComputation& scatter);
1365   friend XlaOp Abs(XlaOp operand);
1366   friend XlaOp Atan2(XlaOp y, XlaOp x,
1367                      absl::Span<const int64> broadcast_dimensions);
1368   friend XlaOp Exp(XlaOp operand);
1369   friend XlaOp Expm1(XlaOp operand);
1370   friend XlaOp Floor(XlaOp operand);
1371   friend XlaOp Ceil(XlaOp operand);
1372   friend XlaOp Round(XlaOp operand);
1373   friend XlaOp Log(XlaOp operand);
1374   friend XlaOp Log1p(XlaOp operand);
1375   friend XlaOp Logistic(XlaOp operand);
1376   friend XlaOp Sign(XlaOp operand);
1377   friend XlaOp Clz(XlaOp operand);
1378   friend XlaOp Cos(XlaOp operand);
1379   friend XlaOp Sin(XlaOp operand);
1380   friend XlaOp Tanh(XlaOp operand);
1381   friend XlaOp Real(XlaOp operand);
1382   friend XlaOp Imag(XlaOp operand);
1383   friend XlaOp Sqrt(XlaOp operand);
1384   friend XlaOp Rsqrt(XlaOp operand);
1385   friend XlaOp Cbrt(XlaOp operand);
1386   friend XlaOp Pow(XlaOp lhs, XlaOp rhs,
1387                    absl::Span<const int64> broadcast_dimensions);
1388   friend XlaOp IsFinite(XlaOp operand);
1389   friend XlaOp Iota(XlaBuilder* builder, const Shape& shape,
1390                     int64_t iota_dimension);
1391   friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size);
1392   friend XlaOp ConvertElementType(XlaOp operand,
1393                                   PrimitiveType new_element_type);
1394   friend XlaOp BitcastConvertType(XlaOp operand,
1395                                   PrimitiveType new_element_type);
1396   friend XlaOp Neg(XlaOp operand);
1397   friend XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
1398   friend XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
1399   friend XlaOp Sort(absl::Span<const XlaOp> operands,
1400                     const XlaComputation& comparator, int64_t dimension,
1401                     bool is_stable);
1402   friend XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
1403   friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1404                    const XlaComputation& computation,
1405                    absl::Span<const int64> dimensions,
1406                    absl::Span<const XlaOp> static_operands);
1407   friend XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
1408   friend XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
1409   friend XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
1410                                const Shape& shape);
1411   friend XlaOp While(const XlaComputation& condition,
1412                      const XlaComputation& body, XlaOp init);
1413   friend XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
1414                            const XlaComputation& true_computation,
1415                            XlaOp false_operand,
1416                            const XlaComputation& false_computation);
1417   friend XlaOp Conditional(
1418       XlaOp branch_index,
1419       absl::Span<const XlaComputation* const> branch_computations,
1420       absl::Span<const XlaOp> branch_operands);
1421   friend XlaOp ConditionalImpl(
1422       XlaOp branch_index,
1423       absl::Span<const XlaComputation* const> branch_computations,
1424       absl::Span<const XlaOp> branch_operands);
1425   friend XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
1426                                const int mantissa_bits);
1427   friend XlaOp Gather(XlaOp input, XlaOp start_indices,
1428                       const GatherDimensionNumbers& dimension_numbers,
1429                       absl::Span<const int64> slice_sizes,
1430                       bool indices_are_sorted);
1431   friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
1432                        const XlaComputation& update_computation,
1433                        const ScatterDimensionNumbers& dimension_numbers,
1434                        bool indices_are_sorted, bool unique_indices);
1435   friend void Send(XlaOp operand, const ChannelHandle& handle);
1436   friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
1437                     const ChannelHandle& handle);
1438   friend XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
1439                                  float epsilon, int64_t feature_index);
1440   friend XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset,
1441                                   XlaOp mean, XlaOp variance, float epsilon,
1442                                   int64_t feature_index);
1443   friend XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
1444                              XlaOp batch_var, XlaOp grad_output, float epsilon,
1445                              int64_t feature_index);
1446   friend XlaOp SendWithToken(XlaOp operand, XlaOp token,
1447                              const ChannelHandle& handle);
1448   friend XlaOp RecvWithToken(XlaOp token, const Shape& shape,
1449                              const ChannelHandle& handle);
1450   friend XlaOp SendToHost(XlaOp operand, XlaOp token,
1451                           const Shape& shape_with_layout,
1452                           const ChannelHandle& handle);
1453   friend XlaOp RecvFromHost(XlaOp token, const Shape& shape,
1454                             const ChannelHandle& handle);
1455   friend XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
1456                                const string& config);
1457   friend XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
1458                                 const Shape& shape_with_layout,
1459                                 const string& outfeed_config);
1460   friend XlaOp CreateToken(XlaBuilder* builder);
1461   friend XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
1462 
1463   friend XlaOp GetDimensionSize(XlaOp operand, int64_t dimension);
1464   friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension);
1465   friend XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension);
1466 
1467  protected:
1468   // Returns OK status if the given op was built using this builder. Otherwise,
1469   // returns an error.
1470   Status CheckOpBuilder(XlaOp op) const;
1471 
1472  private:
1473   XlaOp ConditionalImpl(
1474       XlaOp branch_index,
1475       absl::Span<const XlaComputation* const> branch_computations,
1476       absl::Span<const XlaOp> branch_operands);
1477 
1478   XlaOp AllToAllArray(XlaOp operand, int64_t split_dimension,
1479                       int64_t concat_dimension, int64_t split_count,
1480                       absl::Span<const ReplicaGroup> replica_groups);
1481 
1482   // Creates an op with the given opcode and the output shape.
1483   virtual StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
1484                                          absl::Span<const XlaOp> operands);
1485 
1486   // Here, InstructionType is either const HloInstructionProto* or non-const
1487   // HloInstructionProto*.
1488   template <typename InstructionType>
LookUpInstructionByHandleInternal(int64_t handle)1489   StatusOr<InstructionType> LookUpInstructionByHandleInternal(
1490       int64_t handle) const {
1491     auto it = handle_to_index_.find(handle);
1492     if (it == handle_to_index_.end()) {
1493       // Try look for the instruction in the imported instructions.
1494       auto imported_it = handle_to_imported_index_.find(handle);
1495       if (imported_it != handle_to_imported_index_.end()) {
1496         ImportedInstruction imported = imported_it->second;
1497         return const_cast<InstructionType>(
1498             &embedded_.at(imported.computation_id)
1499                  .instructions(imported.instruction_index));
1500       }
1501       return InvalidArgument("No XlaOp with handle %d", handle);
1502     }
1503     return const_cast<InstructionType>(&instructions_.at(it->second));
1504   }
1505 
1506   // Here, InstructionType is either const HloInstructionProto* or non-const
1507   // HloInstructionProto*.
1508   //
1509   // TODO(hinsu): Return const pointer within StatusOr and use
1510   // absl::implicit_cast at callsites. This requires implicit_cast support in
1511   // stream_executor::port::StatusOr similar to absl::StatusOr.
1512   template <typename InstructionType>
LookUpInstructionInternal(XlaOp op)1513   StatusOr<InstructionType> LookUpInstructionInternal(XlaOp op) const {
1514     TF_RETURN_IF_ERROR(CheckOpBuilder(op));
1515     return LookUpInstructionByHandleInternal<InstructionType>(op.handle());
1516   }
1517 
1518   friend struct internal::XlaBuilderFriend;
1519 
1520   friend class ValueInference;
1521 };
1522 
1523 // RAII-style object: sets the current sharding assignment in builder on
1524 // construction, and sets back to the previous assignment on destruction.
1525 class XlaScopedShardingAssignment {
1526  public:
XlaScopedShardingAssignment(xla::XlaBuilder * builder,absl::optional<OpSharding> sharding)1527   XlaScopedShardingAssignment(xla::XlaBuilder* builder,
1528                               absl::optional<OpSharding> sharding)
1529       : builder_(builder), prev_sharding_(builder->sharding()) {
1530     SetSharding(sharding);
1531   }
1532 
1533   XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
1534   XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
1535       delete;
1536 
~XlaScopedShardingAssignment()1537   ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
1538 
1539  private:
SetSharding(const absl::optional<OpSharding> & sharding)1540   void SetSharding(const absl::optional<OpSharding>& sharding) {
1541     if (sharding.has_value()) {
1542       builder_->SetSharding(sharding.value());
1543     } else {
1544       builder_->ClearSharding();
1545     }
1546   }
1547 
1548   xla::XlaBuilder* const builder_;
1549   absl::optional<OpSharding> prev_sharding_;
1550 };
1551 
1552 // RAII-style object: save the current builder's frontend attributes, and merge
1553 // them with the new ones on construction.
1554 // Restore the original attributes on destruction.
1555 class XlaScopedFrontendAttributesAssignment {
1556  public:
XlaScopedFrontendAttributesAssignment(xla::XlaBuilder * builder,FrontendAttributes attributes)1557   XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder,
1558                                         FrontendAttributes attributes)
1559       : builder_(builder) {
1560     saved_ = builder_->SwapFrontendAttributes(attributes);
1561   }
1562 
~XlaScopedFrontendAttributesAssignment()1563   ~XlaScopedFrontendAttributesAssignment() {
1564     builder_->SetFrontendAttributes(saved_);
1565   }
1566 
1567  private:
1568   xla::XlaBuilder* const builder_;
1569   FrontendAttributes saved_;
1570 
1571   TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedFrontendAttributesAssignment);
1572 };
1573 
1574 // RAII-style object: sets the current op metadata in builder on construction,
1575 // and sets back to the previous assignment on destruction.
1576 class XlaScopedOpMetadataAssignment {
1577  public:
XlaScopedOpMetadataAssignment(xla::XlaBuilder * builder,OpMetadata metadata)1578   XlaScopedOpMetadataAssignment(xla::XlaBuilder* builder, OpMetadata metadata)
1579       : builder_(builder) {
1580     saved_ = builder_->SwapOpMetadata(metadata);
1581   }
1582 
~XlaScopedOpMetadataAssignment()1583   ~XlaScopedOpMetadataAssignment() { builder_->SwapOpMetadata(saved_); }
1584 
1585  private:
1586   xla::XlaBuilder* const builder_;
1587   OpMetadata saved_;
1588 
1589   TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedOpMetadataAssignment);
1590 };
1591 
1592 // Free functions for building XlaOps. The intention is that these will
1593 // become the public API for building XlaOps rather than calling methods on
1594 // XlaBuilder directly.
1595 //
1596 
1597 // Enqueues a "retrieve parameter value" instruction for a parameter that was
1598 // passed to the computation.
1599 XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
1600                 const Shape& shape, const string& name);
1601 
1602 // Same as above, but with leaf buffer replication annotation.
1603 XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
1604                 const Shape& shape, const string& name,
1605                 const std::vector<bool>& replicated_at_leaf_buffers);
1606 
1607 // Enqueues a constant with the value of the given literal onto the
1608 // computation.
1609 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal);
1610 
1611 // Enqueues a constant onto the computation. Methods are templated on the
1612 // native host type (NativeT) which corresponds to a specific XLA
1613 // PrimitiveType as given in the following table:
1614 //
1615 //  Native Type   PrimitiveType
1616 // -----------------------------
1617 //   bool           PRED
1618 //   int32          S32
1619 //   int64          S64
1620 //   uint32         U32
1621 //   uint64         U64
1622 //   float          F32
1623 //   double         F64
1624 //
1625 // Note: not all primitive types defined in xla_data.proto have a
1626 // corresponding native type yet.
1627 template <typename NativeT>
1628 XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
1629 template <typename NativeT>
1630 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values);
1631 XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values);
1632 template <typename NativeT>
1633 XlaOp ConstantR2(XlaBuilder* builder,
1634                  std::initializer_list<std::initializer_list<NativeT>> values);
1635 template <typename NativeT>
1636 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
1637                                   const Array<NativeT>& values,
1638                                   const Layout& layout);
1639 template <typename NativeT>
1640 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values);
1641 template <typename NativeT>
1642 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
1643                                       const Array2D<NativeT>& values,
1644                                       const Layout& layout);
1645 template <typename NativeT>
1646 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
1647                             const Array2D<NativeT>& values);
1648 template <typename NativeT>
1649 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
1650                                       const Array3D<NativeT>& values,
1651                                       const Layout& layout);
1652 template <typename NativeT>
1653 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
1654                             const Array3D<NativeT>& values);
1655 template <typename NativeT>
1656 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
1657                                       const Array4D<NativeT>& values,
1658                                       const Layout& layout);
1659 template <typename NativeT>
1660 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
1661                             const Array4D<NativeT>& values);
1662 
1663 // Enqueues a rank one constant (XlaBuilder* builder, vector) onto the
1664 // computation. The vector has size 'length' and every element has the value
1665 // 'value'.
1666 template <typename NativeT>
1667 XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value);
1668 
1669 // Adds dimensions to an array by duplicating the data in the array.
1670 //
1671 // The new dimensions are inserted on the left, i.e. if
1672 // broadcast_sizes has values {a0, ..., aN} and the operand shape
1673 // has dimensions {b0, ..., bM} then the shape of the output has
1674 // dimensions {a0, ..., aN, b0, ..., bM}.
1675 //
1676 // The new dimensions index into copies of the operand, i.e.
1677 //
1678 //   output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
1679 XlaOp Broadcast(XlaOp operand, absl::Span<const int64> broadcast_sizes);
1680 
1681 // This op broadcasts the `operand` to an output with the given `shape`.
1682 // `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the
1683 // i'th dimension of the operand is mapped to the broadcast_dimensions[i]'th
1684 // dimension of the output. This also requires that the i'th input dimension is
1685 // either 1 or is the same as the output dimension it's broadcasting into.
1686 //
1687 // For example, say operand = {1, 2}, i.e., a 1D tensor in shape s32[2]; the
1688 // output shape is s32[2,2]:
1689 // - Specifying {1} as broadcast_dimension will generate output
1690 //   {{1, 2},
1691 //    {1, 2}}
1692 // - On the other hand, specifying {0} as broadcast_dimension
1693 //   will generate output
1694 //   {{1 , 1},
1695 //    {2 , 2}}
1696 XlaOp BroadcastInDim(XlaOp operand, const absl::Span<const int64> out_dim_size,
1697                      const absl::Span<const int64> broadcast_dimensions);
1698 
1699 // Copies the input operand to the output. This operation is for internal
1700 // purpose and is only used by the compiler for optimization purposes or to
1701 // ensure correctness. The XLA client should never have to generate this
1702 // instruction.
1703 //
1704 // Copy has two potential use cases:
1705 //
1706 // * Create a copy of the operand with a new layout.
1707 //
1708 // * Create a copy of the operand in a separately allocated buffer. This is
1709 //   necessary for some backends if the operand is a parameter or constant and
1710 //   the operand is returned within a tuple. In this case, the lifetime of the
1711 //   operand buffer must be the same as the lifetime of the output result.
1712 //   However, the lifetimes of parameters and constants are managed separately
1713 //   from the lifetime of the output result. Creating a separate copy of the
1714 //   parameter or constant buffer resolves this issue.
1715 XlaOp Copy(XlaOp operand);
1716 
1717 // Enqueues a pad operation onto the computation that pads the given value on
1718 // the edges as well as between the elements of the input. padding_config
1719 // specifies the padding amount for each dimension.
1720 XlaOp Pad(XlaOp operand, XlaOp padding_value,
1721           const PaddingConfig& padding_config);
1722 
1723 // Enqueues a pad operation in a given dimension, taking all other
1724 // dimensions as they are.
1725 XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
1726                int64_t pad_lo, int64_t pad_hi);
1727 
1728 // Enqueues an operation onto the computation that flattens the operand based
1729 // on the dimension order (major/slowest-varying to minor/fastest-varying)
1730 // given, followed by reshaping it into the shape with the given dimension
1731 // sizes (also major to minor). Conceptually, this is a limited form of
1732 // "shape casting".
1733 XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
1734               absl::Span<const int64> new_sizes);
1735 
1736 // Enqueues a dynamic reshape operation. The dynamic reshape takes additional
1737 // XlaOps as sizes for the result dimension. The result dim i is a dynamic
1738 // dimension dimension if dims_are_dynamic[i] is true.
1739 XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
1740                      absl::Span<const int64> new_size_bounds,
1741                      const std::vector<bool>& dims_are_dynamic);
1742 
1743 // Enqueues an operation onto the computation that collapses the operand,
1744 // from first to last dimension (C order), then reshapes it to the given
1745 // dimension sizes. Conceptually, this is a limited form of "shape casting".
1746 XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
1747 
1748 // Enqueues a Reshape op that uses an explicit target shape.
1749 XlaOp Reshape(const Shape& shape, XlaOp operand);
1750 
1751 // `inferred_dimension` represents the output dimension that's inferred by
1752 // upper-level framework by dividing the input element count by the known
1753 // output element count. While an inferred_dimension can be static, if there
1754 // is a dynamic dimension in the output, it must be the inferred dimension.
1755 XlaOp ReshapeWithInferredDimension(XlaOp operand,
1756                                    absl::Span<const int64> new_sizes,
1757                                    int64_t inferred_dimension);
1758 
1759 // Wrapper for Reshape.
1760 // Enqueues an operation to collapse the provided dimensions; e.g. an
1761 // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
1762 // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
1763 // be a consecutive, in-order subsequence of the operand dimensions.
1764 //
1765 // Note that collapsing a single dimension does nothing:
1766 //
1767 //    {256} collapsing {0} => {256}
1768 //    {1} collapsing {0} => {1}
1769 //
1770 // Collapsing multiple dimensions produces a single result dimension:
1771 //
1772 //    {256, 2} collapsing {0,1} => {512}
1773 //    {256, 2, 3} collapsing {0,1} => {512, 3}
1774 //
1775 // This could potentially cause data to be moved -- it provides a more
1776 // structured form of reshaping than an arbitrary Reshape operation.
1777 XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
1778 
1779 // Enqueues a slice operation onto the computation that slices the operand
1780 // from the start indices to the limit indices; e.g.
1781 //
1782 //        x
1783 //   [ 0 1 2 3 ]
1784 // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
1785 //   [ 8 9 a b ]
1786 //
1787 // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
1788 // range notation.
1789 // The strides parameter determines the stride over the slice
1790 XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
1791             absl::Span<const int64> limit_indices,
1792             absl::Span<const int64> strides);
1793 
1794 // Enqueues a slice operation in a given dimension, taking all other
1795 // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
1796 // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
1797 // for:
1798 //
1799 //  array[:, 2:4:1, :]
1800 XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index,
1801                  int64_t stride, int64_t dimno);
1802 
1803 // Enqueues a slice operation onto the computation that slices the 'operand'
1804 // from dynamic start indices which are passed in 'start_indices'.
1805 // The size of the slice in each dimension is passed in 'slice_sizes',
1806 // which specify the end point of exclusive slice intervals in each
1807 // dimension [start, start + size).
1808 // The shape of each element of 'start_indices' must be scalar, with the span
1809 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1810 // have the same shape.
1811 // Slice index calculations are computed modulo input dimension sizes to
1812 // prevent dynamic start indices from generating out-of-bound array accesses.
1813 XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
1814                    absl::Span<const int64> slice_sizes);
1815 
1816 // Enqueues a dynamic update slice operation onto the computation, which
1817 // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
1818 // The shape of 'update' determines the shape of the slice of 'operand'
1819 // which is updated.
1820 // The indices specified in 'start_indices' specify the offset of the slice
1821 // of 'operand' which is updated.
1822 //
1823 //               update = {10, 11} // calculated at runtime.
1824 //   [1 2 3]     start  = {1, 1}   // calculated at runtime.  [1 2  3 ]
1825 //   [4 5 6]  => DynamicUpdateslice(data, update, start)   => [4 10 11]
1826 //   [7 8 9]                                                  [7 8  9 ]
1827 //
1828 // The shape of each element of 'start_indices' must be scalar, with the span
1829 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1830 // have the same shape.
1831 // Slice index calculations are computed modulo update dimension sizes to
1832 // prevent dynamic start indices from generating out-of-bound array accesses.
1833 XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
1834                          absl::Span<const XlaOp> start_indices);
1835 
1836 // Enqueues a concatenate instruction onto the computation. 'operands' must
1837 // have >= 1 entry.
1838 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1839                   int64_t dimension);
1840 
1841 // Enqueue a tracing operation onto the computation; the computation will emit
1842 // a logging message with the operand.
1843 void Trace(const string& tag, XlaOp operand);
1844 
1845 // Enqueues a conditional-move-like select operation onto the computation;
1846 // predicated on pred, selects between on_true and on_false.
1847 XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
1848 
1849 // Enqueues a tuple-creation instruction onto the computation.
1850 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
1851 
1852 // Enqueues a tuple-element-get instruction onto the computation.
1853 XlaOp GetTupleElement(XlaOp tuple_data, int64_t index);
1854 
1855 // Enqueues an equal-to comparison instruction onto the computation.
1856 XlaOp Eq(XlaOp lhs, XlaOp rhs,
1857          absl::Span<const int64> broadcast_dimensions = {});
1858 XlaOp EqTotalOrder(XlaOp lhs, XlaOp rhs,
1859                    absl::Span<const int64> broadcast_dimensions = {});
1860 
1861 // Enqueues a not-equal comparison instruction onto the computation.
1862 XlaOp Ne(XlaOp lhs, XlaOp rhs,
1863          absl::Span<const int64> broadcast_dimensions = {});
1864 XlaOp NeTotalOrder(XlaOp lhs, XlaOp rhs,
1865                    absl::Span<const int64> broadcast_dimensions = {});
1866 
1867 // Enqueues a greater-or-equal comparison instruction onto the computation.
1868 XlaOp Ge(XlaOp lhs, XlaOp rhs,
1869          absl::Span<const int64> broadcast_dimensions = {});
1870 XlaOp GeTotalOrder(XlaOp lhs, XlaOp rhs,
1871                    absl::Span<const int64> broadcast_dimensions = {});
1872 
1873 // Enqueues a greater-than comparison instruction onto the computation.
1874 XlaOp Gt(XlaOp lhs, XlaOp rhs,
1875          absl::Span<const int64> broadcast_dimensions = {});
1876 XlaOp GtTotalOrder(XlaOp lhs, XlaOp rhs,
1877                    absl::Span<const int64> broadcast_dimensions = {});
1878 
1879 // Enqueues a less-than comparison instruction onto the computation.
1880 XlaOp Lt(XlaOp lhs, XlaOp rhs,
1881          absl::Span<const int64> broadcast_dimensions = {});
1882 XlaOp LtTotalOrder(XlaOp lhs, XlaOp rhs,
1883                    absl::Span<const int64> broadcast_dimensions = {});
1884 
1885 // Enqueues a less-or-equal comparison instruction onto the computation.
1886 XlaOp Le(XlaOp lhs, XlaOp rhs,
1887          absl::Span<const int64> broadcast_dimensions = {});
1888 XlaOp LeTotalOrder(XlaOp lhs, XlaOp rhs,
1889                    absl::Span<const int64> broadcast_dimensions = {});
1890 
1891 // Enqueues a comparison instruction onto the computation (optionally without
1892 // broadcast_dimensions for consistency with others).
1893 XlaOp Compare(XlaOp lhs, XlaOp rhs,
1894               absl::Span<const int64> broadcast_dimensions,
1895               ComparisonDirection direction, Comparison::Type compare_type);
1896 XlaOp Compare(XlaOp lhs, XlaOp rhs,
1897               absl::Span<const int64> broadcast_dimensions,
1898               ComparisonDirection direction);
1899 XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction);
1900 
1901 // Enqueues a dot instruction onto the computation.
1902 XlaOp Dot(XlaOp lhs, XlaOp rhs,
1903           const PrecisionConfig* precision_config = nullptr,
1904           absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1905 
1906 // Enqueues a general dot instruction onto the computation.
1907 XlaOp DotGeneral(
1908     XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
1909     const PrecisionConfig* precision_config = nullptr,
1910     absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1911 
1912 // Enqueues a convolution instruction onto the computation, which uses the
1913 // default convolution dimension numbers.
1914 XlaOp Conv(
1915     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1916     Padding padding, int64_t feature_group_count = 1,
1917     int64_t batch_group_count = 1,
1918     const PrecisionConfig* precision_config = nullptr,
1919     absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1920 
1921 // Enqueues a convolution instruction onto the computation, with the caller
1922 // provided padding configuration in the format returned by MakePadding().
1923 XlaOp ConvWithGeneralPadding(
1924     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1925     absl::Span<const std::pair<int64, int64>> padding,
1926     int64_t feature_group_count = 1, int64_t batch_group_count = 1,
1927     const PrecisionConfig* precision_config = nullptr,
1928     absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1929 
1930 // Enqueues a convolution instruction onto the computation, with the caller
1931 // provided dimension numbers configuration.
1932 XlaOp ConvWithGeneralDimensions(
1933     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1934     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1935     int64_t feature_group_count = 1, int64_t batch_group_count = 1,
1936     const PrecisionConfig* precision_config = nullptr,
1937     absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1938 
1939 // Enqueues a convolution instruction onto the computation, with the caller
1940 // provided padding configuration as well as the dimension numbers.
1941 XlaOp ConvGeneral(
1942     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1943     absl::Span<const std::pair<int64, int64>> padding,
1944     const ConvolutionDimensionNumbers& dimension_numbers,
1945     int64_t feature_group_count = 1, int64_t batch_group_count = 1,
1946     const PrecisionConfig* precision_config = nullptr,
1947     absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1948 
1949 // Enqueues a convolution instruction onto the computation, with the caller
1950 // provided padding configuration, dilation factors and dimension numbers.
1951 XlaOp ConvGeneralDilated(
1952     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1953     absl::Span<const std::pair<int64, int64>> padding,
1954     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1955     const ConvolutionDimensionNumbers& dimension_numbers,
1956     int64_t feature_group_count = 1, int64_t batch_group_count = 1,
1957     const PrecisionConfig* precision_config = nullptr,
1958     absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1959 
1960 XlaOp DynamicConvForward(
1961     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1962     absl::Span<const std::pair<int64, int64>> padding,
1963     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1964     const ConvolutionDimensionNumbers& dimension_numbers,
1965     int64_t feature_group_count, int64_t batch_group_count,
1966     const PrecisionConfig* precision_config, PaddingType padding_type,
1967     absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1968 
1969 XlaOp DynamicConvInputGrad(
1970     XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
1971     absl::Span<const int64> window_strides,
1972     absl::Span<const std::pair<int64, int64>> padding,
1973     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1974     const ConvolutionDimensionNumbers& dimension_numbers,
1975     int64_t feature_group_count, int64_t batch_group_count,
1976     const PrecisionConfig* precision_config, PaddingType padding_type,
1977     absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1978 
1979 XlaOp DynamicConvKernelGrad(
1980     XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
1981     absl::Span<const std::pair<int64, int64>> padding,
1982     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1983     const ConvolutionDimensionNumbers& dimension_numbers,
1984     int64_t feature_group_count, int64_t batch_group_count,
1985     const PrecisionConfig* precision_config, PaddingType padding_type,
1986     absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1987 
1988 // Enqueues an FFT instruction onto the computation, of the given type and
1989 // with the given FFT length.
1990 XlaOp Fft(XlaOp operand, FftType fft_type, absl::Span<const int64> fft_length);
1991 
1992 // Solves systems of linear equations with lower or upper triangular coefficient
1993 // matrices by forward- or back-substitution. Broadcasting along leading
1994 // dimensions, this routine solves for x in one of the matrix systems
1995 //   `op(a) * x = b`,  or `x * op(a) = b`,
1996 // for the variable `x` given `a` and `b`, where `op(a)` is either
1997 //   `op(a) = a`,  or `op(a) = transpose(a)`,  or `op(a) = conj(transpose(a))`.
1998 //
1999 // * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form
2000 //   square matrices. If `lower` is true (false), then the strictly upper
2001 //   (lower) triangular part of each innermost matrix in `a` is assumed to be
2002 //   zero and is not accessed.
2003 // * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a
2004 //   tensor of shape `[..., K, M]`.
2005 // * `left_side` is a boolean, indicating whether to solve a system of the form
2006 //   op(a) * x = b (true) or x * op(a) = b (false).
2007 // * `lower` is a boolean, indicating whether the argument `a` is
2008 //   lower-triangular (true) or upper-triangular (false).
2009 // * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be
2010 //   1 and not accessed.
2011 // * `transpose_a` indicates which function `op` we use to transform the tensor
2012 //   `a`: the identity function, transpose(a), or conjugate(transpose(a))
2013 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
2014                       bool unit_diagonal,
2015                       TriangularSolveOptions::Transpose transpose_a);
2016 
2017 // Computes the Cholesky decompositions of a batch of symmetric (Hermitian)
2018 // positive definite matrices.
2019 // `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the
2020 // two minor dimensions equal.
2021 // If `lower` is true, the data from the lower triangle is used; if false, the
2022 // upper triangle is used. The input data in the other triangle of the input
2023 // does not affect the output. Returns the output in the same lower/upper
2024 // triangle. The data returned in the other output triangle is arbitrary and
2025 // implementation-defined.
2026 //
2027 // If `a` is not Hermitian positive definite, returns an array full of NaNs.
2028 XlaOp Cholesky(XlaOp a, bool lower);
2029 
2030 // Enqueues an infeed instruction onto the computation, which writes data of
2031 // the given shape to the infeed buffer of the device.
2032 XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
2033              const string& config = "");
2034 
2035 // Variant of Infeed which takes a token-shaped operand and produces a
2036 // two-element tuple containing the data value and a token-shaped value.
2037 // Tokens are used for ordering side-effecting operations.
2038 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2039 XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
2040                       const string& config = "");
2041 
2042 // Enqueues an outfeed instruction onto the computation. This instruction
2043 // generates outgoing data transfers for the given data.
2044 //
2045 // shape_with_layout communicates the laid out shape that we want to outfeed
2046 // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
2047 // will occur.
2048 void Outfeed(XlaOp operand, const Shape& shape_with_layout,
2049              const string& outfeed_config);
2050 
2051 // Variant of Outfeed which takes a token-shaped operand and produces a
2052 // token-shaped value. Tokens are used for ordering side-effecting operations.
2053 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2054 XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
2055                        const Shape& shape_with_layout,
2056                        const string& outfeed_config);
2057 
2058 // Enqueues a call instruction onto the computation.
2059 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
2060            absl::Span<const XlaOp> operands);
2061 
2062 // Enqueues a custom call instruction onto the computation. A custom call
2063 // invokes code external to XLA. The |operands| are passed to the external code,
2064 // and the external code is expected to produce a result of the given
2065 // |shape|. The exact mechanism is backend-specific. For example, in the CPU
2066 // backend, a call instruction is emitted which targets a symbol with the name
2067 // |call_target_name|.  |call_target_name| and |opaque| can arbitrary strings,
2068 // but |call_target_name| should be short as it may be used in labels. |opaque|
2069 // can encode arbitrarily large amounts of information. |has_side_effect|
2070 // specifies whether the instruction can have side effects.
2071 // |output_operand_aliasing| specifies a list of output/operand buffer pairs
2072 // that alias each other, where the output buffer is represented as a
2073 // ShapeIndex, and the operand buffer is represented as the operand index and
2074 // the ShapeIndex.
2075 XlaOp CustomCall(
2076     XlaBuilder* builder, const string& call_target_name,
2077     absl::Span<const XlaOp> operands, const Shape& shape,
2078     const string& opaque = "", bool has_side_effect = false,
2079     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
2080         output_operand_aliasing = {},
2081     const Literal* literal = nullptr,
2082     CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE,
2083     CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
2084 
2085 // Overload which constructs a custom call that applies an Xla computation.
2086 XlaOp CustomCallWithComputation(
2087     XlaBuilder* builder, const string& call_target_name,
2088     absl::Span<const XlaOp> operands, const XlaComputation& computation,
2089     const Shape& shape, const string& opaque = "", bool has_side_effect = false,
2090     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
2091         output_operand_aliasing = {},
2092     const Literal* literal = nullptr,
2093     CustomCallSchedule = CustomCallSchedule::SCHEDULE_NONE,
2094     CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
2095 
2096 // Overload which constructs a custom call with fixed layouts. The operands will
2097 // have the layouts specified by |operand_shapes_with_layout| when provided to
2098 // external code, and the external code is expected to produce a result with the
2099 // layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
2100 // and |operand_shapes_with_layout| must have layouts.
2101 XlaOp CustomCallWithLayout(
2102     XlaBuilder* builder, const string& call_target_name,
2103     absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
2104     absl::Span<const Shape> operand_shapes_with_layout,
2105     const string& opaque = "", bool has_side_effect = false,
2106     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
2107         output_operand_aliasing = {},
2108     const Literal* literal = nullptr,
2109     CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE,
2110     CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
2111 
2112 // Overload which annotates a custom call with the given Window and
2113 // ConvolutionDimensionNumbers.  Useful for custom-calls which represent
2114 // convolutions.
2115 //
2116 // This sets the layout of its operands if operand_shapes_with_layout is
2117 // nonempty, and it sets the layout of its result if `shape` has a layout.
2118 XlaOp CustomCallWithConvDnums(
2119     XlaBuilder* builder, const string& call_target_name,
2120     absl::Span<const XlaOp> operands, const Shape& shape,
2121     absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
2122     bool has_side_effect,
2123     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
2124         output_operand_aliasing,
2125     const Literal* literal, Window window, ConvolutionDimensionNumbers dnums,
2126     CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE,
2127     CustomCallApiVersion api_version = API_VERSION_ORIGINAL);
2128 
2129 // The following methods enqueue element-wise binary arithmetic operations
2130 // onto the computation. The shapes of the operands have to match unless one
2131 // of the operands is a scalar, or an explicit broadcast dimension is given
2132 // (see g3doc for more details).
2133 
2134 // Enqueues a complex compose instruction onto the computation.
2135 XlaOp Complex(XlaOp real, XlaOp imag,
2136               absl::Span<const int64> broadcast_dimensions = {});
2137 
2138 // Enqueues a complex conjugate instruction onto the computation.
2139 XlaOp Conj(XlaOp operand);
2140 
2141 // Enqueues an add instruction onto the computation.
2142 XlaOp Add(XlaOp lhs, XlaOp rhs,
2143           absl::Span<const int64> broadcast_dimensions = {});
2144 
2145 // Enqueues a subtract instruction onto the computation.
2146 XlaOp Sub(XlaOp lhs, XlaOp rhs,
2147           absl::Span<const int64> broadcast_dimensions = {});
2148 
2149 // Enqueues a multiply instruction onto the computation.
2150 XlaOp Mul(XlaOp lhs, XlaOp rhs,
2151           absl::Span<const int64> broadcast_dimensions = {});
2152 
2153 // Enqueues a divide instruction onto the computation.
2154 XlaOp Div(XlaOp lhs, XlaOp rhs,
2155           absl::Span<const int64> broadcast_dimensions = {});
2156 
2157 // Enqueues a remainder instruction onto the computation.
2158 XlaOp Rem(XlaOp lhs, XlaOp rhs,
2159           absl::Span<const int64> broadcast_dimensions = {});
2160 
2161 // Enqueues a max instruction onto the computation.
2162 XlaOp Max(XlaOp lhs, XlaOp rhs,
2163           absl::Span<const int64> broadcast_dimensions = {});
2164 
2165 // Enqueues a min instruction onto the computation.
2166 XlaOp Min(XlaOp lhs, XlaOp rhs,
2167           absl::Span<const int64> broadcast_dimensions = {});
2168 
2169 // Element-wise logical operators
2170 XlaOp And(XlaOp lhs, XlaOp rhs,
2171           absl::Span<const int64> broadcast_dimensions = {});
2172 
2173 // Overload to call And with 3 or more operands.  We need the following somewhat
2174 // convoluted overload set to disambiguate with the overload that takes the
2175 // `broadcast_dimensions` optional param.
And(XlaOp op1,XlaOp op2,XlaOp op3)2176 inline XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3) {
2177   return And(op1, And(op2, op3));
2178 }
2179 template <typename... XlaOpTs>
And(XlaOp op1,XlaOp op2,XlaOp op3,const XlaOpTs &...operands)2180 XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) {
2181   return And(op1, And(op2, And(op3, operands...)));
2182 }
2183 
2184 XlaOp Or(XlaOp lhs, XlaOp rhs,
2185          absl::Span<const int64> broadcast_dimensions = {});
2186 
2187 // Overload to call Or with 3 or more operands.  As with `And`, we need the
2188 // following complicated overload set to handle the default arg in the `Or`
2189 // overload above.
Or(XlaOp op1,XlaOp op2,XlaOp op3)2190 inline XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3) {
2191   return Or(op1, Or(op2, op3));
2192 }
2193 template <typename... XlaOpTs>
Or(XlaOp op1,XlaOp op2,XlaOp op3,const XlaOpTs &...operands)2194 XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) {
2195   return Or(op1, Or(op2, Or(op3, operands...)));
2196 }
2197 
2198 XlaOp Xor(XlaOp lhs, XlaOp rhs,
2199           absl::Span<const int64> broadcast_dimensions = {});
2200 
2201 XlaOp Not(XlaOp operand);
2202 
2203 XlaOp PopulationCount(XlaOp operand);
2204 
2205 XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs,
2206                 absl::Span<const int64> broadcast_dimensions = {});
2207 XlaOp ShiftRightArithmetic(XlaOp lhs, XlaOp rhs,
2208                            absl::Span<const int64> broadcast_dimensions = {});
2209 XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs,
2210                         absl::Span<const int64> broadcast_dimensions = {});
2211 
2212 // Reduces an array among the provided dimensions, given "computation" as a
2213 // reduction operator.
2214 XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation,
2215              absl::Span<const int64> dimensions_to_reduce);
2216 
2217 // Reduces several arrays simultaneously among the provided dimensions, given
2218 // "computation" as a reduction operator.
2219 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
2220              absl::Span<const XlaOp> init_values,
2221              const XlaComputation& computation,
2222              absl::Span<const int64> dimensions_to_reduce);
2223 
2224 // Convenience wrapper around the above that reduces all the dimensions in the
2225 // operand shape.
2226 XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
2227                 const XlaComputation& computation);
2228 
2229 // Enqueues a windowed reduce instruction onto the computation.
2230 XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
2231                    const XlaComputation& computation,
2232                    absl::Span<const int64> window_dimensions,
2233                    absl::Span<const int64> window_strides, Padding padding);
2234 
2235 XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
2236                    absl::Span<const XlaOp> init_values,
2237                    const XlaComputation& computation,
2238                    absl::Span<const int64> window_dimensions,
2239                    absl::Span<const int64> window_strides, Padding padding);
2240 
2241 // As ReduceWindow(), but the padding is given in the format
2242 // returned by MakePadding().
2243 XlaOp ReduceWindowWithGeneralPadding(
2244     XlaOp operand, XlaOp init_value, const XlaComputation& computation,
2245     absl::Span<const int64> window_dimensions,
2246     absl::Span<const int64> window_strides,
2247     absl::Span<const int64> base_dilations,
2248     absl::Span<const int64> window_dilations,
2249     absl::Span<const std::pair<int64, int64>> padding);
2250 XlaOp ReduceWindowWithGeneralPadding(
2251     absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
2252     const XlaComputation& computation,
2253     absl::Span<const int64> window_dimensions,
2254     absl::Span<const int64> window_strides,
2255     absl::Span<const int64> base_dilations,
2256     absl::Span<const int64> window_dilations,
2257     absl::Span<const std::pair<int64, int64>> padding);
2258 
2259 // Returns the sum of the operand value within each subgroup of replicas. All
2260 // replicas supply one input to the sum and all replicas receive the resulting
2261 // sum for each subgroup.
2262 XlaOp CrossReplicaSum(XlaOp operand,
2263                       absl::Span<const ReplicaGroup> replica_groups = {});
2264 
2265 XlaOp AllGather(
2266     XlaOp operand, int64_t all_gather_dimension, int64_t shard_count,
2267     absl::Span<const ReplicaGroup> replica_groups = {},
2268     const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
2269     const absl::optional<Layout>& layout = absl::nullopt,
2270     const absl::optional<bool> use_global_device_ids = absl::nullopt);
2271 
2272 // Enqueues an operation that do an AllReduce of the operand cross cores. Here
2273 // AllReduce means doing a reduction on the input operand cross cores and then
2274 // broadcasting the reduction result to those cores. The reduction function is
2275 // defined by `computation`, which should be a commutative computation on
2276 // scalars, e.g., add, min, or max. The way that AllReduce is applied is
2277 // configured by:
2278 //
2279 // - `replica_groups`: each ReplicaGroup contains a list of replica id. If
2280 // empty, all replicas belong to one group. Allreduce will be applied within
2281 // subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}}
2282 // means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
2283 //
2284 // - `channel_id`: for Allreduce nodes from different modules, if they have the
2285 // same channel_id, they will be 'AllReduce'd. If empty, AllReduce will not be
2286 // applied cross modules.
2287 //
2288 // - `shape_with_layout`: forces the layout of the AllReduce to the given
2289 // layout. This is used to guarantee the same layout for a group of AllReduce
2290 // ops compiled separately.
2291 XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
2292                 absl::Span<const ReplicaGroup> replica_groups = {},
2293                 const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
2294                 const absl::optional<Shape>& shape_with_layout = absl::nullopt);
2295 
2296 XlaOp ReduceScatter(
2297     XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension,
2298     int64_t shard_count, absl::Span<const ReplicaGroup> replica_groups = {},
2299     const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
2300     const absl::optional<Layout>& layout = absl::nullopt,
2301     const absl::optional<bool> use_global_device_ids = absl::nullopt);
2302 
2303 // Enqueues an operation that do an Alltoall of the operand cross cores.
2304 // An optional `layout` can be specified to force the layout of the instruction.
2305 // This is used to guarantee the same layout for a group of AllToAll ops
2306 // compiled separately.
2307 XlaOp AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension,
2308                int64_t split_count,
2309                absl::Span<const ReplicaGroup> replica_groups = {},
2310                const absl::optional<Layout>& layout = absl::nullopt);
2311 
2312 XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension,
2313                     int64_t concat_dimension, int64_t split_count,
2314                     absl::Span<const ReplicaGroup> replica_groups = {},
2315                     const absl::optional<Layout>& layout = absl::nullopt);
2316 
2317 // Enqueues an collective operation that sends and receives data cross replicas.
2318 //
2319 // - `source_target_pair`: a list of (source_replica_id, target_replica_id)
2320 // pairs. For each pair, the operand is sent from source replica to target
2321 // replica. Note that, 1) any two pairs should not have the same target replica
2322 // id, and they should not have the same source replica id; 2) if a replica id
2323 // is not a target in any pair, then the output on that replica is a tensor
2324 // consists of 0(s) with the same shape as the input.
2325 XlaOp CollectivePermute(
2326     XlaOp operand,
2327     const std::vector<std::pair<int64, int64>>& source_target_pairs);
2328 
2329 // Enqueues an operation that returns the replica ID.
2330 XlaOp ReplicaId(XlaBuilder* builder);
2331 
2332 // Enqueues an operation that scatters the `source` array to the selected
2333 // indices of each window.
2334 XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
2335                        absl::Span<const int64> window_dimensions,
2336                        absl::Span<const int64> window_strides, Padding padding,
2337                        XlaOp source, XlaOp init_value,
2338                        const XlaComputation& scatter);
2339 
2340 // As SelectAndScatter(), but the padding is given in the format
2341 // returned by MakePadding().
2342 XlaOp SelectAndScatterWithGeneralPadding(
2343     XlaOp operand, const XlaComputation& select,
2344     absl::Span<const int64> window_dimensions,
2345     absl::Span<const int64> window_strides,
2346     absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
2347     XlaOp init_value, const XlaComputation& scatter);
2348 
2349 // Enqueues an abs instruction onto the computation.
2350 XlaOp Abs(XlaOp operand);
2351 
2352 // Enqueues a atan2 instruction onto the computation.
2353 XlaOp Atan2(XlaOp y, XlaOp x,
2354             absl::Span<const int64> broadcast_dimensions = {});
2355 
2356 // Enqueues an exp instruction onto the computation.
2357 XlaOp Exp(XlaOp operand);
2358 
2359 // Enqueues an expm1 instruction onto the computation.
2360 XlaOp Expm1(XlaOp operand);
2361 
2362 // Enqueues a floor instruction onto the computation.
2363 XlaOp Floor(XlaOp operand);
2364 
2365 // Enqueues a ceil instruction onto the computation.
2366 XlaOp Ceil(XlaOp operand);
2367 
2368 // Enqueues a round instruction onto the computation, rounding to nearest even
2369 // with half-way cases rounding away from zero.
2370 XlaOp Round(XlaOp operand);
2371 
2372 // Enqueues an log instruction (natural logarithm) onto the computation.
2373 XlaOp Log(XlaOp operand);
2374 
2375 // Enqueues an log1p instruction (log(x+1)) onto the computation.
2376 XlaOp Log1p(XlaOp operand);
2377 
2378 // Enqueues a logistic instruction onto the computation.
2379 XlaOp Logistic(XlaOp operand);
2380 
2381 // Enqueues a sign instruction onto the computation.
2382 XlaOp Sign(XlaOp operand);
2383 
2384 // Enqueues a count leading zeros instruction onto the computation.
2385 XlaOp Clz(XlaOp operand);
2386 
2387 // Enqueues a cosine instruction onto the computation.
2388 XlaOp Cos(XlaOp operand);
2389 
2390 // Enqueues a sine instruction onto the computation.
2391 XlaOp Sin(XlaOp operand);
2392 
2393 // Enqueues a tanh instruction onto the computation.
2394 XlaOp Tanh(XlaOp operand);
2395 
2396 // Enqueues a real-part instruction onto the computation.
2397 XlaOp Real(XlaOp operand);
2398 
2399 // Enqueues an imaginary-part instruction onto the computation.
2400 XlaOp Imag(XlaOp operand);
2401 
2402 // Enqueues a sqrt computation onto the computation.
2403 XlaOp Sqrt(XlaOp operand);
2404 
2405 // Enqueues a cbrt computation onto the computation.
2406 XlaOp Cbrt(XlaOp operand);
2407 
2408 // Enqueues a rsqrt computation onto the computation.
2409 XlaOp Rsqrt(XlaOp operand);
2410 
2411 // Enqueues a lhs^rhs computation onto the computation.
2412 XlaOp Pow(XlaOp lhs, XlaOp rhs,
2413           absl::Span<const int64> broadcast_dimensions = {});
2414 
2415 // Enqueues an operator that tests if the operand's values are finite, i.e., not
2416 // +/-Inf or NaN.  Returns an array of booleans with the same shape where
2417 // entries are true iff the corresponding entry was not infinite or NaN.
2418 //
2419 // Defined only for real-valued (i.e. not complex) floating-point types; raises
2420 // an error for other types.
2421 //
2422 // See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h.
2423 XlaOp IsFinite(XlaOp operand);
2424 
2425 // Enqueues an iota operation onto the computation.
2426 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64_t iota_dimension);
2427 
2428 // Enqueues a rank-1 iota operation onto the computation.
2429 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size);
2430 
2431 // Enqueues a convert instruction onto the computation that changes the
2432 // element type of the operand array to primitive_type.
2433 XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
2434 
2435 // Enqueues a no-op instruction onto the computation that changes
2436 // the element type of the operand array to primitive_type. The
2437 // bit-widths of the source and destination element types must be
2438 // identical.
2439 XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
2440 
2441 // Enqueues a negate instruction onto the computation.
2442 XlaOp Neg(XlaOp operand);
2443 
2444 // Enqueues a transpose instruction onto the computation.
2445 XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
2446 
2447 // Enqueues a reverse instruction onto the computation. The order of the
2448 // elements in the given dimensions is reversed (i.e., the element at index i
2449 // is moved to index dimension_size - 1 - i).
2450 XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
2451 
2452 // Enqueues a sort instruction onto the computation, using 'comparator' for
2453 // comparisons. 'comparator' needs to define a strict weak order. 'is_stable'
2454 // determines whether the stable sorting should be used.
2455 // If only one operand is provided:
2456 // * If the operand is a rank-1 tensor (an array), the result is a sorted array.
2457 //   The resulting sorting order has the property that for all index positions
2458 //   i, j with i < j, either
2459 //   comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or
2460 //   comparator(value[i], value[j]) = true.
2461 // * If the operand has higher rank, the operand is sorted along the provided
2462 //   dimension. For example, for a rank-2 tensor (a matrix), a dimension value
2463 //   of 0 will independently sort every column, and a dimension value of 1 will
2464 //   independently sort each row. If no dimension number is provided, then the
2465 //   last dimension is chosen by default. For the dimension which is sorted, the
2466 //   same sorting order applies as in the rank-1 case.
2467 //
2468 // If more than one operand is provided:
2469 // * All operands must be tensors with the same dimensions. The element types of
2470 //   the tensors may be different.
2471 // * The result is a tuple that consists of the operands in sorted order (along
2472 //   the provided dimension, as above). The same permutation as implied by the
2473 //   comparison computation is applied to all operand tensors. When comparing
2474 //   two index positions, 'comparator' is called with 2 * n scalar parameters,
2475 //   where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at
2476 //   two index positions.
2477 // Default comparator computations can be found in lib/comparators.h
2478 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
2479            int64_t dimension = -1, bool is_stable = false);
2480 
2481 // Enqueues a clamp instruction onto the computation.
2482 XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
2483 
2484 // Enqueues a map instruction onto the computation.
2485 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
2486           const XlaComputation& computation, absl::Span<const int64> dimensions,
2487           absl::Span<const XlaOp> static_operands = {});
2488 
2489 // Enqueues a N(mu, sigma) random number generation instruction onto the
2490 // computation.
2491 XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
2492 
2493 // Enqueues a U(a, b) random number generation instruction onto the
2494 // computation. Returns values in the semi-open interval [a, b).
2495 XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
2496 
2497 // Enqueues a B(initial_state) random bit generation instruction onto the
2498 // computation. Returns the new key and random bits with the specified shape.
2499 XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
2500                       const Shape& shape);
2501 
2502 // Enqueues a while node onto the computation.
2503 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
2504             XlaOp init);
2505 
2506 // Enqueues a conditional node onto the computation.
2507 XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
2508                   const XlaComputation& true_computation, XlaOp false_operand,
2509                   const XlaComputation& false_computation);
2510 
2511 // Enqueues either a predicated (if/else) or indexed (switch/case/default)
2512 // conditional node onto the computation. N >= 1 branch_computations and
2513 // branch_operands are matched by index. branch_index selects the branch that
2514 // will be executed. Out of range branch_index uses the N-1'th
2515 // branch_computation as default.
2516 XlaOp Conditional(XlaOp branch_index,
2517                   absl::Span<const XlaComputation* const> branch_computations,
2518                   absl::Span<const XlaOp> branch_operands);
2519 
2520 // Enqueues a ReducePrecision node onto the computation.
2521 XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
2522                       const int mantissa_bits);
2523 
2524 // Enqueues a Gather node onto the computation.
2525 XlaOp Gather(XlaOp input, XlaOp start_indices,
2526              const GatherDimensionNumbers& dimension_numbers,
2527              absl::Span<const int64> slice_sizes,
2528              bool indices_are_sorted = false);
2529 
2530 // Enqueues a Scatter node onto the computation.
2531 XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
2532               const XlaComputation& update_computation,
2533               const ScatterDimensionNumbers& dimension_numbers,
2534               bool indices_are_sorted = false, bool unique_indices = false);
2535 
2536 // Enqueues a Send node onto the computation for device-to-device
2537 // communication. This operation sends the given operand to
2538 // a Recv instruction in a different computation that shares the same channel
2539 // handle.
2540 void Send(XlaOp operand, const ChannelHandle& handle);
2541 
2542 // Variant of Send which takes a token-shaped operand and produces a
2543 // token-shaped value.  Tokens are used for ordering side-effecting operations.
2544 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2545 XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle);
2546 
2547 // Enqueues a Recv node onto the computation for device-to-device
2548 // communication. The data comes from a Send instruction in a different
2549 // computation that shares the same channel handle and its shape must be the
2550 // same as the given shape.
2551 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
2552            const ChannelHandle& handle);
2553 
2554 // Variant of Recv which takes a token-shaped operand and produces a two-element
2555 // tuple containing the data value and a token-shaped value. Tokens are used
2556 // for ordering side-effecting operations.
2557 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2558 XlaOp RecvWithToken(XlaOp token, const Shape& shape,
2559                     const ChannelHandle& handle);
2560 
2561 // Enqueues a Send node which transfers data from the device to the host. The
2562 // 'shape_with_layout' argument defines the layout of the data transferred; its
2563 // shape must be compatible with the shape of the operand. The operand must be
2564 // array-shaped.
2565 // TODO(b/111544877): Support tuple shapes.
2566 XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout,
2567                  const ChannelHandle& handle);
2568 
2569 // Enqueues a Recv node which transfers data from the host to the device. The
2570 // given shape must contain a layout and must be an array.
2571 // TODO(b/111544877): Support tuple shapes.
2572 XlaOp RecvFromHost(XlaOp token, const Shape& shape,
2573                    const ChannelHandle& handle);
2574 
2575 // Enqueues an operation (AfterAll) with no operands that produces a
2576 // token-shaped value.  Tokens are used for ordering side-effecting operations.
2577 // This is a separate method from AfterAll to facility the removal of
2578 // operand-less AfterAll instructions.
2579 // TODO(b/110532604): Remove this function when all tokens are derived from a
2580 // single token generated or passed into the entry computation.
2581 XlaOp CreateToken(XlaBuilder* builder);
2582 
2583 // Enqueues an AfterAll instruction which produces a token-shaped value and
2584 // takes a variadic number of token-shaped operands. The number of operands must
2585 // be greater than zero. Used for joining tokens.
2586 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
2587 
2588 // Normalizes operand across spatial and batch dimensions for each feature.
2589 //
2590 // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
2591 // is the normalized result and batch_mean and batch_var are the mean and
2592 // variance, respectively, across batch for the operand.
2593 XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, float epsilon,
2594                         int64_t feature_index);
2595 
2596 // Normalizes operand across spatial and batch dimensions for each feature.
2597 //
2598 // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
2599 // computing `mean` and `variance` for each batch inside the operation. It
2600 // uses the input `mean` and `variance` instead as estimated values. The
2601 // purpose of this op is to reduce latency in inference, hence the name
2602 // `BatchNormInference`.
2603 //
2604 // The output has the same shape as `operand`, and contains the normalized
2605 // values for each batch.
2606 XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean,
2607                          XlaOp variance, float epsilon, int64_t feature_index);
2608 
2609 // Calculates the gradients of a batch norm op.
2610 //
2611 // The inputs `batch_mean` and `batch_var` represent the mean and variance
2612 // across the batch.
2613 //
2614 // Returns a tuple of three elements:
2615 //   - grad_operand: Gradient with respect to input `operand`
2616 //   - grad_offset: Gradient with respect to input `offset`
2617 //   - grad_scale: Gradient with respect to input `scale`
2618 XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
2619                     XlaOp batch_var, XlaOp grad_output, float epsilon,
2620                     int64_t feature_index);
2621 
2622 // Returns the size of the given dimension of the operand. The operand must be
2623 // array shaped.
2624 XlaOp GetDimensionSize(XlaOp operand, int64_t dimension);
2625 
2626 XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension);
2627 
2628 // Returns the same op but with dynamic dimension removed.
2629 XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension);
2630 
2631 // Implementation details below this point.
2632 //
2633 
2634 // Free function template implementations.
2635 
2636 template <typename NativeT>
ConstantR0(XlaBuilder * builder,NativeT value)2637 XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
2638   return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
2639 }
2640 
2641 template <typename NativeT>
ConstantR1(XlaBuilder * builder,absl::Span<const NativeT> values)2642 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
2643   BorrowingLiteral literal(
2644       reinterpret_cast<const char*>(values.begin()),
2645       ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
2646                            {static_cast<int64>(values.size())}));
2647   return ConstantLiteral(builder, literal);
2648 }
2649 
2650 template <typename NativeT>
ConstantR1(XlaBuilder * builder,int64_t length,NativeT value)2651 XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value) {
2652   Literal literal(ShapeUtil::MakeShape(
2653       primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
2654   literal.PopulateWithValue(value);
2655   return ConstantLiteral(builder, literal);
2656 }
2657 
ConstantR1(XlaBuilder * builder,const tensorflow::core::Bitmap & values)2658 inline XlaOp ConstantR1(XlaBuilder* builder,
2659                         const tensorflow::core::Bitmap& values) {
2660   return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
2661 }
2662 
2663 template <typename NativeT>
ConstantR2(XlaBuilder * builder,std::initializer_list<std::initializer_list<NativeT>> values)2664 XlaOp ConstantR2(XlaBuilder* builder,
2665                  std::initializer_list<std::initializer_list<NativeT>> values) {
2666   return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
2667 }
2668 
2669 template <typename NativeT>
ConstantFromArrayWithLayout(XlaBuilder * builder,const Array<NativeT> & values,const Layout & layout)2670 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
2671                                   const Array<NativeT>& values,
2672                                   const Layout& layout) {
2673   return ConstantLiteral(
2674       builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
2675 }
2676 
2677 template <typename NativeT>
ConstantFromArray(XlaBuilder * builder,const Array<NativeT> & values)2678 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
2679   return ConstantLiteral(builder,
2680                          LiteralUtil::CreateFromArray<NativeT>(values));
2681 }
2682 
2683 template <typename NativeT>
ConstantR2FromArray2DWithLayout(XlaBuilder * builder,const Array2D<NativeT> & values,const Layout & layout)2684 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
2685                                       const Array2D<NativeT>& values,
2686                                       const Layout& layout) {
2687   return ConstantLiteral(
2688       builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
2689 }
2690 
2691 template <typename NativeT>
ConstantR2FromArray2D(XlaBuilder * builder,const Array2D<NativeT> & values)2692 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
2693                             const Array2D<NativeT>& values) {
2694   return ConstantLiteral(builder,
2695                          LiteralUtil::CreateR2FromArray2D<NativeT>(values));
2696 }
2697 
2698 template <typename NativeT>
ConstantR3FromArray3DWithLayout(XlaBuilder * builder,const Array3D<NativeT> & values,const Layout & layout)2699 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
2700                                       const Array3D<NativeT>& values,
2701                                       const Layout& layout) {
2702   return ConstantLiteral(
2703       builder,
2704       LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
2705 }
2706 
2707 template <typename NativeT>
ConstantR3FromArray3D(XlaBuilder * builder,const Array3D<NativeT> & values)2708 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
2709                             const Array3D<NativeT>& values) {
2710   return ConstantFromArray(builder, values);
2711 }
2712 
2713 template <typename NativeT>
ConstantR4FromArray4DWithLayout(XlaBuilder * builder,const Array4D<NativeT> & values,const Layout & layout)2714 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
2715                                       const Array4D<NativeT>& values,
2716                                       const Layout& layout) {
2717   return ConstantFromArrayWithLayout(builder, values, layout);
2718 }
2719 
2720 template <typename NativeT>
ConstantR4FromArray4D(XlaBuilder * builder,const Array4D<NativeT> & values)2721 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
2722                             const Array4D<NativeT>& values) {
2723   return ConstantFromArray(builder, values);
2724 }
2725 
2726 }  // namespace xla
2727 
2728 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
2729