• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
18 
19 #include <map>
20 #include <string>
21 #include <type_traits>
22 #include <utility>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/client/padding.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/comparison_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/literal_util.h"
33 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
34 #include "tensorflow/compiler/xla/service/hlo.pb.h"
35 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status_macros.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/stacktrace.h"
43 #include "tensorflow/core/platform/types.h"
44 
45 namespace xla {
46 
47 class XlaBuilder;
48 
49 // This represents an instruction that has been enqueued using the XlaBuilder.
50 // This is used to pass to subsequent computations that depends upon the
51 // instruction as an operand.
52 class XlaOp {
53  public:
XlaOp()54   XlaOp() : handle_(-1), builder_(nullptr) {
55     static_assert(std::is_trivially_destructible<XlaOp>::value,
56                   "XlaOp should be trivially destructible");
57   }
58   ~XlaOp() = default;
59 
60   XlaOp(const XlaOp& other) = default;
61   XlaOp& operator=(const XlaOp& other) = default;
62 
63   // Precondition: !IsUninitialized().
64   //
65   // It's very common to do foo.builder()->bar().  Without this precondition, if
66   // foo.builder() is null, the call to bar will segfault at some point possibly
67   // deep in the callstack when we finally dereference `this`.  The precondition
68   // lets us avoid this tricky-to-debug problem.
builder()69   XlaBuilder* builder() const {
70     CHECK(builder_ != nullptr);
71     return builder_;
72   }
73 
74   // Returns true if the XlaOp represents valid, non-erroneous value.
valid()75   bool valid() const { return handle_ >= 0; }
76 
77   // Returns true if the XlaOp was created by the XlaOp() constructor and
78   // not returned by a builder.
IsUninitialized()79   bool IsUninitialized() const { return builder_ == nullptr; }
80 
IsIdenticalTo(XlaOp rhs)81   bool IsIdenticalTo(XlaOp rhs) const {
82     return handle_ == rhs.handle_ && builder_ == rhs.builder_;
83   }
84 
85   friend std::ostream& operator<<(std::ostream& out, XlaOp op) {
86     out << op.handle();
87     return out;
88   }
89 
90  private:
XlaOp(XlaBuilder * builder)91   explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {}
XlaOp(int64 handle,XlaBuilder * builder)92   XlaOp(int64 handle, XlaBuilder* builder)
93       : handle_(handle), builder_(builder) {}
94 
handle()95   int64 handle() const { return handle_; }
96 
97   friend class XlaBuilder;
98 
99   // < 0 means "invalid handle".
100   int64 handle_;
101 
102   // Not owned. Non-null for any handle returned by XlaBuilder, even if the
103   // handle is invalid.
104   XlaBuilder* builder_;
105 };
106 
107 // Arithmetic operator overloads for the XlaOp type.
108 XlaOp operator-(XlaOp x);
109 XlaOp operator+(XlaOp x, XlaOp y);
110 XlaOp operator-(XlaOp x, XlaOp y);
111 XlaOp operator*(XlaOp x, XlaOp y);
112 XlaOp operator/(XlaOp x, XlaOp y);
113 XlaOp operator%(XlaOp x, XlaOp y);
114 
115 // Bitwise operator overloads for the XlaOp type.
116 XlaOp operator~(XlaOp x);
117 XlaOp operator&(XlaOp x, XlaOp y);
118 XlaOp operator|(XlaOp x, XlaOp y);
119 XlaOp operator^(XlaOp x, XlaOp y);
120 XlaOp operator<<(XlaOp x, XlaOp y);
121 // Performs a right arithmetic shift if 'x' is a signed type, otherwise performs
122 // a right logical shift.
123 XlaOp operator>>(XlaOp x, XlaOp y);
124 
125 // We don't overload the relational operators (==, !=, <, <=, >, >=) because the
126 // semantics might be surprising since their result types are usually 'bool'.
127 // Further programmers may expect == to be a structural equality.
128 // We also choose not to overload any of the mutating operators (e.g., +=, -=)
129 // because the semantics might be misleading — XLA computations are immutable.
130 
131 // A convenient interface for building up computations.
132 //
133 // Thread-compatible.
134 class XlaBuilder {
135  public:
136   // computation_name: name to use for the built computation.
137   XlaBuilder(const string& computation_name);
138 
139   XlaBuilder(const XlaBuilder&) = delete;
140   XlaBuilder& operator=(const XlaBuilder&) = delete;
141 
142   ~XlaBuilder();
143 
144   // Returns the computation name.
name()145   const string& name() const { return name_; }
146 
147   // Sets OpMetadata that will be added to all instructions until cleared.
148   //
149   // OpMetadata is often applied to a series of XLA HLO instructions. As a
150   // result, OpMetadata is set on the computation builder. All subsequent
151   // instructions generated via this computation builder will have the same
152   // OpMetadata attached until a call to ClearOpMetadata.
SetOpMetadata(OpMetadata metadata)153   void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); }
154 
155   // Clears the HloMetadata state.
ClearOpMetadata()156   void ClearOpMetadata() { metadata_.Clear(); }
157 
158   // Sets an OpSharding that will be attached to all instructions until cleared.
SetSharding(const OpSharding & sharding)159   void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
160 
161   // Sets the FrontendAttributes that will be added to all instructions until
162   // cleared.
163   //
164   // FrontendAttributes are often applied to a series of XLA HLO instructions.
165   // As a result they are set on the computation builder and all the
166   // instructions generated via the computation builder will have the same
167   // frontend attributes attached to them.
SetFrontendAttributes(const FrontendAttributes & frontend_attributes)168   void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) {
169     frontend_attributes_ = frontend_attributes;
170   }
171 
172   // Swap the passed FrontendAttributes with the ones currently set.
173   //
174   // Return the old attributes.
SwapFrontendAttributes(const FrontendAttributes & frontend_attributes)175   FrontendAttributes SwapFrontendAttributes(
176       const FrontendAttributes& frontend_attributes) {
177     FrontendAttributes old_attributes = std::move(frontend_attributes_);
178     frontend_attributes_ = frontend_attributes;
179     return old_attributes;
180   }
181 
182   // Returns the FrontendAttributes that will be attached to all instructions.
frontend_attributes()183   const FrontendAttributes& frontend_attributes() const {
184     return frontend_attributes_;
185   }
186 
187   // Clears all the frontend attributes.
ClearFrontendAttributes()188   void ClearFrontendAttributes() { frontend_attributes_.Clear(); }
189 
190   // Clears the sharding. Ops will be sharded according to the default placement
191   // policy.
ClearSharding()192   void ClearSharding() { sharding_ = absl::nullopt; }
193 
194   // Returns the OpSharding that will be attached to all instructions.
sharding()195   const absl::optional<OpSharding>& sharding() const { return sharding_; }
196 
197   // Sets the builder to a mode where it will die immediately when an error is
198   // encountered, rather than producing it in a deferred fashion when Build() is
199   // called (which is the default).
set_die_immediately_on_error(bool enabled)200   void set_die_immediately_on_error(bool enabled) {
201     die_immediately_on_error_ = enabled;
202   }
203 
204   // Default dimension numbers used for a 2D convolution.
205   static constexpr int64 kConvBatchDimension = 0;
206   static constexpr int64 kConvFeatureDimension = 1;
207   static constexpr int64 kConvFirstSpatialDimension = 2;
208   static constexpr int64 kConvSecondSpatialDimension = 3;
209   static constexpr int64 kConvKernelOutputDimension = 0;
210   static constexpr int64 kConvKernelInputDimension = 1;
211   static constexpr int64 kConvKernelFirstSpatialDimension = 2;
212   static constexpr int64 kConvKernelSecondSpatialDimension = 3;
213 
214   // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
215   // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
216   // the kernel operand
217   // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
218   static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
219       int num_spatial_dims = 2);
220 
221   // Returns an error if the convolution dimension numbers have conflicts.
222   static Status Validate(const ConvolutionDimensionNumbers& dnum);
223 
224   // Returns a new XlaBuilder whose resultant Computation is used only by this
225   // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
226   // behavior as the parent.
227   std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
228 
229   // Builds the computation with the requested operations, or returns a non-ok
230   // status. Note that all ops that have been enqueued will be moved to the
231   // computation being returned. The root of the computation will be the last
232   // added operation.
233   //
234   // `remove_dynamic_dimensions` tells the builder whether to remove the
235   // dynamic dimensions information in all ops.
236   //
237   // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the
238   // dynamic dimensions information when XLA backend can handle dynamic
239   // dimensions.
240   StatusOr<XlaComputation> Build(bool remove_dynamic_dimensions = false);
241 
242   // Overload of Build which specifies a particular root instruction for the
243   // computation.
244   StatusOr<XlaComputation> Build(XlaOp root,
245                                  bool remove_dynamic_dimensions = false);
246 
247   // Builds the computation with the requested operations, or notes an error in
248   // the parent XlaBuilder and returns an empty computation if building failed.
249   // This function is intended to be used where the returned XlaComputation is
250   // only used by the parent XlaBuilder and hence further operation on the
251   // returned XlaComputation will simply be error'ed out if an error occurred
252   // while building this computation. If the built computation is to be used by
253   // a XlaBuilder other than the parent XlaBuilder then Build() should be used
254   // instead.
255   XlaComputation BuildAndNoteError();
256 
257   // Returns a subgraph that roots on the given root. If the root is not a
258   // compile-time constant (see `IsConstant`), returns an error.
259   //
260   // This will copy the needed ops/computations to the subgraph.
261   StatusOr<XlaComputation> BuildConstantSubGraph(
262       XlaOp root_op, bool dynamic_dimension_is_uint_max = false);
263 
264   // Returns the first error that was encountered while building the
265   // computation. When an error is encountered, by default we return a vacuous
266   // XlaOp and inform the user of the error that occurred while
267   // building the computation when they make a final call to Build().
268   //
269   // See also set_die_immediately_on_error().
first_error()270   Status first_error() const { return first_error_; }
271 
272   // Returns the current status of the builder, complete with the stack trace
273   // information.
274   Status GetCurrentStatus() const;
275 
276   // Returns the shape of the given op.
277   StatusOr<Shape> GetShape(XlaOp op) const;
278 
279   // Returns the shape of the given op.
280   StatusOr<const Shape*> GetShapePtr(XlaOp op) const;
281 
282   // Returns the (inferred) result for the current computation's shape. This
283   // assumes the root instruction is the last added instruction.
284   StatusOr<ProgramShape> GetProgramShape() const;
285 
286   // Returns the (inferred) result for the current computation's shape using the
287   // given operation as the root.
288   StatusOr<ProgramShape> GetProgramShape(XlaOp root) const;
289 
290   // Reports an error to the builder, by
291   // * storing it internally and capturing a backtrace if it's the first error
292   //   (this deferred value will be produced on the call to
293   //    Build()/GetShape()/...)
294   // * dying if die_immediately_on_error_ is true.
295   // Returns an XlaOp with an invalid handle but a valid builder. This value can
296   // be returned in place of a value in APIs that return an XlaOp.
297   XlaOp ReportError(const Status& error);
298 
299   // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
300   // If the Status was an error, reports the error to builder and returns an
301   // invalid XlaOp handle.
302   XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
303 
304   // A helper function that runs a function that returns a StatusOr<XlaOp> and
305   // returns an XlaOp.
306   XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
307 
308   // Returns true if 'operand' is a compile-time constant. A compile-time
309   // constant does not depend on any parameters, or on stateful operators such
310   // as `RngNormal` or `Infeed`.
311   //
312   // This tests whether a computation is a compile-time constant without
313   // evaluating the computation.
314   StatusOr<bool> IsConstant(XlaOp operand) const;
315 
316   // Sets up binding which indicates that the `target_dim_num` in the subshape
317   // `target_param_index` of parameter `target_param_num` is a dynamic dimension
318   // and its real dynamic size is represented by `dynamic_param_index` in
319   // parameter `dynamic_param_num`.
320   //
321   // Note that this should be called before the dynamic parameters are used to
322   // create other operations, otherwise created operations won't have the
323   // dynamic dimensions information.
324   //
325   // TODO(b/119520625): Remove this API once we have more dynamic shape infra
326   // ready.
327   Status SetDynamicBinding(int64 dynamic_size_param_num,
328                            ShapeIndex dynamic_size_param_index,
329                            int64 target_param_num,
330                            ShapeIndex target_param_index, int64 target_dim_num);
331 
332   // Adds a new input/output alias. Since the input/output shape information are
333   // not available until the computation is built, and eventual error in the
334   // arguments of this API will be detected only at computation Build() time.
SetUpAlias(const ShapeIndex & output_index,int64 param_number,const ShapeIndex & param_index)335   void SetUpAlias(const ShapeIndex& output_index, int64 param_number,
336                   const ShapeIndex& param_index) {
337     input_output_aliases_.push_back({output_index, param_number, param_index});
338   }
339 
340   // Describes an input/output alias as inserted by the SetUpAlias() API.
341   struct InputOutputAlias {
342     // Specifies the index of the aliased buffer in the result tuple.
343     ShapeIndex output_index;
344     // Specifies the parameter containing the buffer to be aliased.
345     int64 param_number;
346     // Specifies the index of the aliased buffer in the parameter
347     ShapeIndex param_index;
348   };
349 
350   // Looks up the HloInstruction and sets the frontend attribute "attribute" to
351   // "value".
352   //
353   // If the attribute already existed then its value is updated.
354   //
355   // Note: the attribute is only added to the HloInstruction, not to the
356   // builder.
357   Status SetInstructionFrontendAttribute(XlaOp op, string attribute,
358                                          string value);
359 
360  private:
361   // Build helper which takes the id of the root operation..
362   StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions);
363 
364   // Description for the methods below can be found in the corresponding public
365   // functions section in this file.
366 
367   XlaOp Parameter(int64 parameter_number, const Shape& shape,
368                   const string& name,
369                   const std::vector<bool>& replicated_at_leaf_buffers);
Parameter(int64 parameter_number,const Shape & shape,const string & name)370   XlaOp Parameter(int64 parameter_number, const Shape& shape,
371                   const string& name) {
372     std::vector<bool> empty_bools;
373     return Parameter(parameter_number, shape, name, empty_bools);
374   }
375 
376   XlaOp ConstantLiteral(const LiteralSlice& literal);
377 
378   XlaOp Broadcast(XlaOp operand, absl::Span<const int64> broadcast_sizes);
379 
380   XlaOp BroadcastInDim(XlaOp operand,
381                        const absl::Span<const int64> out_dim_size,
382                        const absl::Span<const int64> broadcast_dimensions);
383 
384   XlaOp Pad(XlaOp operand, XlaOp padding_value,
385             const PaddingConfig& padding_config);
386 
387   XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
388                 absl::Span<const int64> new_sizes,
389                 int64 inferred_dimension = -1);
390 
391   XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
392                 int64 inferred_dimension = -1);
393 
394   XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
395 
396   XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
397               absl::Span<const int64> limit_indices,
398               absl::Span<const int64> strides);
399 
400   XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
401                    int64 stride, int64 dimno);
402 
403   ABSL_DEPRECATED("Use span-of-indices form instead")
404   XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices,
405                      absl::Span<const int64> slice_sizes);
406   XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
407                      absl::Span<const int64> slice_sizes);
408 
409   ABSL_DEPRECATED("Use span-of-indices form instead")
410   XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, XlaOp start_indices);
411   XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
412                            absl::Span<const XlaOp> start_indices);
413 
414   XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
415 
416   void Trace(const string& tag, XlaOp operand);
417 
418   XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
419 
420   XlaOp Tuple(absl::Span<const XlaOp> elements);
421 
422   XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
423 
424   XlaOp Dot(XlaOp lhs, XlaOp rhs,
425             const PrecisionConfig* precision_config = nullptr);
426 
427   XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
428                    const DotDimensionNumbers& dimension_numbers,
429                    const PrecisionConfig* precision_config = nullptr);
430 
431   XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
432              Padding padding, int64 feature_group_count = 1,
433              int64 batch_group_count = 1,
434              const PrecisionConfig* precision_config = nullptr);
435 
436   XlaOp ConvWithGeneralPadding(
437       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
438       absl::Span<const std::pair<int64, int64>> padding,
439       int64 feature_group_count = 1, int64 batch_group_count = 1,
440       const PrecisionConfig* precision_config = nullptr);
441 
442   XlaOp ConvWithGeneralDimensions(
443       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
444       Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
445       int64 feature_group_count = 1, int64 batch_group_count = 1,
446       const PrecisionConfig* precision_config = nullptr);
447 
448   XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs,
449                     absl::Span<const int64> window_strides,
450                     absl::Span<const std::pair<int64, int64>> padding,
451                     const ConvolutionDimensionNumbers& dimension_numbers,
452                     int64 feature_group_count = 1, int64 batch_group_count = 1,
453                     const PrecisionConfig* precision_config = nullptr);
454 
455   XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs,
456                            absl::Span<const int64> window_strides,
457                            absl::Span<const std::pair<int64, int64>> padding,
458                            absl::Span<const int64> lhs_dilation,
459                            absl::Span<const int64> rhs_dilation,
460                            const ConvolutionDimensionNumbers& dimension_numbers,
461                            int64 feature_group_count = 1,
462                            int64 batch_group_count = 1,
463                            const PrecisionConfig* precision_config = nullptr);
464 
465   XlaOp Fft(XlaOp operand, FftType fft_type,
466             absl::Span<const int64> fft_length);
467 
468   XlaOp Infeed(const Shape& shape, const string& config = "");
469   XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
470                         const string& config = "");
471 
472   void Outfeed(XlaOp operand, const Shape& shape_with_layout,
473                const string& outfeed_config);
474   XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
475                          const Shape& shape_with_layout,
476                          const string& outfeed_config);
477 
478   XlaOp Call(const XlaComputation& computation,
479              absl::Span<const XlaOp> operands);
480 
481   XlaOp CustomCall(
482       const string& call_target_name, absl::Span<const XlaOp> operands,
483       const Shape& shape_with_layout, const string& opaque,
484       absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
485 
486   XlaOp Reduce(XlaOp operand, XlaOp init_value,
487                const XlaComputation& computation,
488                absl::Span<const int64> dimensions_to_reduce);
489 
490   XlaOp Reduce(absl::Span<const XlaOp> operands,
491                absl::Span<const XlaOp> init_values,
492                const XlaComputation& computation,
493                absl::Span<const int64> dimensions_to_reduce);
494 
495   XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
496                   const XlaComputation& computation);
497 
498   XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
499                      const XlaComputation& computation,
500                      absl::Span<const int64> window_dimensions,
501                      absl::Span<const int64> window_strides, Padding padding);
502 
503   XlaOp ReduceWindowWithGeneralPadding(
504       XlaOp operand, XlaOp init_value, const XlaComputation& computation,
505       absl::Span<const int64> window_dimensions,
506       absl::Span<const int64> window_strides,
507       absl::Span<const int64> base_dilations,
508       absl::Span<const int64> window_dilations,
509       absl::Span<const std::pair<int64, int64>> padding);
510 
511   XlaOp CrossReplicaSum(XlaOp operand,
512                         absl::Span<const ReplicaGroup> replica_groups = {});
513 
514   XlaOp AllReduce(
515       XlaOp operand, const XlaComputation& computation,
516       absl::Span<const ReplicaGroup> replica_groups = {},
517       const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
518       const absl::optional<Shape>& shape_with_layout = absl::nullopt);
519 
520   XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
521                  int64 split_count,
522                  const std::vector<ReplicaGroup>& replica_groups);
523 
524   XlaOp CollectivePermute(
525       XlaOp operand,
526       const std::vector<std::pair<int64, int64>>& source_target_pairs);
527 
528   XlaOp ReplicaId();
529 
530   XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
531                          absl::Span<const int64> window_dimensions,
532                          absl::Span<const int64> window_strides,
533                          Padding padding, XlaOp source, XlaOp init_value,
534                          const XlaComputation& scatter);
535 
536   XlaOp SelectAndScatterWithGeneralPadding(
537       XlaOp operand, const XlaComputation& select,
538       absl::Span<const int64> window_dimensions,
539       absl::Span<const int64> window_strides,
540       absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
541       XlaOp init_value, const XlaComputation& scatter);
542 
543   XlaOp Iota(const Shape& shape, int64 iota_dimension);
544 
545   XlaOp Iota(PrimitiveType type, int64 size);
546 
547   XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
548 
549   XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
550 
551   XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
552 
553   XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
554 
555   XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
556              int64 dimension = -1, bool is_stable = false);
557 
558   XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
559 
560   XlaOp Map(absl::Span<const XlaOp> operands, const XlaComputation& computation,
561             absl::Span<const int64> dimensions,
562             absl::Span<const XlaOp> static_operands = {});
563 
564   XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
565 
566   XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
567 
568   XlaOp While(const XlaComputation& condition, const XlaComputation& body,
569               XlaOp init);
570 
571   XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
572                     const XlaComputation& true_computation, XlaOp false_operand,
573                     const XlaComputation& false_computation);
574 
575   XlaOp Conditional(XlaOp branch_index,
576                     absl::Span<const XlaComputation* const> branch_computations,
577                     absl::Span<const XlaOp> branch_operands);
578 
579   XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
580                         const int mantissa_bits);
581 
582   XlaOp Gather(XlaOp input, XlaOp start_indices,
583                const GatherDimensionNumbers& dimension_numbers,
584                absl::Span<const int64> slice_sizes,
585                bool indices_are_sorted = false);
586 
587   XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
588                 const XlaComputation& update_computation,
589                 const ScatterDimensionNumbers& dimension_numbers,
590                 bool indices_are_sorted = false, bool unique_indices = false);
591 
592   void Send(XlaOp operand, const ChannelHandle& handle);
593   XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle);
594 
595   XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout,
596                    const ChannelHandle& handle);
597 
598   XlaOp RecvFromHost(XlaOp token, const Shape& shape,
599                      const ChannelHandle& handle);
600 
601   XlaOp CreateToken();
602 
603   XlaOp AfterAll(absl::Span<const XlaOp> tokens);
604 
605   XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
606   XlaOp RecvWithToken(XlaOp token, const Shape& shape,
607                       const ChannelHandle& handle);
608 
609   XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
610                           float epsilon, int64 feature_index);
611 
612   XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean,
613                            XlaOp variance, float epsilon, int64 feature_index);
614 
615   XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
616                       XlaOp batch_var, XlaOp grad_output, float epsilon,
617                       int64 feature_index);
618 
619   XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
620 
621   XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
622 
623   StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
624                                  absl::Span<const XlaOp> operands = {});
625 
626   void AddCalledComputation(const XlaComputation& computation,
627                             HloInstructionProto* instr);
628 
629   StatusOr<const HloInstructionProto*> LookUpInstruction(XlaOp op) const;
630   StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
631       int64 handle) const;
632   StatusOr<HloInstructionProto*> LookUpMutableInstruction(XlaOp op);
633   StatusOr<HloInstructionProto*> LookUpMutableInstructionByHandle(int64 handle);
634 
635   // Internal helper method that does the building for an arbitrary unary op.
636   XlaOp UnaryOp(HloOpcode unop, XlaOp operand);
637 
638   // Internal helper method that does the building for an arbitrary binary op.
639   // broadcast_dimensions specifies which dimensions to use for broadcasting
640   // when the operation is between tensors of different ranks. The direction is
641   // only used if opcode is kCompare.
642   XlaOp BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
643                  absl::Span<const int64> broadcast_dimensions,
644                  absl::optional<ComparisonDirection> direction = absl::nullopt);
645 
646   // Internal helper method that does the building for an arbitrary ternary op.
647   XlaOp TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs);
648 
649   XlaOp RngOp(RandomDistribution distribution,
650               absl::Span<const XlaOp> parameters, const Shape& shape);
651 
652   StatusOr<XlaOp> InDimBroadcast(const Shape& shape, XlaOp operand,
653                                  absl::Span<const int64> broadcast_dimensions);
654 
655   // Internal helper method that creates a sequence of instructions that
656   // performs an explicit broadcast of the operand to the target shape.
657   StatusOr<XlaOp> AddBroadcastSequence(const Shape& output_shape,
658                                        XlaOp operand);
659 
660   // Internal helper method for creating a Reshape op with the already inferred
661   // shape.
662   StatusOr<XlaOp> Reshape(const Shape& shape, XlaOp operand,
663                           int64 inferred_dimension = -1);
664 
665   // Returns the (inferred) result for the program shape using the given root.
666   StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;
667 
668   // Returns shapes for the operands.
669   StatusOr<std::vector<Shape>> GetOperandShapes(
670       absl::Span<const XlaOp> operands) const;
671 
672   // A visitor which checks whether an operation is a compile-time constant,
673   // meaning that it doesn't depend on any parameters, or on any stateful
674   // operation such as `RngNormal` or `Infeed`. The visitor walks the
675   // computation starting at a given operation and sets is_constant to false iff
676   // a parameter or stateful operation is encountered.
677   void IsConstantVisitor(const int64 op_handle,
678                          absl::flat_hash_set<int64>* visited,
679                          bool* is_constant) const;
680 
681   // Checks bounds for convolution parameters.
682   Status VerifyConvolution(
683       const Shape& lhs_shape, const Shape& rhs_shape,
684       const ConvolutionDimensionNumbers& dimension_numbers) const;
685 
GetNextId()686   int64 GetNextId() { return ++next_id_; }
687 
688   // Populates the module with the input/output alias information stored within
689   // the input_output_aliases vector.
690   static Status PopulateInputOutputAlias(
691       HloModuleProto* module, const ProgramShape& program_shape,
692       const std::vector<InputOutputAlias>& input_output_aliases);
693 
694   string name_;  // Name to use for the built computation.
695 
696   // The next sequential ID for every instruction/computation contained within
697   // this computation.
698   int64 next_id_ = 0;
699 
700   // The first error encountered while building the computation.
701   // This is OK until the first error is encountered.
702   Status first_error_;
703 
704   // The saved stack trace from the point at which the first error occurred.
705   tensorflow::SavedStackTrace first_error_backtrace_;
706 
707   // The instructions of this computation.
708   std::vector<HloInstructionProto> instructions_;
709 
710   // An cache for the HloInstructionProto shapes, to avoid recreating Shape
711   // objects from protos and to support the GetShapePtr() API.
712   std::vector<std::unique_ptr<Shape>> instruction_shapes_;
713 
714   // Dynamic parameter configuration of this computation.
715   DynamicParameterBinding dynamic_parameter_binding_;
716 
717   // Holds the input/output alias information populated by the SetUpAlias() API.
718   std::vector<InputOutputAlias> input_output_aliases_;
719 
720   // A map from XlaOp::Handle to the index in the instructions_ vector where the
721   // instruction is held.
722   absl::flat_hash_map<int64, int64> handle_to_index_;
723 
724   // The embedded computations used by this computation. Each computation was
725   // the entry computation of some XlaComputation, the key is the unique id of
726   // that XlaComputation.
727   std::map<int64, HloComputationProto> embedded_;
728 
729   // The unique parameter numbers.
730   absl::flat_hash_set<int64> parameter_numbers_;
731 
732   // The metadata to attach to each op. This is structured as a "modal"-like
733   // operation, in order to simplify client code (and not sprinkle this metadata
734   // throughout the TensorFlow op kernel implementations).
735   OpMetadata metadata_;
736 
737   // Sharding for this operator. This is structured as a "model"-like operation,
738   // in order to simplify client code, similar to metadata_.
739   absl::optional<OpSharding> sharding_;
740 
741   // Mode bit that indicates whether to die when a first error is encountered.
742   bool die_immediately_on_error_ = false;
743 
744   XlaBuilder* parent_builder_{nullptr};
745 
746   FrontendAttributes frontend_attributes_;
747 
748   friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number,
749                          const Shape& shape, const string& name,
750                          const std::vector<bool>& replicated_at_leaf_buffers);
751   friend XlaOp ConstantLiteral(XlaBuilder* builder,
752                                const LiteralSlice& literal);
753 
754   friend XlaOp Broadcast(XlaOp operand,
755                          absl::Span<const int64> broadcast_sizes);
756 
757   friend XlaOp BroadcastInDim(
758       XlaOp operand, const absl::Span<const int64> out_dim_size,
759       const absl::Span<const int64> broadcast_dimensions);
760 
761   friend XlaOp Copy(XlaOp operand);
762 
763   friend XlaOp Pad(XlaOp operand, XlaOp padding_value,
764                    const PaddingConfig& padding_config);
765 
766   friend XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
767                        absl::Span<const int64> new_sizes);
768 
769   friend XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
770 
771   friend XlaOp ReshapeWithInferredDimension(XlaOp operand,
772                                             absl::Span<const int64> new_sizes,
773                                             int64 inferred_dimension);
774 
775   friend XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
776 
777   friend XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
778                      absl::Span<const int64> limit_indices,
779                      absl::Span<const int64> strides);
780 
781   friend XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
782                           int64 stride, int64 dimno);
783 
784   friend XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices,
785                             absl::Span<const int64> slice_sizes);
786   friend XlaOp DynamicSlice(XlaOp operand,
787                             absl::Span<const XlaOp> start_indices,
788                             absl::Span<const int64> slice_sizes);
789 
790   friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
791                                   XlaOp start_indices);
792   friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
793                                   absl::Span<const XlaOp> start_indices);
794 
795   friend XlaOp ConcatInDim(XlaBuilder* builder,
796                            absl::Span<const XlaOp> operands, int64 dimension);
797 
798   friend void Trace(const string& tag, XlaOp operand);
799 
800   friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
801   friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
802   friend XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
803   friend XlaOp Eq(XlaOp lhs, XlaOp rhs,
804                   absl::Span<const int64> broadcast_dimensions);
805   friend XlaOp Ne(XlaOp lhs, XlaOp rhs,
806                   absl::Span<const int64> broadcast_dimensions);
807   friend XlaOp Ge(XlaOp lhs, XlaOp rhs,
808                   absl::Span<const int64> broadcast_dimensions);
809   friend XlaOp Gt(XlaOp lhs, XlaOp rhs,
810                   absl::Span<const int64> broadcast_dimensions);
811   friend XlaOp Lt(XlaOp lhs, XlaOp rhs,
812                   absl::Span<const int64> broadcast_dimensions);
813   friend XlaOp Le(XlaOp lhs, XlaOp rhs,
814                   absl::Span<const int64> broadcast_dimensions);
815   friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
816                        absl::Span<const int64> broadcast_dimensions,
817                        ComparisonDirection direction);
818   friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
819                    const PrecisionConfig* precision_config);
820   friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
821                           const DotDimensionNumbers& dimension_number,
822                           const PrecisionConfig* precision_config);
823   friend XlaOp Conv(XlaOp lhs, XlaOp rhs,
824                     absl::Span<const int64> window_strides, Padding padding,
825                     int64 feature_group_count, int64 batch_group_count,
826                     const PrecisionConfig* precision_config);
827   friend XlaOp ConvWithGeneralPadding(
828       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
829       absl::Span<const std::pair<int64, int64>> padding,
830       int64 feature_group_count, int64 batch_group_count,
831       const PrecisionConfig* precision_config);
832   friend XlaOp ConvWithGeneralDimensions(
833       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
834       Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
835       int64 feature_group_count, int64 batch_group_count,
836       const PrecisionConfig* precision_config);
837   friend XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs,
838                            absl::Span<const int64> window_strides,
839                            absl::Span<const std::pair<int64, int64>> padding,
840                            const ConvolutionDimensionNumbers& dimension_numbers,
841                            int64 feature_group_count, int64 batch_group_count,
842                            const PrecisionConfig* precision_config);
843   friend XlaOp ConvGeneralDilated(
844       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
845       absl::Span<const std::pair<int64, int64>> padding,
846       absl::Span<const int64> lhs_dilation,
847       absl::Span<const int64> rhs_dilation,
848       const ConvolutionDimensionNumbers& dimension_numbers,
849       int64 feature_group_count, int64 batch_group_count,
850       const PrecisionConfig* precision_config);
851   friend XlaOp Fft(XlaOp operand, FftType fft_type,
852                    absl::Span<const int64> fft_length);
853   friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
854                                bool unit_diagonal,
855                                TriangularSolveOptions::Transpose transpose_a);
856   friend XlaOp Cholesky(XlaOp a, bool lower);
857   friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
858                       const string& config);
859   friend void Outfeed(XlaOp operand, const Shape& shape_with_layout,
860                       const string& outfeed_config);
861   friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
862                     absl::Span<const XlaOp> operands);
863   friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
864                           absl::Span<const XlaOp> operands, const Shape& shape,
865                           const string& opaque);
866   friend XlaOp CustomCallWithLayout(
867       XlaBuilder* builder, const string& call_target_name,
868       absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
869       absl::Span<const Shape> operand_shapes_with_layout, const string& opaque);
870   friend XlaOp Complex(XlaOp real, XlaOp imag,
871                        absl::Span<const int64> broadcast_dimensions);
872   friend XlaOp Conj(XlaOp operand);
873   friend XlaOp Add(XlaOp lhs, XlaOp rhs,
874                    absl::Span<const int64> broadcast_dimensions);
875   friend XlaOp Sub(XlaOp lhs, XlaOp rhs,
876                    absl::Span<const int64> broadcast_dimensions);
877   friend XlaOp Mul(XlaOp lhs, XlaOp rhs,
878                    absl::Span<const int64> broadcast_dimensions);
879   friend XlaOp Div(XlaOp lhs, XlaOp rhs,
880                    absl::Span<const int64> broadcast_dimensions);
881   friend XlaOp Rem(XlaOp lhs, XlaOp rhs,
882                    absl::Span<const int64> broadcast_dimensions);
883   friend XlaOp Max(XlaOp lhs, XlaOp rhs,
884                    absl::Span<const int64> broadcast_dimensions);
885   friend XlaOp Min(XlaOp lhs, XlaOp rhs,
886                    absl::Span<const int64> broadcast_dimensions);
887   friend XlaOp And(XlaOp lhs, XlaOp rhs,
888                    absl::Span<const int64> broadcast_dimensions);
889   friend XlaOp Or(XlaOp lhs, XlaOp rhs,
890                   absl::Span<const int64> broadcast_dimensions);
891   friend XlaOp Xor(XlaOp lhs, XlaOp rhs,
892                    absl::Span<const int64> broadcast_dimensions);
893   friend XlaOp Not(XlaOp operand);
894   friend XlaOp PopulationCount(XlaOp operand);
895   friend XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs,
896                          absl::Span<const int64> broadcast_dimensions);
897   friend XlaOp ShiftRightArithmetic(
898       XlaOp lhs, XlaOp rhs, absl::Span<const int64> broadcast_dimensions);
899   friend XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs,
900                                  absl::Span<const int64> broadcast_dimensions);
901   friend XlaOp Reduce(XlaOp operand, XlaOp init_value,
902                       const XlaComputation& computation,
903                       absl::Span<const int64> dimensions_to_reduce);
904   friend XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
905                       absl::Span<const XlaOp> init_values,
906                       const XlaComputation& computation,
907                       absl::Span<const int64> dimensions_to_reduce);
908   friend XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
909                          const XlaComputation& computation);
910   friend XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
911                             const XlaComputation& computation,
912                             absl::Span<const int64> window_dimensions,
913                             absl::Span<const int64> window_strides,
914                             Padding padding);
915   friend XlaOp ReduceWindowWithGeneralPadding(
916       XlaOp operand, XlaOp init_value, const XlaComputation& computation,
917       absl::Span<const int64> window_dimensions,
918       absl::Span<const int64> window_strides,
919       absl::Span<const int64> base_dilations,
920       absl::Span<const int64> window_dilations,
921       absl::Span<const std::pair<int64, int64>> padding);
922   friend XlaOp CrossReplicaSum(XlaOp operand,
923                                absl::Span<const ReplicaGroup> replica_groups);
924   friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
925                          absl::Span<const ReplicaGroup> replica_groups,
926                          const absl::optional<ChannelHandle>& channel_id,
927                          const absl::optional<Shape>& shape_with_layout);
928   friend XlaOp AllToAll(XlaOp operand, int64 split_dimension,
929                         int64 concat_dimension, int64 split_count,
930                         const std::vector<ReplicaGroup>& replica_groups);
931   friend XlaOp CollectivePermute(
932       XlaOp operand,
933       const std::vector<std::pair<int64, int64>>& source_target_pairs);
934   friend XlaOp ReplicaId(XlaBuilder* builder);
935   friend XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
936                                 absl::Span<const int64> window_dimensions,
937                                 absl::Span<const int64> window_strides,
938                                 Padding padding, XlaOp source, XlaOp init_value,
939                                 const XlaComputation& scatter);
940   friend XlaOp SelectAndScatterWithGeneralPadding(
941       XlaOp operand, const XlaComputation& select,
942       absl::Span<const int64> window_dimensions,
943       absl::Span<const int64> window_strides,
944       absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
945       XlaOp init_value, const XlaComputation& scatter);
946   friend XlaOp Abs(XlaOp operand);
947   friend XlaOp Atan2(XlaOp y, XlaOp x,
948                      absl::Span<const int64> broadcast_dimensions);
949   friend XlaOp Exp(XlaOp operand);
950   friend XlaOp Expm1(XlaOp operand);
951   friend XlaOp Floor(XlaOp operand);
952   friend XlaOp Ceil(XlaOp operand);
953   friend XlaOp Round(XlaOp operand);
954   friend XlaOp Log(XlaOp operand);
955   friend XlaOp Log1p(XlaOp operand);
956   friend XlaOp Sign(XlaOp operand);
957   friend XlaOp Clz(XlaOp operand);
958   friend XlaOp Cos(XlaOp operand);
959   friend XlaOp Sin(XlaOp operand);
960   friend XlaOp Tanh(XlaOp operand);
961   friend XlaOp Real(XlaOp operand);
962   friend XlaOp Imag(XlaOp operand);
963   friend XlaOp Sqrt(XlaOp operand);
964   friend XlaOp Rsqrt(XlaOp operand);
965   friend XlaOp Pow(XlaOp lhs, XlaOp rhs,
966                    absl::Span<const int64> broadcast_dimensions);
967   friend XlaOp IsFinite(XlaOp operand);
968   friend XlaOp Iota(XlaBuilder* builder, const Shape& shape,
969                     int64 iota_dimension);
970   friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
971   friend XlaOp ConvertElementType(XlaOp operand,
972                                   PrimitiveType new_element_type);
973   friend XlaOp BitcastConvertType(XlaOp operand,
974                                   PrimitiveType new_element_type);
975   friend XlaOp Neg(XlaOp operand);
976   friend XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
977   friend XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
978   friend XlaOp Sort(absl::Span<const XlaOp> operands,
979                     const XlaComputation& comparator, int64 dimension,
980                     bool is_stable);
981   friend XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
982   friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
983                    const XlaComputation& computation,
984                    absl::Span<const int64> dimensions,
985                    absl::Span<const XlaOp> static_operands);
986   friend XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
987   friend XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
988   friend XlaOp While(const XlaComputation& condition,
989                      const XlaComputation& body, XlaOp init);
990   friend XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
991                            const XlaComputation& true_computation,
992                            XlaOp false_operand,
993                            const XlaComputation& false_computation);
994   friend XlaOp Conditional(
995       XlaOp branch_index,
996       absl::Span<const XlaComputation* const> branch_computations,
997       absl::Span<const XlaOp> branch_operands);
998   friend XlaOp ConditionalImpl(
999       XlaOp branch_index,
1000       absl::Span<const XlaComputation* const> branch_computations,
1001       absl::Span<const XlaOp> branch_operands);
1002   friend XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
1003                                const int mantissa_bits);
1004   friend XlaOp Gather(XlaOp input, XlaOp start_indices,
1005                       const GatherDimensionNumbers& dimension_numbers,
1006                       absl::Span<const int64> slice_sizes,
1007                       bool indices_are_sorted);
1008   friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
1009                        const XlaComputation& update_computation,
1010                        const ScatterDimensionNumbers& dimension_numbers,
1011                        bool indices_are_sorted, bool unique_indices);
1012   friend void Send(XlaOp operand, const ChannelHandle& handle);
1013   friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
1014                     const ChannelHandle& handle);
1015   friend XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
1016                                  float epsilon, int64 feature_index);
1017   friend XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset,
1018                                   XlaOp mean, XlaOp variance, float epsilon,
1019                                   int64 feature_index);
1020   friend XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
1021                              XlaOp batch_var, XlaOp grad_output, float epsilon,
1022                              int64 feature_index);
1023   friend XlaOp SendWithToken(XlaOp operand, XlaOp token,
1024                              const ChannelHandle& handle);
1025   friend XlaOp RecvWithToken(XlaOp token, const Shape& shape,
1026                              const ChannelHandle& handle);
1027   friend XlaOp SendToHost(XlaOp operand, XlaOp token,
1028                           const Shape& shape_with_layout,
1029                           const ChannelHandle& handle);
1030   friend XlaOp RecvFromHost(XlaOp token, const Shape& shape,
1031                             const ChannelHandle& handle);
1032   friend XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
1033                                const string& config);
1034   friend XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
1035                                 const Shape& shape_with_layout,
1036                                 const string& outfeed_config);
1037   friend XlaOp CreateToken(XlaBuilder* builder);
1038   friend XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
1039 
1040   friend XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
1041   friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
1042 
1043  private:
1044   XlaOp ConditionalImpl(
1045       XlaOp branch_index,
1046       absl::Span<const XlaComputation* const> branch_computations,
1047       absl::Span<const XlaOp> branch_operands);
1048 };
1049 
1050 // RAII-style object: sets the current sharding assignment in builder on
1051 // construction, and sets back to the previous assignment on destruction.
1052 class XlaScopedShardingAssignment {
1053  public:
XlaScopedShardingAssignment(xla::XlaBuilder * builder,absl::optional<OpSharding> sharding)1054   XlaScopedShardingAssignment(xla::XlaBuilder* builder,
1055                               absl::optional<OpSharding> sharding)
1056       : builder_(builder), prev_sharding_(builder->sharding()) {
1057     SetSharding(sharding);
1058   }
1059 
1060   XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
1061   XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
1062       delete;
1063 
~XlaScopedShardingAssignment()1064   ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
1065 
1066  private:
SetSharding(const absl::optional<OpSharding> & sharding)1067   void SetSharding(const absl::optional<OpSharding>& sharding) {
1068     if (sharding.has_value()) {
1069       builder_->SetSharding(sharding.value());
1070     } else {
1071       builder_->ClearSharding();
1072     }
1073   }
1074 
1075   xla::XlaBuilder* const builder_;
1076   absl::optional<OpSharding> prev_sharding_;
1077 };
1078 
1079 // RAII-style object: save the current builder's frontend attributes, and merge
1080 // them with the new ones on construction.
1081 // Restore the original attributes on destruction.
1082 class XlaScopedFrontendAttributesAssignment {
1083  public:
XlaScopedFrontendAttributesAssignment(xla::XlaBuilder * builder,FrontendAttributes attributes)1084   XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder,
1085                                         FrontendAttributes attributes)
1086       : builder_(builder) {
1087     saved_ = builder_->SwapFrontendAttributes(attributes);
1088   }
1089 
~XlaScopedFrontendAttributesAssignment()1090   ~XlaScopedFrontendAttributesAssignment() {
1091     builder_->SetFrontendAttributes(saved_);
1092   }
1093 
1094  private:
1095   xla::XlaBuilder* const builder_;
1096   FrontendAttributes saved_;
1097 
1098   TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedFrontendAttributesAssignment);
1099 };
1100 // Free functions for building XlaOps. The intention is that these will
1101 // become the public API for building XlaOps rather than calling methods on
1102 // XlaBuilder directly.
1103 //
1104 
1105 // Enqueues a "retrieve parameter value" instruction for a parameter that was
1106 // passed to the computation.
1107 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
1108                 const string& name);
1109 
1110 // Same as above, but with leaf buffer replication annotation.
1111 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
1112                 const string& name,
1113                 const std::vector<bool>& replicated_at_leaf_buffers);
1114 
1115 // Enqueues a constant with the value of the given literal onto the
1116 // computation.
1117 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal);
1118 
1119 // Enqueues a constant onto the computation. Methods are templated on the
1120 // native host type (NativeT) which corresponds to a specific XLA
1121 // PrimitiveType as given in the following table:
1122 //
1123 //  Native Type   PrimitiveType
1124 // -----------------------------
1125 //   bool           PRED
1126 //   int32          S32
1127 //   int64          S64
1128 //   uint32         U32
1129 //   uint64         U64
1130 //   float          F32
1131 //   double         F64
1132 //
1133 // Note: not all primitive types defined in xla_data.proto have a
1134 // corresponding native type yet.
1135 template <typename NativeT>
1136 XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
1137 template <typename NativeT>
1138 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values);
1139 XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values);
1140 template <typename NativeT>
1141 XlaOp ConstantR2(XlaBuilder* builder,
1142                  std::initializer_list<std::initializer_list<NativeT>> values);
1143 template <typename NativeT>
1144 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
1145                                   const Array<NativeT>& values,
1146                                   const Layout& layout);
1147 template <typename NativeT>
1148 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values);
1149 template <typename NativeT>
1150 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
1151                                       const Array2D<NativeT>& values,
1152                                       const Layout& layout);
1153 template <typename NativeT>
1154 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
1155                             const Array2D<NativeT>& values);
1156 template <typename NativeT>
1157 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
1158                                       const Array3D<NativeT>& values,
1159                                       const Layout& layout);
1160 template <typename NativeT>
1161 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
1162                             const Array3D<NativeT>& values);
1163 template <typename NativeT>
1164 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
1165                                       const Array4D<NativeT>& values,
1166                                       const Layout& layout);
1167 template <typename NativeT>
1168 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
1169                             const Array4D<NativeT>& values);
1170 
1171 // Enqueues a rank one constant (XlaBuilder* builder, vector) onto the
1172 // computation. The vector has size 'length' and every element has the value
1173 // 'value'.
1174 template <typename NativeT>
1175 XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
1176 
1177 // Adds dimensions to an array by duplicating the data in the array.
1178 //
1179 // The new dimensions are inserted on the left, i.e. if
1180 // broadcast_sizes has values {a0, ..., aN} and the operand shape
1181 // has dimensions {b0, ..., bM} then the shape of the output has
1182 // dimensions {a0, ..., aN, b0, ..., bM}.
1183 //
1184 // The new dimensions index into copies of the operand, i.e.
1185 //
1186 //   output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
1187 XlaOp Broadcast(XlaOp operand, absl::Span<const int64> broadcast_sizes);
1188 
1189 // This op broadcasts the `operand` to an output with the given `shape`.
1190 // `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the
1191 // i'th dimension of the operand is mapped to the broadcast_dimensions[i]'th
1192 // dimension of the output. This also requires that the i'th input dimension is
1193 // either 1 or is the same as the output dimension it's broadcasting into.
1194 //
1195 // For example, say operand = {1, 2}, i.e., a 1D tensor in shape s32[2]; the
1196 // output shape is s32[2,2]:
1197 // - Specifying {1} as broadcast_dimension will generate output
1198 //   {{1, 2},
1199 //    {1, 2}}
1200 // - On the other hand, specifying {0} as broadcast_dimension
1201 //   will generate output
1202 //   {{1 , 1},
1203 //    {2 , 2}}
1204 XlaOp BroadcastInDim(XlaOp operand, const absl::Span<const int64> out_dim_size,
1205                      const absl::Span<const int64> broadcast_dimensions);
1206 
1207 // Copies the input operand to the output. This operation is for internal
1208 // purpose and is only used by the compiler for optimization purposes or to
1209 // ensure correctness. The XLA client should never have to generate this
1210 // instruction.
1211 //
1212 // Copy has two potential use cases:
1213 //
1214 // * Create a copy of the operand with a new layout.
1215 //
1216 // * Create a copy of the operand in a separately allocated buffer. This is
1217 //   necessary for some backends if the operand is a parameter or constant and
1218 //   the operand is returned within a tuple. In this case, the lifetime of the
1219 //   operand buffer must be the same as the lifetime of the output result.
1220 //   However, the lifetimes of parameters and constants are managed separately
1221 //   from the lifetime of the output result. Creating a separate copy of the
1222 //   parameter or constant buffer resolves this issue.
1223 XlaOp Copy(XlaOp operand);
1224 
1225 // Enqueues a pad operation onto the computation that pads the given value on
1226 // the edges as well as between the elements of the input. padding_config
1227 // specifies the padding amount for each dimension.
1228 XlaOp Pad(XlaOp operand, XlaOp padding_value,
1229           const PaddingConfig& padding_config);
1230 
1231 // Enqueues an operation onto the computation that flattens the operand based
1232 // on the dimension order (major/slowest-varying to minor/fastest-varying)
1233 // given, followed by reshaping it into the shape with the given dimension
1234 // sizes (also major to minor). Conceptually, this is a limited form of
1235 // "shape casting".
1236 XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
1237               absl::Span<const int64> new_sizes);
1238 
1239 // Enqueues an operation onto the computation that collapses the operand, from
1240 // first to last dimension (C order), then reshapes it to the given dimension
1241 // sizes. Conceptually, this is a limited form of "shape casting".
1242 XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
1243 
1244 // `inferred_dimension` represents the output dimension that's inferred by
1245 // upper-level framework by dividing the input element count by the known
1246 // output element count. While an inferred_dimension can be static, if there
1247 // is a dynamic dimension in the output, it must be the inferred dimension.
1248 XlaOp ReshapeWithInferredDimension(XlaOp operand,
1249                                    absl::Span<const int64> new_sizes,
1250                                    int64 inferred_dimension);
1251 
1252 // Wrapper for Reshape.
1253 // Enqueues an operation to collapse the provided dimensions; e.g. an
1254 // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
1255 // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
1256 // be a consecutive, in-order subsequence of the operand dimensions.
1257 //
1258 // Note that collapsing a single dimension does nothing:
1259 //
1260 //    {256} collapsing {0} => {256}
1261 //    {1} collapsing {0} => {1}
1262 //
1263 // Collapsing multiple dimensions produces a single result dimension:
1264 //
1265 //    {256, 2} collapsing {0,1} => {512}
1266 //    {256, 2, 3} collapsing {0,1} => {512, 3}
1267 //
1268 // This could potentially cause data to be moved -- it provides a more
1269 // structured form of reshaping than an arbitrary Reshape operation.
1270 XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
1271 
1272 // Enqueues a slice operation onto the computation that slices the operand
1273 // from the start indices to the limit indices; e.g.
1274 //
1275 //        x
1276 //   [ 0 1 2 3 ]
1277 // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
1278 //   [ 8 9 a b ]
1279 //
1280 // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
1281 // range notation.
1282 // The strides parameter determines the stride over the slice
1283 XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
1284             absl::Span<const int64> limit_indices,
1285             absl::Span<const int64> strides);
1286 
1287 // Enqueues a slice operation in a given dimension, taking all other
1288 // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
1289 // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
1290 // for:
1291 //
1292 //  array[:, 2:4:1, :]
1293 XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
1294                  int64 stride, int64 dimno);
1295 
1296 // Enqueues a slice operation onto the computation that slices the 'operand'
1297 // from dynamic start indices which are passed in 'start_indices'.
1298 // The size of the slice in each dimension is passed in 'slice_sizes',
1299 // which specify the end point of exclusive slice intervals in each
1300 // dimension [start, start + size).
1301 // The shape of each element of 'start_indices' must be scalar, with the span
1302 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1303 // have the same shape.
1304 // Slice index calculations are computed modulo input dimension sizes to
1305 // prevent dynamic start indices from generating out-of-bound array accesses.
1306 XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
1307                    absl::Span<const int64> slice_sizes);
1308 
1309 ABSL_DEPRECATED("Use span-of-indices form instead")
1310 XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices,
1311                    absl::Span<const int64> slice_sizes);
1312 
1313 // Enqueues a dynamic update slice operation onto the computation, which
1314 // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
1315 // The shape of 'update' determines the shape of the slice of 'operand'
1316 // which is updated.
1317 // The indices specified in 'start_indices' specify the offset of the slice
1318 // of 'operand' which is updated.
1319 //
1320 //               update = {10, 11} // calculated at runtime.
1321 //   [1 2 3]     start  = {1, 1}   // calculated at runtime.  [1 2  3 ]
1322 //   [4 5 6]  => DynamicUpdateslice(data, update, start)   => [4 10 11]
1323 //   [7 8 9]                                                  [7 8  9 ]
1324 //
1325 // The shape of each element of 'start_indices' must be scalar, with the span
1326 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1327 // have the same shape.
1328 // Slice index calculations are computed modulo update dimension sizes to
1329 // prevent dynamic start indices from generating out-of-bound array accesses.
1330 XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
1331                          absl::Span<const XlaOp> start_indices);
1332 
1333 ABSL_DEPRECATED("Use span-of-indices form instead")
1334 XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, XlaOp start_indices);
1335 
1336 // Enqueues a concatenate instruction onto the computation. 'operands' must
1337 // have >= 1 entry.
1338 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1339                   int64 dimension);
1340 
1341 // Enqueue a tracing operation onto the computation; the computation will emit
1342 // a logging message with the operand.
1343 void Trace(const string& tag, XlaOp operand);
1344 
1345 // Enqueues a conditional-move-like select operation onto the computation;
1346 // predicated on pred, selects between on_true and on_false.
1347 XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
1348 
1349 // Enqueues a tuple-creation instruction onto the computation.
1350 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
1351 
1352 // Enqueues a tuple-element-get instruction onto the computation.
1353 XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
1354 
1355 // Enqueues an equal-to comparison instruction onto the computation.
1356 XlaOp Eq(XlaOp lhs, XlaOp rhs,
1357          absl::Span<const int64> broadcast_dimensions = {});
1358 
1359 // Enqueues a not-equal comparison instruction onto the computation.
1360 XlaOp Ne(XlaOp lhs, XlaOp rhs,
1361          absl::Span<const int64> broadcast_dimensions = {});
1362 
1363 // Enqueues a greater-or-equal comparison instruction onto the computation.
1364 XlaOp Ge(XlaOp lhs, XlaOp rhs,
1365          absl::Span<const int64> broadcast_dimensions = {});
1366 
1367 // Enqueues a greater-than comparison instruction onto the computation.
1368 XlaOp Gt(XlaOp lhs, XlaOp rhs,
1369          absl::Span<const int64> broadcast_dimensions = {});
1370 
1371 // Enqueues a less-than comparison instruction onto the computation.
1372 XlaOp Lt(XlaOp lhs, XlaOp rhs,
1373          absl::Span<const int64> broadcast_dimensions = {});
1374 
1375 // Enqueues a less-or-equal comparison instruction onto the computation.
1376 XlaOp Le(XlaOp lhs, XlaOp rhs,
1377          absl::Span<const int64> broadcast_dimensions = {});
1378 
1379 // Enqueues a comparison instruction onto the computation.
1380 XlaOp Compare(XlaOp lhs, XlaOp rhs,
1381               absl::Span<const int64> broadcast_dimensions,
1382               ComparisonDirection direction);
1383 
1384 // Enqueues a dot instruction onto the computation.
1385 XlaOp Dot(XlaOp lhs, XlaOp rhs,
1386           const PrecisionConfig* precision_config = nullptr);
1387 
1388 // Enqueues a general dot instruction onto the computation.
1389 XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
1390                  const DotDimensionNumbers& dimension_numbers,
1391                  const PrecisionConfig* precision_config = nullptr);
1392 
1393 // Enqueues a convolution instruction onto the computation, which uses the
1394 // default convolution dimension numbers.
1395 XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1396            Padding padding, int64 feature_group_count = 1,
1397            int64 batch_group_count = 1,
1398            const PrecisionConfig* precision_config = nullptr);
1399 
1400 // Enqueues a convolution instruction onto the computation, with the caller
1401 // provided padding configuration in the format returned by MakePadding().
1402 XlaOp ConvWithGeneralPadding(XlaOp lhs, XlaOp rhs,
1403                              absl::Span<const int64> window_strides,
1404                              absl::Span<const std::pair<int64, int64>> padding,
1405                              int64 feature_group_count = 1,
1406                              int64 batch_group_count = 1,
1407                              const PrecisionConfig* precision_config = nullptr);
1408 
1409 // Enqueues a convolution instruction onto the computation, with the caller
1410 // provided dimension numbers configuration.
1411 XlaOp ConvWithGeneralDimensions(
1412     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1413     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1414     int64 feature_group_count = 1, int64 batch_group_count = 1,
1415     const PrecisionConfig* precision_config = nullptr);
1416 
1417 // Enqueues a convolution instruction onto the computation, with the caller
1418 // provided padding configuration as well as the dimension numbers.
1419 XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1420                   absl::Span<const std::pair<int64, int64>> padding,
1421                   const ConvolutionDimensionNumbers& dimension_numbers,
1422                   int64 feature_group_count = 1, int64 batch_group_count = 1,
1423                   const PrecisionConfig* precision_config = nullptr);
1424 
1425 // Enqueues a convolution instruction onto the computation, with the caller
1426 // provided padding configuration, dilation factors and dimension numbers.
1427 XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs,
1428                          absl::Span<const int64> window_strides,
1429                          absl::Span<const std::pair<int64, int64>> padding,
1430                          absl::Span<const int64> lhs_dilation,
1431                          absl::Span<const int64> rhs_dilation,
1432                          const ConvolutionDimensionNumbers& dimension_numbers,
1433                          int64 feature_group_count = 1,
1434                          int64 batch_group_count = 1,
1435                          const PrecisionConfig* precision_config = nullptr);
1436 
1437 // Enqueues an FFT instruction onto the computation, of the given type and
1438 // with the given FFT length.
1439 XlaOp Fft(XlaOp operand, FftType fft_type, absl::Span<const int64> fft_length);
1440 
1441 // Solves systems of linear equations with lower or upper triangular coefficient
1442 // matrices by forward- or back-substitution. Broadcasting along leading
1443 // dimensions, this routine solves for x in one of the matrix systems
1444 //   `op(a) * x = b`,  or `x * op(a) = b`,
1445 // for the variable `x` given `a` and `b`, where `op(a)` is either
1446 //   `op(a) = a`,  or `op(a) = transpose(a)`,  or `op(a) = conj(transpose(a))`.
1447 //
1448 // * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form
1449 //   square matrices. If `lower` is true (false), then the strictly upper
1450 //   (lower) triangular part of each innermost matrix in `a` is assumed to be
1451 //   zero and is not accessed.
1452 // * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a
1453 //   tensor of shape `[..., K, M]`.
1454 // * `left_side` is a boolean, indicating whether to solve a system of the form
1455 //   op(a) * x = b (true) or x * op(a) = b (false).
1456 // * `lower` is a boolean, indicating whether the argument `a` is
1457 //   lower-triangular (true) or upper-triangular (false).
1458 // * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be
1459 //   1 and not accessed.
1460 // * `transpose_a` indicates which function `op` we use to transform the tensor
1461 //   `a`: the identity function, transpose(a), or conjugate(transpose(a))
1462 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
1463                       bool unit_diagonal,
1464                       TriangularSolveOptions::Transpose transpose_a);
1465 
1466 // Computes the Cholesky decompositions of a batch of symmetric (Hermitian)
1467 // positive definite matrices.
1468 // `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the
1469 // two minor dimensions equal.
1470 // If `lower` is true, the data from the lower triangle is used; if false, the
1471 // upper triangle is used. The input data in the other triangle of the input
1472 // does not affect the output. Returns the output in the same lower/upper
1473 // triangle. The data returned in the other output triangle is arbitrary and
1474 // implementation-defined.
1475 //
1476 // If `a` is not Hermitian positive definite, returns an array full of NaNs.
1477 XlaOp Cholesky(XlaOp a, bool lower);
1478 
1479 // Enqueues an infeed instruction onto the computation, which writes data of
1480 // the given shape to the infeed buffer of the device.
1481 XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
1482              const string& config = "");
1483 
1484 // Variant of Infeed which takes a token-shaped operand and produces a
1485 // two-element tuple containing the data value and a token-shaped value.
1486 // Tokens are used for ordering side-effecting operations.
1487 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
1488 XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
1489                       const string& config = "");
1490 
1491 // Enqueues an outfeed instruction onto the computation. This instruction
1492 // generates outgoing data transfers for the given data.
1493 //
1494 // shape_with_layout communicates the laid out shape that we want to outfeed
1495 // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
1496 // will occur.
1497 void Outfeed(XlaOp operand, const Shape& shape_with_layout,
1498              const string& outfeed_config);
1499 
1500 // Variant of Outfeed which takes a token-shaped operand and produces a
1501 // token-shaped value. Tokens are used for ordering side-effecting operations.
1502 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
1503 XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
1504                        const Shape& shape_with_layout,
1505                        const string& outfeed_config);
1506 
1507 // Enqueues a call instruction onto the computation.
1508 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
1509            absl::Span<const XlaOp> operands);
1510 
1511 // Enqueues a custom call instruction onto the computation. A custom call
1512 // invokes code external to XLA. The |operands| are passed to the external code,
1513 // and the external code is expected to produce a result of the given
1514 // |shape|. The exact mechanism is backend-specific. For example, in the CPU
1515 // backend, a call instruction is emitted which targets a symbol with the name
1516 // |call_target_name|.  |call_target_name| and |opaque| can arbitrary strings,
1517 // but |call_target_name| should be short as it may be used in labels. |opaque|
1518 // can encode arbitrarily large amounts of information.
1519 XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
1520                  absl::Span<const XlaOp> operands, const Shape& shape,
1521                  const string& opaque = "");
1522 
1523 // Overload which constructs a custom call with fixed layouts. The operands will
1524 // have the layouts specified by |operand_shapes_with_layout| when provided to
1525 // external code, and the external code is expected to produce a result with the
1526 // layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
1527 // and |operand_shapes_with_layout| must have layouts.
1528 XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
1529                            absl::Span<const XlaOp> operands,
1530                            const Shape& shape_with_layout,
1531                            absl::Span<const Shape> operand_shapes_with_layout,
1532                            const string& opaque = "");
1533 
1534 // The following methods enqueue element-wise binary arithmetic operations
1535 // onto the computation. The shapes of the operands have to match unless one
1536 // of the operands is a scalar, or an explicit broadcast dimension is given
1537 // (see g3doc for more details).
1538 
1539 // Enqueues a complex compose instruction onto the computation.
1540 XlaOp Complex(XlaOp real, XlaOp imag,
1541               absl::Span<const int64> broadcast_dimensions = {});
1542 
1543 // Enqueues a complex conjugate instruction onto the computation.
1544 XlaOp Conj(XlaOp operand);
1545 
1546 // Enqueues an add instruction onto the computation.
1547 XlaOp Add(XlaOp lhs, XlaOp rhs,
1548           absl::Span<const int64> broadcast_dimensions = {});
1549 
1550 // Enqueues a subtract instruction onto the computation.
1551 XlaOp Sub(XlaOp lhs, XlaOp rhs,
1552           absl::Span<const int64> broadcast_dimensions = {});
1553 
1554 // Enqueues a multiply instruction onto the computation.
1555 XlaOp Mul(XlaOp lhs, XlaOp rhs,
1556           absl::Span<const int64> broadcast_dimensions = {});
1557 
1558 // Enqueues a divide instruction onto the computation.
1559 XlaOp Div(XlaOp lhs, XlaOp rhs,
1560           absl::Span<const int64> broadcast_dimensions = {});
1561 
1562 // Enqueues a remainder instruction onto the computation.
1563 XlaOp Rem(XlaOp lhs, XlaOp rhs,
1564           absl::Span<const int64> broadcast_dimensions = {});
1565 
1566 // Enqueues a max instruction onto the computation.
1567 XlaOp Max(XlaOp lhs, XlaOp rhs,
1568           absl::Span<const int64> broadcast_dimensions = {});
1569 
1570 // Enqueues a min instruction onto the computation.
1571 XlaOp Min(XlaOp lhs, XlaOp rhs,
1572           absl::Span<const int64> broadcast_dimensions = {});
1573 
1574 // Element-wise logical operators
1575 XlaOp And(XlaOp lhs, XlaOp rhs,
1576           absl::Span<const int64> broadcast_dimensions = {});
1577 
1578 // Overload to call And with 3 or more operands.  We need the following somewhat
1579 // convoluted overload set to disambiguate with the overload that takes the
1580 // `broadcast_dimensions` optional param.
And(XlaOp op1,XlaOp op2,XlaOp op3)1581 inline XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3) {
1582   return And(op1, And(op2, op3));
1583 }
1584 template <typename... XlaOpTs>
And(XlaOp op1,XlaOp op2,XlaOp op3,const XlaOpTs &...operands)1585 XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) {
1586   return And(op1, And(op2, And(op3, operands...)));
1587 }
1588 
1589 XlaOp Or(XlaOp lhs, XlaOp rhs,
1590          absl::Span<const int64> broadcast_dimensions = {});
1591 
1592 // Overload to call Or with 3 or more operands.  As with `And`, we need the
1593 // following complicated overload set to handle the default arg in the `Or`
1594 // overload above.
Or(XlaOp op1,XlaOp op2,XlaOp op3)1595 inline XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3) {
1596   return Or(op1, Or(op2, op3));
1597 }
1598 template <typename... XlaOpTs>
Or(XlaOp op1,XlaOp op2,XlaOp op3,const XlaOpTs &...operands)1599 XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) {
1600   return Or(op1, Or(op2, Or(op3, operands...)));
1601 }
1602 
1603 XlaOp Xor(XlaOp lhs, XlaOp rhs,
1604           absl::Span<const int64> broadcast_dimensions = {});
1605 
1606 XlaOp Not(XlaOp operand);
1607 
1608 XlaOp PopulationCount(XlaOp operand);
1609 
1610 XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs,
1611                 absl::Span<const int64> broadcast_dimensions = {});
1612 XlaOp ShiftRightArithmetic(XlaOp lhs, XlaOp rhs,
1613                            absl::Span<const int64> broadcast_dimensions = {});
1614 XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs,
1615                         absl::Span<const int64> broadcast_dimensions = {});
1616 
1617 // Reduces an array among the provided dimensions, given "computation" as a
1618 // reduction operator.
1619 XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation,
1620              absl::Span<const int64> dimensions_to_reduce);
1621 
1622 // Reduces several arrays simultaneously among the provided dimensions, given
1623 // "computation" as a reduction operator.
1624 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1625              absl::Span<const XlaOp> init_values,
1626              const XlaComputation& computation,
1627              absl::Span<const int64> dimensions_to_reduce);
1628 
1629 // Convenience wrapper around the above that reduces all the dimensions in the
1630 // operand shape.
1631 XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
1632                 const XlaComputation& computation);
1633 
1634 // Enqueues a windowed reduce instruction onto the computation.
1635 XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
1636                    const XlaComputation& computation,
1637                    absl::Span<const int64> window_dimensions,
1638                    absl::Span<const int64> window_strides, Padding padding);
1639 
1640 // As ReduceWindow(), but the padding is given in the format
1641 // returned by MakePadding().
1642 XlaOp ReduceWindowWithGeneralPadding(
1643     XlaOp operand, XlaOp init_value, const XlaComputation& computation,
1644     absl::Span<const int64> window_dimensions,
1645     absl::Span<const int64> window_strides,
1646     absl::Span<const int64> base_dilations,
1647     absl::Span<const int64> window_dilations,
1648     absl::Span<const std::pair<int64, int64>> padding);
1649 
1650 // Returns the sum of the operand value within each subgroup of replicas. All
1651 // replicas supply one input to the sum and all replicas receive the resulting
1652 // sum for each subgroup.
1653 XlaOp CrossReplicaSum(XlaOp operand,
1654                       absl::Span<const ReplicaGroup> replica_groups = {});
1655 
1656 // Enqueues an operation that do an AllReduce of the operand cross cores. Here
1657 // AllReduce means doing a reduction on the input operand cross cores and then
1658 // broadcasting the reduction result to those cores. The reduction function is
1659 // defined by `computation`, which should be a commutative computation on
1660 // scalars, e.g., add, min, or max. The way that AllReduce is applied is
1661 // configured by:
1662 //
1663 // - `replica_groups`: each ReplicaGroup contains a list of replica id. If
1664 // empty, all replicas belong to one group. Allreduce will be applied within
1665 // subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}}
1666 // means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
1667 //
1668 // - `channel_id`: for Allreduce nodes from different modules, if they have the
1669 // same channel_id, they will be 'AllReduce'd. If empty, AllReduce will not be
1670 // applied cross modules.
1671 //
1672 // - `shape_with_layout`: forces the layout of the AllReduce to the given
1673 // layout. This is used to guarantee the same layout for a group of AllReduce
1674 // ops compiled separately.
1675 XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
1676                 absl::Span<const ReplicaGroup> replica_groups = {},
1677                 const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
1678                 const absl::optional<Shape>& shape_with_layout = absl::nullopt);
1679 
1680 // Enqueues an operation that do an Alltoall of the operand cross cores.
1681 XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
1682                int64 split_count,
1683                const std::vector<ReplicaGroup>& replica_groups = {});
1684 
1685 // Enqueues an collective operation that sends and receives data cross replicas.
1686 //
1687 // - `source_target_pair`: a list of (source_replica_id, target_replica_id)
1688 // pairs. For each pair, the operand is sent from source replica to target
1689 // replica. Note that, 1) any two pairs should not have the same target replica
1690 // id, and they should not have the same source replica id; 2) if a replica id
1691 // is not a target in any pair, then the output on that replica is a tensor
1692 // consists of 0(s) with the same shape as the input.
1693 XlaOp CollectivePermute(
1694     XlaOp operand,
1695     const std::vector<std::pair<int64, int64>>& source_target_pairs);
1696 
1697 // Enqueues an operation that returns the replica ID.
1698 XlaOp ReplicaId(XlaBuilder* builder);
1699 
1700 // Enqueues an operation that scatters the `source` array to the selected
1701 // indices of each window.
1702 XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
1703                        absl::Span<const int64> window_dimensions,
1704                        absl::Span<const int64> window_strides, Padding padding,
1705                        XlaOp source, XlaOp init_value,
1706                        const XlaComputation& scatter);
1707 
1708 // As SelectAndScatter(), but the padding is given in the format
1709 // returned by MakePadding().
1710 XlaOp SelectAndScatterWithGeneralPadding(
1711     XlaOp operand, const XlaComputation& select,
1712     absl::Span<const int64> window_dimensions,
1713     absl::Span<const int64> window_strides,
1714     absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
1715     XlaOp init_value, const XlaComputation& scatter);
1716 
1717 // Enqueues an abs instruction onto the computation.
1718 XlaOp Abs(XlaOp operand);
1719 
1720 // Enqueues a atan2 instruction onto the computation.
1721 XlaOp Atan2(XlaOp y, XlaOp x,
1722             absl::Span<const int64> broadcast_dimensions = {});
1723 
1724 // Enqueues an exp instruction onto the computation.
1725 XlaOp Exp(XlaOp operand);
1726 
1727 // Enqueues an expm1 instruction onto the computation.
1728 XlaOp Expm1(XlaOp operand);
1729 
1730 // Enqueues a floor instruction onto the computation.
1731 XlaOp Floor(XlaOp operand);
1732 
1733 // Enqueues a ceil instruction onto the computation.
1734 XlaOp Ceil(XlaOp operand);
1735 
1736 // Enqueues a round instruction onto the computation, rounding to nearest even
1737 // with half-way cases rounding away from zero.
1738 XlaOp Round(XlaOp operand);
1739 
1740 // Enqueues an log instruction (natural logarithm) onto the computation.
1741 XlaOp Log(XlaOp operand);
1742 
1743 // Enqueues an log1p instruction (log(x+1)) onto the computation.
1744 XlaOp Log1p(XlaOp operand);
1745 
1746 // Enqueues a sign instruction onto the computation.
1747 XlaOp Sign(XlaOp operand);
1748 
1749 // Enqueues a count leading zeros instruction onto the computation.
1750 XlaOp Clz(XlaOp operand);
1751 
1752 // Enqueues a cosine instruction onto the computation.
1753 XlaOp Cos(XlaOp operand);
1754 
1755 // Enqueues a sine instruction onto the computation.
1756 XlaOp Sin(XlaOp operand);
1757 
1758 // Enqueues a tanh instruction onto the computation.
1759 XlaOp Tanh(XlaOp operand);
1760 
1761 // Enqueues a real-part instruction onto the computation.
1762 XlaOp Real(XlaOp operand);
1763 
1764 // Enqueues an imaginary-part instruction onto the computation.
1765 XlaOp Imag(XlaOp operand);
1766 
1767 // Enqueues a sqrt computation onto the computation.
1768 XlaOp Sqrt(XlaOp operand);
1769 
1770 // Enqueues a rsqrt computation onto the computation.
1771 XlaOp Rsqrt(XlaOp operand);
1772 
1773 // Enqueues a lhs^rhs computation onto the computation.
1774 XlaOp Pow(XlaOp lhs, XlaOp rhs,
1775           absl::Span<const int64> broadcast_dimensions = {});
1776 
1777 // Enqueues an operator that tests if the operand's values are finite, i.e., not
1778 // +/-Inf or NaN.  Returns an array of booleans with the same shape where
1779 // entries are true iff the corresponding entry was not infinite or NaN.
1780 //
1781 // Defined only for real-valued (i.e. not complex) floating-point types; raises
1782 // an error for other types.
1783 //
1784 // See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h.
1785 XlaOp IsFinite(XlaOp operand);
1786 
1787 // Enqueues an iota operation onto the computation.
1788 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension);
1789 
1790 // Enqueues a rank-1 iota operation onto the computation.
1791 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
1792 
1793 // Enqueues a convert instruction onto the computation that changes the
1794 // element type of the operand array to primitive_type.
1795 XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
1796 
1797 // Enqueues a no-op instruction onto the computation that changes
1798 // the element type of the operand array to primitive_type. The
1799 // bit-widths of the source and destination element types must be
1800 // identical.
1801 XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
1802 
1803 // Enqueues a negate instruction onto the computation.
1804 XlaOp Neg(XlaOp operand);
1805 
1806 // Enqueues a transpose instruction onto the computation.
1807 XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
1808 
1809 // Enqueues a reverse instruction onto the computation. The order of the
1810 // elements in the given dimensions is reversed (i.e., the element at index i
1811 // is moved to index dimension_size - 1 - i).
1812 XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
1813 
1814 // Enqueues a sort instruction onto the computation, using 'comparator' for
1815 // comparisons. 'comparator' needs to define a strict weak order. 'is_stable'
1816 // determines whether the stable sorting should be used.
1817 // If only one operand is provided:
1818 // * If the operand is a rank-1 tensor (an array), the result is a sorted array.
1819 //   The resulting sorting order has the property that for all index positions
1820 //   i, j with i < j, either
1821 //   comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or
1822 //   comparator(value[i], value[j]) = true.
1823 // * If the operand has higher rank, the operand is sorted along the provided
1824 //   dimension. For example, for a rank-2 tensor (a matrix), a dimension value
1825 //   of 0 will independently sort every column, and a dimension value of 1 will
1826 //   independently sort each row. If no dimension number is provided, then the
1827 //   last dimension is chosen by default. For the dimension which is sorted, the
1828 //   same sorting order applies as in the rank-1 case.
1829 //
1830 // If more than one operand is provided:
1831 // * All operands must be tensors with the same dimensions. The element types of
1832 //   the tensors may be different.
1833 // * The result is a tuple that consists of the operands in sorted order (along
1834 //   the provided dimension, as above). The same permutation as implied by the
1835 //   comparison computation is applied to all operand tensors. When comparing
1836 //   two index positions, 'comparator' is called with 2 * n scalar parameters,
1837 //   where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at
1838 //   two index positions.
1839 // Default comparator computations can be found in lib/comparators.h
1840 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
1841            int64 dimension = -1, bool is_stable = false);
1842 
1843 // Enqueues a clamp instruction onto the computation.
1844 XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
1845 
1846 // Enqueues a map instruction onto the computation.
1847 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1848           const XlaComputation& computation, absl::Span<const int64> dimensions,
1849           absl::Span<const XlaOp> static_operands = {});
1850 
1851 // Enqueues a N(mu, sigma) random number generation instruction onto the
1852 // computation.
1853 XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
1854 
1855 // Enqueues a U(a, b) random number generation instruction onto the
1856 // computation. Returns values in the semi-open interval [a, b).
1857 XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
1858 
1859 // Enqueues a while node onto the computation.
1860 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
1861             XlaOp init);
1862 
1863 // Enqueues a conditional node onto the computation.
1864 XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
1865                   const XlaComputation& true_computation, XlaOp false_operand,
1866                   const XlaComputation& false_computation);
1867 
1868 // Enqueues either a predicated (if/else) or indexed (switch/case/default)
1869 // conditional node onto the computation. N >= 1 branch_computations and
1870 // branch_operands are matched by index. branch_index selects the branch that
1871 // will be executed. Out of range branch_index uses the N-1'th
1872 // branch_computation as default.
1873 XlaOp Conditional(XlaOp branch_index,
1874                   absl::Span<const XlaComputation* const> branch_computations,
1875                   absl::Span<const XlaOp> branch_operands);
1876 
1877 // Enqueues a ReducePrecision node onto the computation.
1878 XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
1879                       const int mantissa_bits);
1880 
1881 // Enqueues a Gather node onto the computation.
1882 XlaOp Gather(XlaOp input, XlaOp start_indices,
1883              const GatherDimensionNumbers& dimension_numbers,
1884              absl::Span<const int64> slice_sizes,
1885              bool indices_are_sorted = false);
1886 
1887 // Enqueues a Scatter node onto the computation.
1888 XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
1889               const XlaComputation& update_computation,
1890               const ScatterDimensionNumbers& dimension_numbers,
1891               bool indices_are_sorted = false, bool unique_indices = false);
1892 
1893 // Enqueues a Send node onto the computation for device-to-device
1894 // communication. This operation sends the given operand to
1895 // a Recv instruction in a different computation that shares the same channel
1896 // handle.
1897 void Send(XlaOp operand, const ChannelHandle& handle);
1898 
1899 // Variant of Send which takes a token-shaped operand and produces a
1900 // token-shaped value.  Tokens are used for ordering side-effecting operations.
1901 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
1902 XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle);
1903 
1904 // Enqueues a Recv node onto the computation for device-to-device
1905 // communication. The data comes from a Send instruction in a different
1906 // computation that shares the same channel handle and its shape must be the
1907 // same as the given shape.
1908 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
1909            const ChannelHandle& handle);
1910 
1911 // Variant of Recv which takes a token-shaped operand and produces a two-element
1912 // tuple containing the data value and a token-shaped value. Tokens are used
1913 // for ordering side-effecting operations.
1914 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
1915 XlaOp RecvWithToken(XlaOp token, const Shape& shape,
1916                     const ChannelHandle& handle);
1917 
1918 // Enqueues a Send node which transfers data from the device to the host. The
1919 // 'shape_with_layout' argument defines the layout of the data transferred; its
1920 // shape must be compatible with the shape of the operand. The operand must be
1921 // array-shaped.
1922 // TODO(b/111544877): Support tuple shapes.
1923 XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout,
1924                  const ChannelHandle& handle);
1925 
1926 // Enqueues a Recv node which transfers data from the host to the device. The
1927 // given shape must contain a layout and must be an array.
1928 // TODO(b/111544877): Support tuple shapes.
1929 XlaOp RecvFromHost(XlaOp token, const Shape& shape,
1930                    const ChannelHandle& handle);
1931 
1932 // Enqueues an operation (AfterAll) with no operands that produces a
1933 // token-shaped value.  Tokens are used for ordering side-effecting operations.
1934 // This is a separate method from AfterAll to facility the removal of
1935 // operand-less AfterAll instructions.
1936 // TODO(b/110532604): Remove this function when all tokens are derived from a
1937 // single token generated or passed into the entry computation.
1938 XlaOp CreateToken(XlaBuilder* builder);
1939 
1940 // Enqueues an AfterAll instruction which produces a token-shaped value and
1941 // takes a variadic number of token-shaped operands. The number of operands must
1942 // be greater than zero. Used for joining tokens.
1943 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
1944 
1945 // Normalizes operand across spatial and batch dimensions for each feature.
1946 //
1947 // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
1948 // is the normalized result and batch_mean and batch_var are the mean and
1949 // variance, respectively, across batch for the operand.
1950 XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, float epsilon,
1951                         int64 feature_index);
1952 
1953 // Normalizes operand across spatial and batch dimensions for each feature.
1954 //
1955 // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
1956 // computing `mean` and `variance` for each batch inside the operation. It
1957 // uses the input `mean` and `variance` instead as estimated values. The
1958 // purpose of this op is to reduce latency in inference, hence the name
1959 // `BatchNormInference`.
1960 //
1961 // The output has the same shape as `operand`, and contains the normalized
1962 // values for each batch.
1963 XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean,
1964                          XlaOp variance, float epsilon, int64 feature_index);
1965 
1966 // Calculates the gradients of a batch norm op.
1967 //
1968 // The inputs `batch_mean` and `batch_var` represent the mean and variance
1969 // across the batch.
1970 //
1971 // Returns a tuple of three elements:
1972 //   - grad_operand: Gradient with respect to input `operand`
1973 //   - grad_offset: Gradient with respect to input `offset`
1974 //   - grad_scale: Gradient with respect to input `scale`
1975 XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
1976                     XlaOp batch_var, XlaOp grad_output, float epsilon,
1977                     int64 feature_index);
1978 
1979 // Returns the size of the given dimension of the operand. The operand must be
1980 // array shaped.
1981 XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
1982 
1983 XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
1984 
1985 // Implementation details below this point.
1986 //
1987 
1988 // Free function template implementations.
1989 
1990 template <typename NativeT>
ConstantR0(XlaBuilder * builder,NativeT value)1991 XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
1992   return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
1993 }
1994 
1995 template <typename NativeT>
ConstantR1(XlaBuilder * builder,absl::Span<const NativeT> values)1996 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
1997   BorrowingLiteral literal(
1998       reinterpret_cast<const char*>(values.begin()),
1999       ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
2000                            {static_cast<int64>(values.size())}));
2001   return ConstantLiteral(builder, literal);
2002 }
2003 
2004 template <typename NativeT>
ConstantR1(XlaBuilder * builder,int64 length,NativeT value)2005 XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
2006   Literal literal(ShapeUtil::MakeShape(
2007       primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
2008   literal.PopulateWithValue(value);
2009   return ConstantLiteral(builder, literal);
2010 }
2011 
ConstantR1(XlaBuilder * builder,const tensorflow::core::Bitmap & values)2012 inline XlaOp ConstantR1(XlaBuilder* builder,
2013                         const tensorflow::core::Bitmap& values) {
2014   return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
2015 }
2016 
2017 template <typename NativeT>
ConstantR2(XlaBuilder * builder,std::initializer_list<std::initializer_list<NativeT>> values)2018 XlaOp ConstantR2(XlaBuilder* builder,
2019                  std::initializer_list<std::initializer_list<NativeT>> values) {
2020   return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
2021 }
2022 
2023 template <typename NativeT>
ConstantFromArrayWithLayout(XlaBuilder * builder,const Array<NativeT> & values,const Layout & layout)2024 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
2025                                   const Array<NativeT>& values,
2026                                   const Layout& layout) {
2027   return ConstantLiteral(
2028       builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
2029 }
2030 
2031 template <typename NativeT>
ConstantFromArray(XlaBuilder * builder,const Array<NativeT> & values)2032 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
2033   return ConstantLiteral(builder,
2034                          LiteralUtil::CreateFromArray<NativeT>(values));
2035 }
2036 
2037 template <typename NativeT>
ConstantR2FromArray2DWithLayout(XlaBuilder * builder,const Array2D<NativeT> & values,const Layout & layout)2038 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
2039                                       const Array2D<NativeT>& values,
2040                                       const Layout& layout) {
2041   return ConstantLiteral(
2042       builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
2043 }
2044 
2045 template <typename NativeT>
ConstantR2FromArray2D(XlaBuilder * builder,const Array2D<NativeT> & values)2046 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
2047                             const Array2D<NativeT>& values) {
2048   return ConstantLiteral(builder,
2049                          LiteralUtil::CreateR2FromArray2D<NativeT>(values));
2050 }
2051 
2052 template <typename NativeT>
ConstantR3FromArray3DWithLayout(XlaBuilder * builder,const Array3D<NativeT> & values,const Layout & layout)2053 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
2054                                       const Array3D<NativeT>& values,
2055                                       const Layout& layout) {
2056   return ConstantLiteral(
2057       builder,
2058       LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
2059 }
2060 
2061 template <typename NativeT>
ConstantR3FromArray3D(XlaBuilder * builder,const Array3D<NativeT> & values)2062 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
2063                             const Array3D<NativeT>& values) {
2064   return ConstantFromArray(builder, values);
2065 }
2066 
2067 template <typename NativeT>
ConstantR4FromArray4DWithLayout(XlaBuilder * builder,const Array4D<NativeT> & values,const Layout & layout)2068 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
2069                                       const Array4D<NativeT>& values,
2070                                       const Layout& layout) {
2071   return ConstantFromArrayWithLayout(builder, values, layout);
2072 }
2073 
2074 template <typename NativeT>
ConstantR4FromArray4D(XlaBuilder * builder,const Array4D<NativeT> & values)2075 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
2076                             const Array4D<NativeT>& values) {
2077   return ConstantFromArray(builder, values);
2078 }
2079 
2080 }  // namespace xla
2081 
2082 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
2083