1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
18
19 #include <map>
20 #include <string>
21 #include <type_traits>
22 #include <utility>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/client/padding.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/comparison_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/literal_util.h"
33 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
34 #include "tensorflow/compiler/xla/service/hlo.pb.h"
35 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status_macros.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/stacktrace.h"
43 #include "tensorflow/core/platform/types.h"
44
45 namespace xla {
46
47 class XlaBuilder;
48
49 // This represents an instruction that has been enqueued using the XlaBuilder.
50 // This is used to pass to subsequent computations that depends upon the
51 // instruction as an operand.
52 class XlaOp {
53 public:
XlaOp()54 XlaOp() : handle_(-1), builder_(nullptr) {
55 static_assert(std::is_trivially_destructible<XlaOp>::value,
56 "XlaOp should be trivially destructible");
57 }
58 ~XlaOp() = default;
59
60 XlaOp(const XlaOp& other) = default;
61 XlaOp& operator=(const XlaOp& other) = default;
62
63 // Precondition: !IsUninitialized().
64 //
65 // It's very common to do foo.builder()->bar(). Without this precondition, if
66 // foo.builder() is null, the call to bar will segfault at some point possibly
67 // deep in the callstack when we finally dereference `this`. The precondition
68 // lets us avoid this tricky-to-debug problem.
builder()69 XlaBuilder* builder() const {
70 CHECK(builder_ != nullptr);
71 return builder_;
72 }
73
74 // Returns true if the XlaOp represents valid, non-erroneous value.
valid()75 bool valid() const { return handle_ >= 0; }
76
77 // Returns true if the XlaOp was created by the XlaOp() constructor and
78 // not returned by a builder.
IsUninitialized()79 bool IsUninitialized() const { return builder_ == nullptr; }
80
IsIdenticalTo(XlaOp rhs)81 bool IsIdenticalTo(XlaOp rhs) const {
82 return handle_ == rhs.handle_ && builder_ == rhs.builder_;
83 }
84
85 friend std::ostream& operator<<(std::ostream& out, XlaOp op) {
86 out << op.handle();
87 return out;
88 }
89
90 private:
XlaOp(XlaBuilder * builder)91 explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {}
XlaOp(int64 handle,XlaBuilder * builder)92 XlaOp(int64 handle, XlaBuilder* builder)
93 : handle_(handle), builder_(builder) {}
94
handle()95 int64 handle() const { return handle_; }
96
97 friend class XlaBuilder;
98
99 // < 0 means "invalid handle".
100 int64 handle_;
101
102 // Not owned. Non-null for any handle returned by XlaBuilder, even if the
103 // handle is invalid.
104 XlaBuilder* builder_;
105 };
106
107 // Arithmetic operator overloads for the XlaOp type.
108 XlaOp operator-(XlaOp x);
109 XlaOp operator+(XlaOp x, XlaOp y);
110 XlaOp operator-(XlaOp x, XlaOp y);
111 XlaOp operator*(XlaOp x, XlaOp y);
112 XlaOp operator/(XlaOp x, XlaOp y);
113 XlaOp operator%(XlaOp x, XlaOp y);
114
115 // Bitwise operator overloads for the XlaOp type.
116 XlaOp operator~(XlaOp x);
117 XlaOp operator&(XlaOp x, XlaOp y);
118 XlaOp operator|(XlaOp x, XlaOp y);
119 XlaOp operator^(XlaOp x, XlaOp y);
120 XlaOp operator<<(XlaOp x, XlaOp y);
121 // Performs a right arithmetic shift if 'x' is a signed type, otherwise performs
122 // a right logical shift.
123 XlaOp operator>>(XlaOp x, XlaOp y);
124
125 // We don't overload the relational operators (==, !=, <, <=, >, >=) because the
126 // semantics might be surprising since their result types are usually 'bool'.
127 // Further programmers may expect == to be a structural equality.
128 // We also choose not to overload any of the mutating operators (e.g., +=, -=)
129 // because the semantics might be misleading — XLA computations are immutable.
130
131 // A convenient interface for building up computations.
132 //
133 // Thread-compatible.
134 class XlaBuilder {
135 public:
136 // computation_name: name to use for the built computation.
137 XlaBuilder(const string& computation_name);
138
139 XlaBuilder(const XlaBuilder&) = delete;
140 XlaBuilder& operator=(const XlaBuilder&) = delete;
141
142 ~XlaBuilder();
143
144 // Returns the computation name.
name()145 const string& name() const { return name_; }
146
147 // Sets OpMetadata that will be added to all instructions until cleared.
148 //
149 // OpMetadata is often applied to a series of XLA HLO instructions. As a
150 // result, OpMetadata is set on the computation builder. All subsequent
151 // instructions generated via this computation builder will have the same
152 // OpMetadata attached until a call to ClearOpMetadata.
SetOpMetadata(OpMetadata metadata)153 void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); }
154
155 // Clears the HloMetadata state.
ClearOpMetadata()156 void ClearOpMetadata() { metadata_.Clear(); }
157
158 // Sets an OpSharding that will be attached to all instructions until cleared.
SetSharding(const OpSharding & sharding)159 void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
160
161 // Sets the FrontendAttributes that will be added to all instructions until
162 // cleared.
163 //
164 // FrontendAttributes are often applied to a series of XLA HLO instructions.
165 // As a result they are set on the computation builder and all the
166 // instructions generated via the computation builder will have the same
167 // frontend attributes attached to them.
SetFrontendAttributes(const FrontendAttributes & frontend_attributes)168 void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) {
169 frontend_attributes_ = frontend_attributes;
170 }
171
172 // Swap the passed FrontendAttributes with the ones currently set.
173 //
174 // Return the old attributes.
SwapFrontendAttributes(const FrontendAttributes & frontend_attributes)175 FrontendAttributes SwapFrontendAttributes(
176 const FrontendAttributes& frontend_attributes) {
177 FrontendAttributes old_attributes = std::move(frontend_attributes_);
178 frontend_attributes_ = frontend_attributes;
179 return old_attributes;
180 }
181
182 // Returns the FrontendAttributes that will be attached to all instructions.
frontend_attributes()183 const FrontendAttributes& frontend_attributes() const {
184 return frontend_attributes_;
185 }
186
187 // Clears all the frontend attributes.
ClearFrontendAttributes()188 void ClearFrontendAttributes() { frontend_attributes_.Clear(); }
189
190 // Clears the sharding. Ops will be sharded according to the default placement
191 // policy.
ClearSharding()192 void ClearSharding() { sharding_ = absl::nullopt; }
193
194 // Returns the OpSharding that will be attached to all instructions.
sharding()195 const absl::optional<OpSharding>& sharding() const { return sharding_; }
196
197 // Sets the builder to a mode where it will die immediately when an error is
198 // encountered, rather than producing it in a deferred fashion when Build() is
199 // called (which is the default).
set_die_immediately_on_error(bool enabled)200 void set_die_immediately_on_error(bool enabled) {
201 die_immediately_on_error_ = enabled;
202 }
203
204 // Default dimension numbers used for a 2D convolution.
205 static constexpr int64 kConvBatchDimension = 0;
206 static constexpr int64 kConvFeatureDimension = 1;
207 static constexpr int64 kConvFirstSpatialDimension = 2;
208 static constexpr int64 kConvSecondSpatialDimension = 3;
209 static constexpr int64 kConvKernelOutputDimension = 0;
210 static constexpr int64 kConvKernelInputDimension = 1;
211 static constexpr int64 kConvKernelFirstSpatialDimension = 2;
212 static constexpr int64 kConvKernelSecondSpatialDimension = 3;
213
214 // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
215 // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
216 // the kernel operand
217 // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
218 static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
219 int num_spatial_dims = 2);
220
221 // Returns an error if the convolution dimension numbers have conflicts.
222 static Status Validate(const ConvolutionDimensionNumbers& dnum);
223
224 // Returns a new XlaBuilder whose resultant Computation is used only by this
225 // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
226 // behavior as the parent.
227 std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
228
229 // Builds the computation with the requested operations, or returns a non-ok
230 // status. Note that all ops that have been enqueued will be moved to the
231 // computation being returned. The root of the computation will be the last
232 // added operation.
233 //
234 // `remove_dynamic_dimensions` tells the builder whether to remove the
235 // dynamic dimensions information in all ops.
236 //
237 // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the
238 // dynamic dimensions information when XLA backend can handle dynamic
239 // dimensions.
240 StatusOr<XlaComputation> Build(bool remove_dynamic_dimensions = false);
241
242 // Overload of Build which specifies a particular root instruction for the
243 // computation.
244 StatusOr<XlaComputation> Build(XlaOp root,
245 bool remove_dynamic_dimensions = false);
246
247 // Builds the computation with the requested operations, or notes an error in
248 // the parent XlaBuilder and returns an empty computation if building failed.
249 // This function is intended to be used where the returned XlaComputation is
250 // only used by the parent XlaBuilder and hence further operation on the
251 // returned XlaComputation will simply be error'ed out if an error occurred
252 // while building this computation. If the built computation is to be used by
253 // a XlaBuilder other than the parent XlaBuilder then Build() should be used
254 // instead.
255 XlaComputation BuildAndNoteError();
256
257 // Returns a subgraph that roots on the given root. If the root is not a
258 // compile-time constant (see `IsConstant`), returns an error.
259 //
260 // This will copy the needed ops/computations to the subgraph.
261 StatusOr<XlaComputation> BuildConstantSubGraph(
262 XlaOp root_op, bool dynamic_dimension_is_uint_max = false);
263
264 // Returns the first error that was encountered while building the
265 // computation. When an error is encountered, by default we return a vacuous
266 // XlaOp and inform the user of the error that occurred while
267 // building the computation when they make a final call to Build().
268 //
269 // See also set_die_immediately_on_error().
first_error()270 Status first_error() const { return first_error_; }
271
272 // Returns the current status of the builder, complete with the stack trace
273 // information.
274 Status GetCurrentStatus() const;
275
276 // Returns the shape of the given op.
277 StatusOr<Shape> GetShape(XlaOp op) const;
278
279 // Returns the shape of the given op.
280 StatusOr<const Shape*> GetShapePtr(XlaOp op) const;
281
282 // Returns the (inferred) result for the current computation's shape. This
283 // assumes the root instruction is the last added instruction.
284 StatusOr<ProgramShape> GetProgramShape() const;
285
286 // Returns the (inferred) result for the current computation's shape using the
287 // given operation as the root.
288 StatusOr<ProgramShape> GetProgramShape(XlaOp root) const;
289
290 // Reports an error to the builder, by
291 // * storing it internally and capturing a backtrace if it's the first error
292 // (this deferred value will be produced on the call to
293 // Build()/GetShape()/...)
294 // * dying if die_immediately_on_error_ is true.
295 // Returns an XlaOp with an invalid handle but a valid builder. This value can
296 // be returned in place of a value in APIs that return an XlaOp.
297 XlaOp ReportError(const Status& error);
298
299 // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
300 // If the Status was an error, reports the error to builder and returns an
301 // invalid XlaOp handle.
302 XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
303
304 // A helper function that runs a function that returns a StatusOr<XlaOp> and
305 // returns an XlaOp.
306 XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
307
308 // Returns true if 'operand' is a compile-time constant. A compile-time
309 // constant does not depend on any parameters, or on stateful operators such
310 // as `RngNormal` or `Infeed`.
311 //
312 // This tests whether a computation is a compile-time constant without
313 // evaluating the computation.
314 StatusOr<bool> IsConstant(XlaOp operand) const;
315
316 // Sets up binding which indicates that the `target_dim_num` in the subshape
317 // `target_param_index` of parameter `target_param_num` is a dynamic dimension
318 // and its real dynamic size is represented by `dynamic_param_index` in
319 // parameter `dynamic_param_num`.
320 //
321 // Note that this should be called before the dynamic parameters are used to
322 // create other operations, otherwise created operations won't have the
323 // dynamic dimensions information.
324 //
325 // TODO(b/119520625): Remove this API once we have more dynamic shape infra
326 // ready.
327 Status SetDynamicBinding(int64 dynamic_size_param_num,
328 ShapeIndex dynamic_size_param_index,
329 int64 target_param_num,
330 ShapeIndex target_param_index, int64 target_dim_num);
331
332 // Adds a new input/output alias. Since the input/output shape information are
333 // not available until the computation is built, and eventual error in the
334 // arguments of this API will be detected only at computation Build() time.
SetUpAlias(const ShapeIndex & output_index,int64 param_number,const ShapeIndex & param_index)335 void SetUpAlias(const ShapeIndex& output_index, int64 param_number,
336 const ShapeIndex& param_index) {
337 input_output_aliases_.push_back({output_index, param_number, param_index});
338 }
339
340 // Describes an input/output alias as inserted by the SetUpAlias() API.
341 struct InputOutputAlias {
342 // Specifies the index of the aliased buffer in the result tuple.
343 ShapeIndex output_index;
344 // Specifies the parameter containing the buffer to be aliased.
345 int64 param_number;
346 // Specifies the index of the aliased buffer in the parameter
347 ShapeIndex param_index;
348 };
349
350 // Looks up the HloInstruction and sets the frontend attribute "attribute" to
351 // "value".
352 //
353 // If the attribute already existed then its value is updated.
354 //
355 // Note: the attribute is only added to the HloInstruction, not to the
356 // builder.
357 Status SetInstructionFrontendAttribute(XlaOp op, string attribute,
358 string value);
359
360 private:
361 // Build helper which takes the id of the root operation..
362 StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions);
363
364 // Description for the methods below can be found in the corresponding public
365 // functions section in this file.
366
367 XlaOp Parameter(int64 parameter_number, const Shape& shape,
368 const string& name,
369 const std::vector<bool>& replicated_at_leaf_buffers);
Parameter(int64 parameter_number,const Shape & shape,const string & name)370 XlaOp Parameter(int64 parameter_number, const Shape& shape,
371 const string& name) {
372 std::vector<bool> empty_bools;
373 return Parameter(parameter_number, shape, name, empty_bools);
374 }
375
376 XlaOp ConstantLiteral(const LiteralSlice& literal);
377
378 XlaOp Broadcast(XlaOp operand, absl::Span<const int64> broadcast_sizes);
379
380 XlaOp BroadcastInDim(XlaOp operand,
381 const absl::Span<const int64> out_dim_size,
382 const absl::Span<const int64> broadcast_dimensions);
383
384 XlaOp Pad(XlaOp operand, XlaOp padding_value,
385 const PaddingConfig& padding_config);
386
387 XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
388 absl::Span<const int64> new_sizes,
389 int64 inferred_dimension = -1);
390
391 XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
392 int64 inferred_dimension = -1);
393
394 XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
395
396 XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
397 absl::Span<const int64> limit_indices,
398 absl::Span<const int64> strides);
399
400 XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
401 int64 stride, int64 dimno);
402
403 ABSL_DEPRECATED("Use span-of-indices form instead")
404 XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices,
405 absl::Span<const int64> slice_sizes);
406 XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
407 absl::Span<const int64> slice_sizes);
408
409 ABSL_DEPRECATED("Use span-of-indices form instead")
410 XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, XlaOp start_indices);
411 XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
412 absl::Span<const XlaOp> start_indices);
413
414 XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
415
416 void Trace(const string& tag, XlaOp operand);
417
418 XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
419
420 XlaOp Tuple(absl::Span<const XlaOp> elements);
421
422 XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
423
424 XlaOp Dot(XlaOp lhs, XlaOp rhs,
425 const PrecisionConfig* precision_config = nullptr);
426
427 XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
428 const DotDimensionNumbers& dimension_numbers,
429 const PrecisionConfig* precision_config = nullptr);
430
431 XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
432 Padding padding, int64 feature_group_count = 1,
433 int64 batch_group_count = 1,
434 const PrecisionConfig* precision_config = nullptr);
435
436 XlaOp ConvWithGeneralPadding(
437 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
438 absl::Span<const std::pair<int64, int64>> padding,
439 int64 feature_group_count = 1, int64 batch_group_count = 1,
440 const PrecisionConfig* precision_config = nullptr);
441
442 XlaOp ConvWithGeneralDimensions(
443 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
444 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
445 int64 feature_group_count = 1, int64 batch_group_count = 1,
446 const PrecisionConfig* precision_config = nullptr);
447
448 XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs,
449 absl::Span<const int64> window_strides,
450 absl::Span<const std::pair<int64, int64>> padding,
451 const ConvolutionDimensionNumbers& dimension_numbers,
452 int64 feature_group_count = 1, int64 batch_group_count = 1,
453 const PrecisionConfig* precision_config = nullptr);
454
455 XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs,
456 absl::Span<const int64> window_strides,
457 absl::Span<const std::pair<int64, int64>> padding,
458 absl::Span<const int64> lhs_dilation,
459 absl::Span<const int64> rhs_dilation,
460 const ConvolutionDimensionNumbers& dimension_numbers,
461 int64 feature_group_count = 1,
462 int64 batch_group_count = 1,
463 const PrecisionConfig* precision_config = nullptr);
464
465 XlaOp Fft(XlaOp operand, FftType fft_type,
466 absl::Span<const int64> fft_length);
467
468 XlaOp Infeed(const Shape& shape, const string& config = "");
469 XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
470 const string& config = "");
471
472 void Outfeed(XlaOp operand, const Shape& shape_with_layout,
473 const string& outfeed_config);
474 XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
475 const Shape& shape_with_layout,
476 const string& outfeed_config);
477
478 XlaOp Call(const XlaComputation& computation,
479 absl::Span<const XlaOp> operands);
480
481 XlaOp CustomCall(
482 const string& call_target_name, absl::Span<const XlaOp> operands,
483 const Shape& shape_with_layout, const string& opaque,
484 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
485
486 XlaOp Reduce(XlaOp operand, XlaOp init_value,
487 const XlaComputation& computation,
488 absl::Span<const int64> dimensions_to_reduce);
489
490 XlaOp Reduce(absl::Span<const XlaOp> operands,
491 absl::Span<const XlaOp> init_values,
492 const XlaComputation& computation,
493 absl::Span<const int64> dimensions_to_reduce);
494
495 XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
496 const XlaComputation& computation);
497
498 XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
499 const XlaComputation& computation,
500 absl::Span<const int64> window_dimensions,
501 absl::Span<const int64> window_strides, Padding padding);
502
503 XlaOp ReduceWindowWithGeneralPadding(
504 XlaOp operand, XlaOp init_value, const XlaComputation& computation,
505 absl::Span<const int64> window_dimensions,
506 absl::Span<const int64> window_strides,
507 absl::Span<const int64> base_dilations,
508 absl::Span<const int64> window_dilations,
509 absl::Span<const std::pair<int64, int64>> padding);
510
511 XlaOp CrossReplicaSum(XlaOp operand,
512 absl::Span<const ReplicaGroup> replica_groups = {});
513
514 XlaOp AllReduce(
515 XlaOp operand, const XlaComputation& computation,
516 absl::Span<const ReplicaGroup> replica_groups = {},
517 const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
518 const absl::optional<Shape>& shape_with_layout = absl::nullopt);
519
520 XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
521 int64 split_count,
522 const std::vector<ReplicaGroup>& replica_groups);
523
524 XlaOp CollectivePermute(
525 XlaOp operand,
526 const std::vector<std::pair<int64, int64>>& source_target_pairs);
527
528 XlaOp ReplicaId();
529
530 XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
531 absl::Span<const int64> window_dimensions,
532 absl::Span<const int64> window_strides,
533 Padding padding, XlaOp source, XlaOp init_value,
534 const XlaComputation& scatter);
535
536 XlaOp SelectAndScatterWithGeneralPadding(
537 XlaOp operand, const XlaComputation& select,
538 absl::Span<const int64> window_dimensions,
539 absl::Span<const int64> window_strides,
540 absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
541 XlaOp init_value, const XlaComputation& scatter);
542
543 XlaOp Iota(const Shape& shape, int64 iota_dimension);
544
545 XlaOp Iota(PrimitiveType type, int64 size);
546
547 XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
548
549 XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
550
551 XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
552
553 XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
554
555 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
556 int64 dimension = -1, bool is_stable = false);
557
558 XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
559
560 XlaOp Map(absl::Span<const XlaOp> operands, const XlaComputation& computation,
561 absl::Span<const int64> dimensions,
562 absl::Span<const XlaOp> static_operands = {});
563
564 XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
565
566 XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
567
568 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
569 XlaOp init);
570
571 XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
572 const XlaComputation& true_computation, XlaOp false_operand,
573 const XlaComputation& false_computation);
574
575 XlaOp Conditional(XlaOp branch_index,
576 absl::Span<const XlaComputation* const> branch_computations,
577 absl::Span<const XlaOp> branch_operands);
578
579 XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
580 const int mantissa_bits);
581
582 XlaOp Gather(XlaOp input, XlaOp start_indices,
583 const GatherDimensionNumbers& dimension_numbers,
584 absl::Span<const int64> slice_sizes,
585 bool indices_are_sorted = false);
586
587 XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
588 const XlaComputation& update_computation,
589 const ScatterDimensionNumbers& dimension_numbers,
590 bool indices_are_sorted = false, bool unique_indices = false);
591
592 void Send(XlaOp operand, const ChannelHandle& handle);
593 XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle);
594
595 XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout,
596 const ChannelHandle& handle);
597
598 XlaOp RecvFromHost(XlaOp token, const Shape& shape,
599 const ChannelHandle& handle);
600
601 XlaOp CreateToken();
602
603 XlaOp AfterAll(absl::Span<const XlaOp> tokens);
604
605 XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
606 XlaOp RecvWithToken(XlaOp token, const Shape& shape,
607 const ChannelHandle& handle);
608
609 XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
610 float epsilon, int64 feature_index);
611
612 XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean,
613 XlaOp variance, float epsilon, int64 feature_index);
614
615 XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
616 XlaOp batch_var, XlaOp grad_output, float epsilon,
617 int64 feature_index);
618
619 XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
620
621 XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
622
623 StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
624 absl::Span<const XlaOp> operands = {});
625
626 void AddCalledComputation(const XlaComputation& computation,
627 HloInstructionProto* instr);
628
629 StatusOr<const HloInstructionProto*> LookUpInstruction(XlaOp op) const;
630 StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
631 int64 handle) const;
632 StatusOr<HloInstructionProto*> LookUpMutableInstruction(XlaOp op);
633 StatusOr<HloInstructionProto*> LookUpMutableInstructionByHandle(int64 handle);
634
635 // Internal helper method that does the building for an arbitrary unary op.
636 XlaOp UnaryOp(HloOpcode unop, XlaOp operand);
637
638 // Internal helper method that does the building for an arbitrary binary op.
639 // broadcast_dimensions specifies which dimensions to use for broadcasting
640 // when the operation is between tensors of different ranks. The direction is
641 // only used if opcode is kCompare.
642 XlaOp BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
643 absl::Span<const int64> broadcast_dimensions,
644 absl::optional<ComparisonDirection> direction = absl::nullopt);
645
646 // Internal helper method that does the building for an arbitrary ternary op.
647 XlaOp TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs);
648
649 XlaOp RngOp(RandomDistribution distribution,
650 absl::Span<const XlaOp> parameters, const Shape& shape);
651
652 StatusOr<XlaOp> InDimBroadcast(const Shape& shape, XlaOp operand,
653 absl::Span<const int64> broadcast_dimensions);
654
655 // Internal helper method that creates a sequence of instructions that
656 // performs an explicit broadcast of the operand to the target shape.
657 StatusOr<XlaOp> AddBroadcastSequence(const Shape& output_shape,
658 XlaOp operand);
659
660 // Internal helper method for creating a Reshape op with the already inferred
661 // shape.
662 StatusOr<XlaOp> Reshape(const Shape& shape, XlaOp operand,
663 int64 inferred_dimension = -1);
664
665 // Returns the (inferred) result for the program shape using the given root.
666 StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;
667
668 // Returns shapes for the operands.
669 StatusOr<std::vector<Shape>> GetOperandShapes(
670 absl::Span<const XlaOp> operands) const;
671
672 // A visitor which checks whether an operation is a compile-time constant,
673 // meaning that it doesn't depend on any parameters, or on any stateful
674 // operation such as `RngNormal` or `Infeed`. The visitor walks the
675 // computation starting at a given operation and sets is_constant to false iff
676 // a parameter or stateful operation is encountered.
677 void IsConstantVisitor(const int64 op_handle,
678 absl::flat_hash_set<int64>* visited,
679 bool* is_constant) const;
680
681 // Checks bounds for convolution parameters.
682 Status VerifyConvolution(
683 const Shape& lhs_shape, const Shape& rhs_shape,
684 const ConvolutionDimensionNumbers& dimension_numbers) const;
685
GetNextId()686 int64 GetNextId() { return ++next_id_; }
687
688 // Populates the module with the input/output alias information stored within
689 // the input_output_aliases vector.
690 static Status PopulateInputOutputAlias(
691 HloModuleProto* module, const ProgramShape& program_shape,
692 const std::vector<InputOutputAlias>& input_output_aliases);
693
694 string name_; // Name to use for the built computation.
695
696 // The next sequential ID for every instruction/computation contained within
697 // this computation.
698 int64 next_id_ = 0;
699
700 // The first error encountered while building the computation.
701 // This is OK until the first error is encountered.
702 Status first_error_;
703
704 // The saved stack trace from the point at which the first error occurred.
705 tensorflow::SavedStackTrace first_error_backtrace_;
706
707 // The instructions of this computation.
708 std::vector<HloInstructionProto> instructions_;
709
710 // An cache for the HloInstructionProto shapes, to avoid recreating Shape
711 // objects from protos and to support the GetShapePtr() API.
712 std::vector<std::unique_ptr<Shape>> instruction_shapes_;
713
714 // Dynamic parameter configuration of this computation.
715 DynamicParameterBinding dynamic_parameter_binding_;
716
717 // Holds the input/output alias information populated by the SetUpAlias() API.
718 std::vector<InputOutputAlias> input_output_aliases_;
719
720 // A map from XlaOp::Handle to the index in the instructions_ vector where the
721 // instruction is held.
722 absl::flat_hash_map<int64, int64> handle_to_index_;
723
724 // The embedded computations used by this computation. Each computation was
725 // the entry computation of some XlaComputation, the key is the unique id of
726 // that XlaComputation.
727 std::map<int64, HloComputationProto> embedded_;
728
729 // The unique parameter numbers.
730 absl::flat_hash_set<int64> parameter_numbers_;
731
732 // The metadata to attach to each op. This is structured as a "modal"-like
733 // operation, in order to simplify client code (and not sprinkle this metadata
734 // throughout the TensorFlow op kernel implementations).
735 OpMetadata metadata_;
736
737 // Sharding for this operator. This is structured as a "model"-like operation,
738 // in order to simplify client code, similar to metadata_.
739 absl::optional<OpSharding> sharding_;
740
741 // Mode bit that indicates whether to die when a first error is encountered.
742 bool die_immediately_on_error_ = false;
743
744 XlaBuilder* parent_builder_{nullptr};
745
746 FrontendAttributes frontend_attributes_;
747
748 friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number,
749 const Shape& shape, const string& name,
750 const std::vector<bool>& replicated_at_leaf_buffers);
751 friend XlaOp ConstantLiteral(XlaBuilder* builder,
752 const LiteralSlice& literal);
753
754 friend XlaOp Broadcast(XlaOp operand,
755 absl::Span<const int64> broadcast_sizes);
756
757 friend XlaOp BroadcastInDim(
758 XlaOp operand, const absl::Span<const int64> out_dim_size,
759 const absl::Span<const int64> broadcast_dimensions);
760
761 friend XlaOp Copy(XlaOp operand);
762
763 friend XlaOp Pad(XlaOp operand, XlaOp padding_value,
764 const PaddingConfig& padding_config);
765
766 friend XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
767 absl::Span<const int64> new_sizes);
768
769 friend XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
770
771 friend XlaOp ReshapeWithInferredDimension(XlaOp operand,
772 absl::Span<const int64> new_sizes,
773 int64 inferred_dimension);
774
775 friend XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
776
777 friend XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
778 absl::Span<const int64> limit_indices,
779 absl::Span<const int64> strides);
780
781 friend XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
782 int64 stride, int64 dimno);
783
784 friend XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices,
785 absl::Span<const int64> slice_sizes);
786 friend XlaOp DynamicSlice(XlaOp operand,
787 absl::Span<const XlaOp> start_indices,
788 absl::Span<const int64> slice_sizes);
789
790 friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
791 XlaOp start_indices);
792 friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
793 absl::Span<const XlaOp> start_indices);
794
795 friend XlaOp ConcatInDim(XlaBuilder* builder,
796 absl::Span<const XlaOp> operands, int64 dimension);
797
798 friend void Trace(const string& tag, XlaOp operand);
799
800 friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
801 friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
802 friend XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
803 friend XlaOp Eq(XlaOp lhs, XlaOp rhs,
804 absl::Span<const int64> broadcast_dimensions);
805 friend XlaOp Ne(XlaOp lhs, XlaOp rhs,
806 absl::Span<const int64> broadcast_dimensions);
807 friend XlaOp Ge(XlaOp lhs, XlaOp rhs,
808 absl::Span<const int64> broadcast_dimensions);
809 friend XlaOp Gt(XlaOp lhs, XlaOp rhs,
810 absl::Span<const int64> broadcast_dimensions);
811 friend XlaOp Lt(XlaOp lhs, XlaOp rhs,
812 absl::Span<const int64> broadcast_dimensions);
813 friend XlaOp Le(XlaOp lhs, XlaOp rhs,
814 absl::Span<const int64> broadcast_dimensions);
815 friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
816 absl::Span<const int64> broadcast_dimensions,
817 ComparisonDirection direction);
818 friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
819 const PrecisionConfig* precision_config);
820 friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
821 const DotDimensionNumbers& dimension_number,
822 const PrecisionConfig* precision_config);
823 friend XlaOp Conv(XlaOp lhs, XlaOp rhs,
824 absl::Span<const int64> window_strides, Padding padding,
825 int64 feature_group_count, int64 batch_group_count,
826 const PrecisionConfig* precision_config);
827 friend XlaOp ConvWithGeneralPadding(
828 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
829 absl::Span<const std::pair<int64, int64>> padding,
830 int64 feature_group_count, int64 batch_group_count,
831 const PrecisionConfig* precision_config);
832 friend XlaOp ConvWithGeneralDimensions(
833 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
834 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
835 int64 feature_group_count, int64 batch_group_count,
836 const PrecisionConfig* precision_config);
837 friend XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs,
838 absl::Span<const int64> window_strides,
839 absl::Span<const std::pair<int64, int64>> padding,
840 const ConvolutionDimensionNumbers& dimension_numbers,
841 int64 feature_group_count, int64 batch_group_count,
842 const PrecisionConfig* precision_config);
843 friend XlaOp ConvGeneralDilated(
844 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
845 absl::Span<const std::pair<int64, int64>> padding,
846 absl::Span<const int64> lhs_dilation,
847 absl::Span<const int64> rhs_dilation,
848 const ConvolutionDimensionNumbers& dimension_numbers,
849 int64 feature_group_count, int64 batch_group_count,
850 const PrecisionConfig* precision_config);
851 friend XlaOp Fft(XlaOp operand, FftType fft_type,
852 absl::Span<const int64> fft_length);
853 friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
854 bool unit_diagonal,
855 TriangularSolveOptions::Transpose transpose_a);
856 friend XlaOp Cholesky(XlaOp a, bool lower);
857 friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
858 const string& config);
859 friend void Outfeed(XlaOp operand, const Shape& shape_with_layout,
860 const string& outfeed_config);
861 friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
862 absl::Span<const XlaOp> operands);
863 friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
864 absl::Span<const XlaOp> operands, const Shape& shape,
865 const string& opaque);
866 friend XlaOp CustomCallWithLayout(
867 XlaBuilder* builder, const string& call_target_name,
868 absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
869 absl::Span<const Shape> operand_shapes_with_layout, const string& opaque);
870 friend XlaOp Complex(XlaOp real, XlaOp imag,
871 absl::Span<const int64> broadcast_dimensions);
872 friend XlaOp Conj(XlaOp operand);
873 friend XlaOp Add(XlaOp lhs, XlaOp rhs,
874 absl::Span<const int64> broadcast_dimensions);
875 friend XlaOp Sub(XlaOp lhs, XlaOp rhs,
876 absl::Span<const int64> broadcast_dimensions);
877 friend XlaOp Mul(XlaOp lhs, XlaOp rhs,
878 absl::Span<const int64> broadcast_dimensions);
879 friend XlaOp Div(XlaOp lhs, XlaOp rhs,
880 absl::Span<const int64> broadcast_dimensions);
881 friend XlaOp Rem(XlaOp lhs, XlaOp rhs,
882 absl::Span<const int64> broadcast_dimensions);
883 friend XlaOp Max(XlaOp lhs, XlaOp rhs,
884 absl::Span<const int64> broadcast_dimensions);
885 friend XlaOp Min(XlaOp lhs, XlaOp rhs,
886 absl::Span<const int64> broadcast_dimensions);
887 friend XlaOp And(XlaOp lhs, XlaOp rhs,
888 absl::Span<const int64> broadcast_dimensions);
889 friend XlaOp Or(XlaOp lhs, XlaOp rhs,
890 absl::Span<const int64> broadcast_dimensions);
891 friend XlaOp Xor(XlaOp lhs, XlaOp rhs,
892 absl::Span<const int64> broadcast_dimensions);
893 friend XlaOp Not(XlaOp operand);
894 friend XlaOp PopulationCount(XlaOp operand);
895 friend XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs,
896 absl::Span<const int64> broadcast_dimensions);
897 friend XlaOp ShiftRightArithmetic(
898 XlaOp lhs, XlaOp rhs, absl::Span<const int64> broadcast_dimensions);
899 friend XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs,
900 absl::Span<const int64> broadcast_dimensions);
901 friend XlaOp Reduce(XlaOp operand, XlaOp init_value,
902 const XlaComputation& computation,
903 absl::Span<const int64> dimensions_to_reduce);
904 friend XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
905 absl::Span<const XlaOp> init_values,
906 const XlaComputation& computation,
907 absl::Span<const int64> dimensions_to_reduce);
908 friend XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
909 const XlaComputation& computation);
910 friend XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
911 const XlaComputation& computation,
912 absl::Span<const int64> window_dimensions,
913 absl::Span<const int64> window_strides,
914 Padding padding);
915 friend XlaOp ReduceWindowWithGeneralPadding(
916 XlaOp operand, XlaOp init_value, const XlaComputation& computation,
917 absl::Span<const int64> window_dimensions,
918 absl::Span<const int64> window_strides,
919 absl::Span<const int64> base_dilations,
920 absl::Span<const int64> window_dilations,
921 absl::Span<const std::pair<int64, int64>> padding);
922 friend XlaOp CrossReplicaSum(XlaOp operand,
923 absl::Span<const ReplicaGroup> replica_groups);
924 friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
925 absl::Span<const ReplicaGroup> replica_groups,
926 const absl::optional<ChannelHandle>& channel_id,
927 const absl::optional<Shape>& shape_with_layout);
928 friend XlaOp AllToAll(XlaOp operand, int64 split_dimension,
929 int64 concat_dimension, int64 split_count,
930 const std::vector<ReplicaGroup>& replica_groups);
931 friend XlaOp CollectivePermute(
932 XlaOp operand,
933 const std::vector<std::pair<int64, int64>>& source_target_pairs);
934 friend XlaOp ReplicaId(XlaBuilder* builder);
935 friend XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
936 absl::Span<const int64> window_dimensions,
937 absl::Span<const int64> window_strides,
938 Padding padding, XlaOp source, XlaOp init_value,
939 const XlaComputation& scatter);
940 friend XlaOp SelectAndScatterWithGeneralPadding(
941 XlaOp operand, const XlaComputation& select,
942 absl::Span<const int64> window_dimensions,
943 absl::Span<const int64> window_strides,
944 absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
945 XlaOp init_value, const XlaComputation& scatter);
946 friend XlaOp Abs(XlaOp operand);
947 friend XlaOp Atan2(XlaOp y, XlaOp x,
948 absl::Span<const int64> broadcast_dimensions);
949 friend XlaOp Exp(XlaOp operand);
950 friend XlaOp Expm1(XlaOp operand);
951 friend XlaOp Floor(XlaOp operand);
952 friend XlaOp Ceil(XlaOp operand);
953 friend XlaOp Round(XlaOp operand);
954 friend XlaOp Log(XlaOp operand);
955 friend XlaOp Log1p(XlaOp operand);
956 friend XlaOp Sign(XlaOp operand);
957 friend XlaOp Clz(XlaOp operand);
958 friend XlaOp Cos(XlaOp operand);
959 friend XlaOp Sin(XlaOp operand);
960 friend XlaOp Tanh(XlaOp operand);
961 friend XlaOp Real(XlaOp operand);
962 friend XlaOp Imag(XlaOp operand);
963 friend XlaOp Sqrt(XlaOp operand);
964 friend XlaOp Rsqrt(XlaOp operand);
965 friend XlaOp Pow(XlaOp lhs, XlaOp rhs,
966 absl::Span<const int64> broadcast_dimensions);
967 friend XlaOp IsFinite(XlaOp operand);
968 friend XlaOp Iota(XlaBuilder* builder, const Shape& shape,
969 int64 iota_dimension);
970 friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
971 friend XlaOp ConvertElementType(XlaOp operand,
972 PrimitiveType new_element_type);
973 friend XlaOp BitcastConvertType(XlaOp operand,
974 PrimitiveType new_element_type);
975 friend XlaOp Neg(XlaOp operand);
976 friend XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
977 friend XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
978 friend XlaOp Sort(absl::Span<const XlaOp> operands,
979 const XlaComputation& comparator, int64 dimension,
980 bool is_stable);
981 friend XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
982 friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
983 const XlaComputation& computation,
984 absl::Span<const int64> dimensions,
985 absl::Span<const XlaOp> static_operands);
986 friend XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
987 friend XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
988 friend XlaOp While(const XlaComputation& condition,
989 const XlaComputation& body, XlaOp init);
990 friend XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
991 const XlaComputation& true_computation,
992 XlaOp false_operand,
993 const XlaComputation& false_computation);
994 friend XlaOp Conditional(
995 XlaOp branch_index,
996 absl::Span<const XlaComputation* const> branch_computations,
997 absl::Span<const XlaOp> branch_operands);
998 friend XlaOp ConditionalImpl(
999 XlaOp branch_index,
1000 absl::Span<const XlaComputation* const> branch_computations,
1001 absl::Span<const XlaOp> branch_operands);
1002 friend XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
1003 const int mantissa_bits);
1004 friend XlaOp Gather(XlaOp input, XlaOp start_indices,
1005 const GatherDimensionNumbers& dimension_numbers,
1006 absl::Span<const int64> slice_sizes,
1007 bool indices_are_sorted);
1008 friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
1009 const XlaComputation& update_computation,
1010 const ScatterDimensionNumbers& dimension_numbers,
1011 bool indices_are_sorted, bool unique_indices);
1012 friend void Send(XlaOp operand, const ChannelHandle& handle);
1013 friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
1014 const ChannelHandle& handle);
1015 friend XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
1016 float epsilon, int64 feature_index);
1017 friend XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset,
1018 XlaOp mean, XlaOp variance, float epsilon,
1019 int64 feature_index);
1020 friend XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
1021 XlaOp batch_var, XlaOp grad_output, float epsilon,
1022 int64 feature_index);
1023 friend XlaOp SendWithToken(XlaOp operand, XlaOp token,
1024 const ChannelHandle& handle);
1025 friend XlaOp RecvWithToken(XlaOp token, const Shape& shape,
1026 const ChannelHandle& handle);
1027 friend XlaOp SendToHost(XlaOp operand, XlaOp token,
1028 const Shape& shape_with_layout,
1029 const ChannelHandle& handle);
1030 friend XlaOp RecvFromHost(XlaOp token, const Shape& shape,
1031 const ChannelHandle& handle);
1032 friend XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
1033 const string& config);
1034 friend XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
1035 const Shape& shape_with_layout,
1036 const string& outfeed_config);
1037 friend XlaOp CreateToken(XlaBuilder* builder);
1038 friend XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
1039
1040 friend XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
1041 friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
1042
1043 private:
1044 XlaOp ConditionalImpl(
1045 XlaOp branch_index,
1046 absl::Span<const XlaComputation* const> branch_computations,
1047 absl::Span<const XlaOp> branch_operands);
1048 };
1049
1050 // RAII-style object: sets the current sharding assignment in builder on
1051 // construction, and sets back to the previous assignment on destruction.
1052 class XlaScopedShardingAssignment {
1053 public:
XlaScopedShardingAssignment(xla::XlaBuilder * builder,absl::optional<OpSharding> sharding)1054 XlaScopedShardingAssignment(xla::XlaBuilder* builder,
1055 absl::optional<OpSharding> sharding)
1056 : builder_(builder), prev_sharding_(builder->sharding()) {
1057 SetSharding(sharding);
1058 }
1059
1060 XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
1061 XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
1062 delete;
1063
~XlaScopedShardingAssignment()1064 ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
1065
1066 private:
SetSharding(const absl::optional<OpSharding> & sharding)1067 void SetSharding(const absl::optional<OpSharding>& sharding) {
1068 if (sharding.has_value()) {
1069 builder_->SetSharding(sharding.value());
1070 } else {
1071 builder_->ClearSharding();
1072 }
1073 }
1074
1075 xla::XlaBuilder* const builder_;
1076 absl::optional<OpSharding> prev_sharding_;
1077 };
1078
1079 // RAII-style object: save the current builder's frontend attributes, and merge
1080 // them with the new ones on construction.
1081 // Restore the original attributes on destruction.
1082 class XlaScopedFrontendAttributesAssignment {
1083 public:
XlaScopedFrontendAttributesAssignment(xla::XlaBuilder * builder,FrontendAttributes attributes)1084 XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder,
1085 FrontendAttributes attributes)
1086 : builder_(builder) {
1087 saved_ = builder_->SwapFrontendAttributes(attributes);
1088 }
1089
~XlaScopedFrontendAttributesAssignment()1090 ~XlaScopedFrontendAttributesAssignment() {
1091 builder_->SetFrontendAttributes(saved_);
1092 }
1093
1094 private:
1095 xla::XlaBuilder* const builder_;
1096 FrontendAttributes saved_;
1097
1098 TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedFrontendAttributesAssignment);
1099 };
1100 // Free functions for building XlaOps. The intention is that these will
1101 // become the public API for building XlaOps rather than calling methods on
1102 // XlaBuilder directly.
1103 //
1104
1105 // Enqueues a "retrieve parameter value" instruction for a parameter that was
1106 // passed to the computation.
1107 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
1108 const string& name);
1109
1110 // Same as above, but with leaf buffer replication annotation.
1111 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
1112 const string& name,
1113 const std::vector<bool>& replicated_at_leaf_buffers);
1114
1115 // Enqueues a constant with the value of the given literal onto the
1116 // computation.
1117 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal);
1118
1119 // Enqueues a constant onto the computation. Methods are templated on the
1120 // native host type (NativeT) which corresponds to a specific XLA
1121 // PrimitiveType as given in the following table:
1122 //
1123 // Native Type PrimitiveType
1124 // -----------------------------
1125 // bool PRED
1126 // int32 S32
1127 // int64 S64
1128 // uint32 U32
1129 // uint64 U64
1130 // float F32
1131 // double F64
1132 //
1133 // Note: not all primitive types defined in xla_data.proto have a
1134 // corresponding native type yet.
1135 template <typename NativeT>
1136 XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
1137 template <typename NativeT>
1138 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values);
1139 XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values);
1140 template <typename NativeT>
1141 XlaOp ConstantR2(XlaBuilder* builder,
1142 std::initializer_list<std::initializer_list<NativeT>> values);
1143 template <typename NativeT>
1144 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
1145 const Array<NativeT>& values,
1146 const Layout& layout);
1147 template <typename NativeT>
1148 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values);
1149 template <typename NativeT>
1150 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
1151 const Array2D<NativeT>& values,
1152 const Layout& layout);
1153 template <typename NativeT>
1154 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
1155 const Array2D<NativeT>& values);
1156 template <typename NativeT>
1157 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
1158 const Array3D<NativeT>& values,
1159 const Layout& layout);
1160 template <typename NativeT>
1161 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
1162 const Array3D<NativeT>& values);
1163 template <typename NativeT>
1164 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
1165 const Array4D<NativeT>& values,
1166 const Layout& layout);
1167 template <typename NativeT>
1168 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
1169 const Array4D<NativeT>& values);
1170
1171 // Enqueues a rank one constant (XlaBuilder* builder, vector) onto the
1172 // computation. The vector has size 'length' and every element has the value
1173 // 'value'.
1174 template <typename NativeT>
1175 XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
1176
1177 // Adds dimensions to an array by duplicating the data in the array.
1178 //
1179 // The new dimensions are inserted on the left, i.e. if
1180 // broadcast_sizes has values {a0, ..., aN} and the operand shape
1181 // has dimensions {b0, ..., bM} then the shape of the output has
1182 // dimensions {a0, ..., aN, b0, ..., bM}.
1183 //
1184 // The new dimensions index into copies of the operand, i.e.
1185 //
1186 // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
1187 XlaOp Broadcast(XlaOp operand, absl::Span<const int64> broadcast_sizes);
1188
1189 // This op broadcasts the `operand` to an output with the given `shape`.
1190 // `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the
1191 // i'th dimension of the operand is mapped to the broadcast_dimensions[i]'th
1192 // dimension of the output. This also requires that the i'th input dimension is
1193 // either 1 or is the same as the output dimension it's broadcasting into.
1194 //
1195 // For example, say operand = {1, 2}, i.e., a 1D tensor in shape s32[2]; the
1196 // output shape is s32[2,2]:
1197 // - Specifying {1} as broadcast_dimension will generate output
1198 // {{1, 2},
1199 // {1, 2}}
1200 // - On the other hand, specifying {0} as broadcast_dimension
1201 // will generate output
1202 // {{1 , 1},
1203 // {2 , 2}}
1204 XlaOp BroadcastInDim(XlaOp operand, const absl::Span<const int64> out_dim_size,
1205 const absl::Span<const int64> broadcast_dimensions);
1206
1207 // Copies the input operand to the output. This operation is for internal
1208 // purpose and is only used by the compiler for optimization purposes or to
1209 // ensure correctness. The XLA client should never have to generate this
1210 // instruction.
1211 //
1212 // Copy has two potential use cases:
1213 //
1214 // * Create a copy of the operand with a new layout.
1215 //
1216 // * Create a copy of the operand in a separately allocated buffer. This is
1217 // necessary for some backends if the operand is a parameter or constant and
1218 // the operand is returned within a tuple. In this case, the lifetime of the
1219 // operand buffer must be the same as the lifetime of the output result.
1220 // However, the lifetimes of parameters and constants are managed separately
1221 // from the lifetime of the output result. Creating a separate copy of the
1222 // parameter or constant buffer resolves this issue.
1223 XlaOp Copy(XlaOp operand);
1224
1225 // Enqueues a pad operation onto the computation that pads the given value on
1226 // the edges as well as between the elements of the input. padding_config
1227 // specifies the padding amount for each dimension.
1228 XlaOp Pad(XlaOp operand, XlaOp padding_value,
1229 const PaddingConfig& padding_config);
1230
1231 // Enqueues an operation onto the computation that flattens the operand based
1232 // on the dimension order (major/slowest-varying to minor/fastest-varying)
1233 // given, followed by reshaping it into the shape with the given dimension
1234 // sizes (also major to minor). Conceptually, this is a limited form of
1235 // "shape casting".
1236 XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
1237 absl::Span<const int64> new_sizes);
1238
1239 // Enqueues an operation onto the computation that collapses the operand, from
1240 // first to last dimension (C order), then reshapes it to the given dimension
1241 // sizes. Conceptually, this is a limited form of "shape casting".
1242 XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
1243
1244 // `inferred_dimension` represents the output dimension that's inferred by
1245 // upper-level framework by dividing the input element count by the known
1246 // output element count. While an inferred_dimension can be static, if there
1247 // is a dynamic dimension in the output, it must be the inferred dimension.
1248 XlaOp ReshapeWithInferredDimension(XlaOp operand,
1249 absl::Span<const int64> new_sizes,
1250 int64 inferred_dimension);
1251
1252 // Wrapper for Reshape.
1253 // Enqueues an operation to collapse the provided dimensions; e.g. an
1254 // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
1255 // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
1256 // be a consecutive, in-order subsequence of the operand dimensions.
1257 //
1258 // Note that collapsing a single dimension does nothing:
1259 //
1260 // {256} collapsing {0} => {256}
1261 // {1} collapsing {0} => {1}
1262 //
1263 // Collapsing multiple dimensions produces a single result dimension:
1264 //
1265 // {256, 2} collapsing {0,1} => {512}
1266 // {256, 2, 3} collapsing {0,1} => {512, 3}
1267 //
1268 // This could potentially cause data to be moved -- it provides a more
1269 // structured form of reshaping than an arbitrary Reshape operation.
1270 XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
1271
1272 // Enqueues a slice operation onto the computation that slices the operand
1273 // from the start indices to the limit indices; e.g.
1274 //
1275 // x
1276 // [ 0 1 2 3 ]
1277 // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
1278 // [ 8 9 a b ]
1279 //
1280 // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
1281 // range notation.
1282 // The strides parameter determines the stride over the slice
1283 XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
1284 absl::Span<const int64> limit_indices,
1285 absl::Span<const int64> strides);
1286
1287 // Enqueues a slice operation in a given dimension, taking all other
1288 // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
1289 // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
1290 // for:
1291 //
1292 // array[:, 2:4:1, :]
1293 XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
1294 int64 stride, int64 dimno);
1295
1296 // Enqueues a slice operation onto the computation that slices the 'operand'
1297 // from dynamic start indices which are passed in 'start_indices'.
1298 // The size of the slice in each dimension is passed in 'slice_sizes',
1299 // which specify the end point of exclusive slice intervals in each
1300 // dimension [start, start + size).
1301 // The shape of each element of 'start_indices' must be scalar, with the span
1302 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1303 // have the same shape.
1304 // Slice index calculations are computed modulo input dimension sizes to
1305 // prevent dynamic start indices from generating out-of-bound array accesses.
1306 XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
1307 absl::Span<const int64> slice_sizes);
1308
1309 ABSL_DEPRECATED("Use span-of-indices form instead")
1310 XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices,
1311 absl::Span<const int64> slice_sizes);
1312
1313 // Enqueues a dynamic update slice operation onto the computation, which
1314 // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
1315 // The shape of 'update' determines the shape of the slice of 'operand'
1316 // which is updated.
1317 // The indices specified in 'start_indices' specify the offset of the slice
1318 // of 'operand' which is updated.
1319 //
1320 // update = {10, 11} // calculated at runtime.
1321 // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ]
1322 // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11]
1323 // [7 8 9] [7 8 9 ]
1324 //
1325 // The shape of each element of 'start_indices' must be scalar, with the span
1326 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1327 // have the same shape.
1328 // Slice index calculations are computed modulo update dimension sizes to
1329 // prevent dynamic start indices from generating out-of-bound array accesses.
1330 XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
1331 absl::Span<const XlaOp> start_indices);
1332
1333 ABSL_DEPRECATED("Use span-of-indices form instead")
1334 XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, XlaOp start_indices);
1335
1336 // Enqueues a concatenate instruction onto the computation. 'operands' must
1337 // have >= 1 entry.
1338 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1339 int64 dimension);
1340
1341 // Enqueue a tracing operation onto the computation; the computation will emit
1342 // a logging message with the operand.
1343 void Trace(const string& tag, XlaOp operand);
1344
1345 // Enqueues a conditional-move-like select operation onto the computation;
1346 // predicated on pred, selects between on_true and on_false.
1347 XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
1348
1349 // Enqueues a tuple-creation instruction onto the computation.
1350 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
1351
1352 // Enqueues a tuple-element-get instruction onto the computation.
1353 XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
1354
1355 // Enqueues an equal-to comparison instruction onto the computation.
1356 XlaOp Eq(XlaOp lhs, XlaOp rhs,
1357 absl::Span<const int64> broadcast_dimensions = {});
1358
1359 // Enqueues a not-equal comparison instruction onto the computation.
1360 XlaOp Ne(XlaOp lhs, XlaOp rhs,
1361 absl::Span<const int64> broadcast_dimensions = {});
1362
1363 // Enqueues a greater-or-equal comparison instruction onto the computation.
1364 XlaOp Ge(XlaOp lhs, XlaOp rhs,
1365 absl::Span<const int64> broadcast_dimensions = {});
1366
1367 // Enqueues a greater-than comparison instruction onto the computation.
1368 XlaOp Gt(XlaOp lhs, XlaOp rhs,
1369 absl::Span<const int64> broadcast_dimensions = {});
1370
1371 // Enqueues a less-than comparison instruction onto the computation.
1372 XlaOp Lt(XlaOp lhs, XlaOp rhs,
1373 absl::Span<const int64> broadcast_dimensions = {});
1374
1375 // Enqueues a less-or-equal comparison instruction onto the computation.
1376 XlaOp Le(XlaOp lhs, XlaOp rhs,
1377 absl::Span<const int64> broadcast_dimensions = {});
1378
1379 // Enqueues a comparison instruction onto the computation.
1380 XlaOp Compare(XlaOp lhs, XlaOp rhs,
1381 absl::Span<const int64> broadcast_dimensions,
1382 ComparisonDirection direction);
1383
1384 // Enqueues a dot instruction onto the computation.
1385 XlaOp Dot(XlaOp lhs, XlaOp rhs,
1386 const PrecisionConfig* precision_config = nullptr);
1387
1388 // Enqueues a general dot instruction onto the computation.
1389 XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
1390 const DotDimensionNumbers& dimension_numbers,
1391 const PrecisionConfig* precision_config = nullptr);
1392
1393 // Enqueues a convolution instruction onto the computation, which uses the
1394 // default convolution dimension numbers.
1395 XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1396 Padding padding, int64 feature_group_count = 1,
1397 int64 batch_group_count = 1,
1398 const PrecisionConfig* precision_config = nullptr);
1399
1400 // Enqueues a convolution instruction onto the computation, with the caller
1401 // provided padding configuration in the format returned by MakePadding().
1402 XlaOp ConvWithGeneralPadding(XlaOp lhs, XlaOp rhs,
1403 absl::Span<const int64> window_strides,
1404 absl::Span<const std::pair<int64, int64>> padding,
1405 int64 feature_group_count = 1,
1406 int64 batch_group_count = 1,
1407 const PrecisionConfig* precision_config = nullptr);
1408
1409 // Enqueues a convolution instruction onto the computation, with the caller
1410 // provided dimension numbers configuration.
1411 XlaOp ConvWithGeneralDimensions(
1412 XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1413 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1414 int64 feature_group_count = 1, int64 batch_group_count = 1,
1415 const PrecisionConfig* precision_config = nullptr);
1416
1417 // Enqueues a convolution instruction onto the computation, with the caller
1418 // provided padding configuration as well as the dimension numbers.
1419 XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1420 absl::Span<const std::pair<int64, int64>> padding,
1421 const ConvolutionDimensionNumbers& dimension_numbers,
1422 int64 feature_group_count = 1, int64 batch_group_count = 1,
1423 const PrecisionConfig* precision_config = nullptr);
1424
1425 // Enqueues a convolution instruction onto the computation, with the caller
1426 // provided padding configuration, dilation factors and dimension numbers.
1427 XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs,
1428 absl::Span<const int64> window_strides,
1429 absl::Span<const std::pair<int64, int64>> padding,
1430 absl::Span<const int64> lhs_dilation,
1431 absl::Span<const int64> rhs_dilation,
1432 const ConvolutionDimensionNumbers& dimension_numbers,
1433 int64 feature_group_count = 1,
1434 int64 batch_group_count = 1,
1435 const PrecisionConfig* precision_config = nullptr);
1436
1437 // Enqueues an FFT instruction onto the computation, of the given type and
1438 // with the given FFT length.
1439 XlaOp Fft(XlaOp operand, FftType fft_type, absl::Span<const int64> fft_length);
1440
1441 // Solves systems of linear equations with lower or upper triangular coefficient
1442 // matrices by forward- or back-substitution. Broadcasting along leading
1443 // dimensions, this routine solves for x in one of the matrix systems
1444 // `op(a) * x = b`, or `x * op(a) = b`,
1445 // for the variable `x` given `a` and `b`, where `op(a)` is either
1446 // `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`.
1447 //
1448 // * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form
1449 // square matrices. If `lower` is true (false), then the strictly upper
1450 // (lower) triangular part of each innermost matrix in `a` is assumed to be
1451 // zero and is not accessed.
1452 // * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a
1453 // tensor of shape `[..., K, M]`.
1454 // * `left_side` is a boolean, indicating whether to solve a system of the form
1455 // op(a) * x = b (true) or x * op(a) = b (false).
1456 // * `lower` is a boolean, indicating whether the argument `a` is
1457 // lower-triangular (true) or upper-triangular (false).
1458 // * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be
1459 // 1 and not accessed.
1460 // * `transpose_a` indicates which function `op` we use to transform the tensor
1461 // `a`: the identity function, transpose(a), or conjugate(transpose(a))
1462 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
1463 bool unit_diagonal,
1464 TriangularSolveOptions::Transpose transpose_a);
1465
1466 // Computes the Cholesky decompositions of a batch of symmetric (Hermitian)
1467 // positive definite matrices.
1468 // `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the
1469 // two minor dimensions equal.
1470 // If `lower` is true, the data from the lower triangle is used; if false, the
1471 // upper triangle is used. The input data in the other triangle of the input
1472 // does not affect the output. Returns the output in the same lower/upper
1473 // triangle. The data returned in the other output triangle is arbitrary and
1474 // implementation-defined.
1475 //
1476 // If `a` is not Hermitian positive definite, returns an array full of NaNs.
1477 XlaOp Cholesky(XlaOp a, bool lower);
1478
1479 // Enqueues an infeed instruction onto the computation, which writes data of
1480 // the given shape to the infeed buffer of the device.
1481 XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
1482 const string& config = "");
1483
1484 // Variant of Infeed which takes a token-shaped operand and produces a
1485 // two-element tuple containing the data value and a token-shaped value.
1486 // Tokens are used for ordering side-effecting operations.
1487 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
1488 XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
1489 const string& config = "");
1490
1491 // Enqueues an outfeed instruction onto the computation. This instruction
1492 // generates outgoing data transfers for the given data.
1493 //
1494 // shape_with_layout communicates the laid out shape that we want to outfeed
1495 // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
1496 // will occur.
1497 void Outfeed(XlaOp operand, const Shape& shape_with_layout,
1498 const string& outfeed_config);
1499
1500 // Variant of Outfeed which takes a token-shaped operand and produces a
1501 // token-shaped value. Tokens are used for ordering side-effecting operations.
1502 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
1503 XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
1504 const Shape& shape_with_layout,
1505 const string& outfeed_config);
1506
1507 // Enqueues a call instruction onto the computation.
1508 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
1509 absl::Span<const XlaOp> operands);
1510
1511 // Enqueues a custom call instruction onto the computation. A custom call
1512 // invokes code external to XLA. The |operands| are passed to the external code,
1513 // and the external code is expected to produce a result of the given
1514 // |shape|. The exact mechanism is backend-specific. For example, in the CPU
1515 // backend, a call instruction is emitted which targets a symbol with the name
1516 // |call_target_name|. |call_target_name| and |opaque| can arbitrary strings,
1517 // but |call_target_name| should be short as it may be used in labels. |opaque|
1518 // can encode arbitrarily large amounts of information.
1519 XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
1520 absl::Span<const XlaOp> operands, const Shape& shape,
1521 const string& opaque = "");
1522
1523 // Overload which constructs a custom call with fixed layouts. The operands will
1524 // have the layouts specified by |operand_shapes_with_layout| when provided to
1525 // external code, and the external code is expected to produce a result with the
1526 // layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
1527 // and |operand_shapes_with_layout| must have layouts.
1528 XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
1529 absl::Span<const XlaOp> operands,
1530 const Shape& shape_with_layout,
1531 absl::Span<const Shape> operand_shapes_with_layout,
1532 const string& opaque = "");
1533
1534 // The following methods enqueue element-wise binary arithmetic operations
1535 // onto the computation. The shapes of the operands have to match unless one
1536 // of the operands is a scalar, or an explicit broadcast dimension is given
1537 // (see g3doc for more details).
1538
1539 // Enqueues a complex compose instruction onto the computation.
1540 XlaOp Complex(XlaOp real, XlaOp imag,
1541 absl::Span<const int64> broadcast_dimensions = {});
1542
1543 // Enqueues a complex conjugate instruction onto the computation.
1544 XlaOp Conj(XlaOp operand);
1545
1546 // Enqueues an add instruction onto the computation.
1547 XlaOp Add(XlaOp lhs, XlaOp rhs,
1548 absl::Span<const int64> broadcast_dimensions = {});
1549
1550 // Enqueues a subtract instruction onto the computation.
1551 XlaOp Sub(XlaOp lhs, XlaOp rhs,
1552 absl::Span<const int64> broadcast_dimensions = {});
1553
1554 // Enqueues a multiply instruction onto the computation.
1555 XlaOp Mul(XlaOp lhs, XlaOp rhs,
1556 absl::Span<const int64> broadcast_dimensions = {});
1557
1558 // Enqueues a divide instruction onto the computation.
1559 XlaOp Div(XlaOp lhs, XlaOp rhs,
1560 absl::Span<const int64> broadcast_dimensions = {});
1561
1562 // Enqueues a remainder instruction onto the computation.
1563 XlaOp Rem(XlaOp lhs, XlaOp rhs,
1564 absl::Span<const int64> broadcast_dimensions = {});
1565
1566 // Enqueues a max instruction onto the computation.
1567 XlaOp Max(XlaOp lhs, XlaOp rhs,
1568 absl::Span<const int64> broadcast_dimensions = {});
1569
1570 // Enqueues a min instruction onto the computation.
1571 XlaOp Min(XlaOp lhs, XlaOp rhs,
1572 absl::Span<const int64> broadcast_dimensions = {});
1573
1574 // Element-wise logical operators
1575 XlaOp And(XlaOp lhs, XlaOp rhs,
1576 absl::Span<const int64> broadcast_dimensions = {});
1577
1578 // Overload to call And with 3 or more operands. We need the following somewhat
1579 // convoluted overload set to disambiguate with the overload that takes the
1580 // `broadcast_dimensions` optional param.
And(XlaOp op1,XlaOp op2,XlaOp op3)1581 inline XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3) {
1582 return And(op1, And(op2, op3));
1583 }
1584 template <typename... XlaOpTs>
And(XlaOp op1,XlaOp op2,XlaOp op3,const XlaOpTs &...operands)1585 XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) {
1586 return And(op1, And(op2, And(op3, operands...)));
1587 }
1588
1589 XlaOp Or(XlaOp lhs, XlaOp rhs,
1590 absl::Span<const int64> broadcast_dimensions = {});
1591
1592 // Overload to call Or with 3 or more operands. As with `And`, we need the
1593 // following complicated overload set to handle the default arg in the `Or`
1594 // overload above.
Or(XlaOp op1,XlaOp op2,XlaOp op3)1595 inline XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3) {
1596 return Or(op1, Or(op2, op3));
1597 }
1598 template <typename... XlaOpTs>
Or(XlaOp op1,XlaOp op2,XlaOp op3,const XlaOpTs &...operands)1599 XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) {
1600 return Or(op1, Or(op2, Or(op3, operands...)));
1601 }
1602
1603 XlaOp Xor(XlaOp lhs, XlaOp rhs,
1604 absl::Span<const int64> broadcast_dimensions = {});
1605
1606 XlaOp Not(XlaOp operand);
1607
1608 XlaOp PopulationCount(XlaOp operand);
1609
1610 XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs,
1611 absl::Span<const int64> broadcast_dimensions = {});
1612 XlaOp ShiftRightArithmetic(XlaOp lhs, XlaOp rhs,
1613 absl::Span<const int64> broadcast_dimensions = {});
1614 XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs,
1615 absl::Span<const int64> broadcast_dimensions = {});
1616
1617 // Reduces an array among the provided dimensions, given "computation" as a
1618 // reduction operator.
1619 XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation,
1620 absl::Span<const int64> dimensions_to_reduce);
1621
1622 // Reduces several arrays simultaneously among the provided dimensions, given
1623 // "computation" as a reduction operator.
1624 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1625 absl::Span<const XlaOp> init_values,
1626 const XlaComputation& computation,
1627 absl::Span<const int64> dimensions_to_reduce);
1628
1629 // Convenience wrapper around the above that reduces all the dimensions in the
1630 // operand shape.
1631 XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
1632 const XlaComputation& computation);
1633
1634 // Enqueues a windowed reduce instruction onto the computation.
1635 XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
1636 const XlaComputation& computation,
1637 absl::Span<const int64> window_dimensions,
1638 absl::Span<const int64> window_strides, Padding padding);
1639
1640 // As ReduceWindow(), but the padding is given in the format
1641 // returned by MakePadding().
1642 XlaOp ReduceWindowWithGeneralPadding(
1643 XlaOp operand, XlaOp init_value, const XlaComputation& computation,
1644 absl::Span<const int64> window_dimensions,
1645 absl::Span<const int64> window_strides,
1646 absl::Span<const int64> base_dilations,
1647 absl::Span<const int64> window_dilations,
1648 absl::Span<const std::pair<int64, int64>> padding);
1649
1650 // Returns the sum of the operand value within each subgroup of replicas. All
1651 // replicas supply one input to the sum and all replicas receive the resulting
1652 // sum for each subgroup.
1653 XlaOp CrossReplicaSum(XlaOp operand,
1654 absl::Span<const ReplicaGroup> replica_groups = {});
1655
1656 // Enqueues an operation that do an AllReduce of the operand cross cores. Here
1657 // AllReduce means doing a reduction on the input operand cross cores and then
1658 // broadcasting the reduction result to those cores. The reduction function is
1659 // defined by `computation`, which should be a commutative computation on
1660 // scalars, e.g., add, min, or max. The way that AllReduce is applied is
1661 // configured by:
1662 //
1663 // - `replica_groups`: each ReplicaGroup contains a list of replica id. If
1664 // empty, all replicas belong to one group. Allreduce will be applied within
1665 // subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}}
1666 // means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
1667 //
1668 // - `channel_id`: for Allreduce nodes from different modules, if they have the
1669 // same channel_id, they will be 'AllReduce'd. If empty, AllReduce will not be
1670 // applied cross modules.
1671 //
1672 // - `shape_with_layout`: forces the layout of the AllReduce to the given
1673 // layout. This is used to guarantee the same layout for a group of AllReduce
1674 // ops compiled separately.
1675 XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
1676 absl::Span<const ReplicaGroup> replica_groups = {},
1677 const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
1678 const absl::optional<Shape>& shape_with_layout = absl::nullopt);
1679
1680 // Enqueues an operation that do an Alltoall of the operand cross cores.
1681 XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
1682 int64 split_count,
1683 const std::vector<ReplicaGroup>& replica_groups = {});
1684
1685 // Enqueues an collective operation that sends and receives data cross replicas.
1686 //
1687 // - `source_target_pair`: a list of (source_replica_id, target_replica_id)
1688 // pairs. For each pair, the operand is sent from source replica to target
1689 // replica. Note that, 1) any two pairs should not have the same target replica
1690 // id, and they should not have the same source replica id; 2) if a replica id
1691 // is not a target in any pair, then the output on that replica is a tensor
1692 // consists of 0(s) with the same shape as the input.
1693 XlaOp CollectivePermute(
1694 XlaOp operand,
1695 const std::vector<std::pair<int64, int64>>& source_target_pairs);
1696
1697 // Enqueues an operation that returns the replica ID.
1698 XlaOp ReplicaId(XlaBuilder* builder);
1699
1700 // Enqueues an operation that scatters the `source` array to the selected
1701 // indices of each window.
1702 XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
1703 absl::Span<const int64> window_dimensions,
1704 absl::Span<const int64> window_strides, Padding padding,
1705 XlaOp source, XlaOp init_value,
1706 const XlaComputation& scatter);
1707
1708 // As SelectAndScatter(), but the padding is given in the format
1709 // returned by MakePadding().
1710 XlaOp SelectAndScatterWithGeneralPadding(
1711 XlaOp operand, const XlaComputation& select,
1712 absl::Span<const int64> window_dimensions,
1713 absl::Span<const int64> window_strides,
1714 absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
1715 XlaOp init_value, const XlaComputation& scatter);
1716
1717 // Enqueues an abs instruction onto the computation.
1718 XlaOp Abs(XlaOp operand);
1719
1720 // Enqueues a atan2 instruction onto the computation.
1721 XlaOp Atan2(XlaOp y, XlaOp x,
1722 absl::Span<const int64> broadcast_dimensions = {});
1723
1724 // Enqueues an exp instruction onto the computation.
1725 XlaOp Exp(XlaOp operand);
1726
1727 // Enqueues an expm1 instruction onto the computation.
1728 XlaOp Expm1(XlaOp operand);
1729
1730 // Enqueues a floor instruction onto the computation.
1731 XlaOp Floor(XlaOp operand);
1732
1733 // Enqueues a ceil instruction onto the computation.
1734 XlaOp Ceil(XlaOp operand);
1735
1736 // Enqueues a round instruction onto the computation, rounding to nearest even
1737 // with half-way cases rounding away from zero.
1738 XlaOp Round(XlaOp operand);
1739
1740 // Enqueues an log instruction (natural logarithm) onto the computation.
1741 XlaOp Log(XlaOp operand);
1742
1743 // Enqueues an log1p instruction (log(x+1)) onto the computation.
1744 XlaOp Log1p(XlaOp operand);
1745
1746 // Enqueues a sign instruction onto the computation.
1747 XlaOp Sign(XlaOp operand);
1748
1749 // Enqueues a count leading zeros instruction onto the computation.
1750 XlaOp Clz(XlaOp operand);
1751
1752 // Enqueues a cosine instruction onto the computation.
1753 XlaOp Cos(XlaOp operand);
1754
1755 // Enqueues a sine instruction onto the computation.
1756 XlaOp Sin(XlaOp operand);
1757
1758 // Enqueues a tanh instruction onto the computation.
1759 XlaOp Tanh(XlaOp operand);
1760
1761 // Enqueues a real-part instruction onto the computation.
1762 XlaOp Real(XlaOp operand);
1763
1764 // Enqueues an imaginary-part instruction onto the computation.
1765 XlaOp Imag(XlaOp operand);
1766
1767 // Enqueues a sqrt computation onto the computation.
1768 XlaOp Sqrt(XlaOp operand);
1769
1770 // Enqueues a rsqrt computation onto the computation.
1771 XlaOp Rsqrt(XlaOp operand);
1772
1773 // Enqueues a lhs^rhs computation onto the computation.
1774 XlaOp Pow(XlaOp lhs, XlaOp rhs,
1775 absl::Span<const int64> broadcast_dimensions = {});
1776
1777 // Enqueues an operator that tests if the operand's values are finite, i.e., not
1778 // +/-Inf or NaN. Returns an array of booleans with the same shape where
1779 // entries are true iff the corresponding entry was not infinite or NaN.
1780 //
1781 // Defined only for real-valued (i.e. not complex) floating-point types; raises
1782 // an error for other types.
1783 //
1784 // See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h.
1785 XlaOp IsFinite(XlaOp operand);
1786
1787 // Enqueues an iota operation onto the computation.
1788 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension);
1789
1790 // Enqueues a rank-1 iota operation onto the computation.
1791 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
1792
1793 // Enqueues a convert instruction onto the computation that changes the
1794 // element type of the operand array to primitive_type.
1795 XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
1796
1797 // Enqueues a no-op instruction onto the computation that changes
1798 // the element type of the operand array to primitive_type. The
1799 // bit-widths of the source and destination element types must be
1800 // identical.
1801 XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
1802
1803 // Enqueues a negate instruction onto the computation.
1804 XlaOp Neg(XlaOp operand);
1805
1806 // Enqueues a transpose instruction onto the computation.
1807 XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
1808
1809 // Enqueues a reverse instruction onto the computation. The order of the
1810 // elements in the given dimensions is reversed (i.e., the element at index i
1811 // is moved to index dimension_size - 1 - i).
1812 XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
1813
1814 // Enqueues a sort instruction onto the computation, using 'comparator' for
1815 // comparisons. 'comparator' needs to define a strict weak order. 'is_stable'
1816 // determines whether the stable sorting should be used.
1817 // If only one operand is provided:
1818 // * If the operand is a rank-1 tensor (an array), the result is a sorted array.
1819 // The resulting sorting order has the property that for all index positions
1820 // i, j with i < j, either
1821 // comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or
1822 // comparator(value[i], value[j]) = true.
1823 // * If the operand has higher rank, the operand is sorted along the provided
1824 // dimension. For example, for a rank-2 tensor (a matrix), a dimension value
1825 // of 0 will independently sort every column, and a dimension value of 1 will
1826 // independently sort each row. If no dimension number is provided, then the
1827 // last dimension is chosen by default. For the dimension which is sorted, the
1828 // same sorting order applies as in the rank-1 case.
1829 //
1830 // If more than one operand is provided:
1831 // * All operands must be tensors with the same dimensions. The element types of
1832 // the tensors may be different.
1833 // * The result is a tuple that consists of the operands in sorted order (along
1834 // the provided dimension, as above). The same permutation as implied by the
1835 // comparison computation is applied to all operand tensors. When comparing
1836 // two index positions, 'comparator' is called with 2 * n scalar parameters,
1837 // where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at
1838 // two index positions.
1839 // Default comparator computations can be found in lib/comparators.h
1840 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
1841 int64 dimension = -1, bool is_stable = false);
1842
1843 // Enqueues a clamp instruction onto the computation.
1844 XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
1845
1846 // Enqueues a map instruction onto the computation.
1847 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1848 const XlaComputation& computation, absl::Span<const int64> dimensions,
1849 absl::Span<const XlaOp> static_operands = {});
1850
1851 // Enqueues a N(mu, sigma) random number generation instruction onto the
1852 // computation.
1853 XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
1854
1855 // Enqueues a U(a, b) random number generation instruction onto the
1856 // computation. Returns values in the semi-open interval [a, b).
1857 XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
1858
1859 // Enqueues a while node onto the computation.
1860 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
1861 XlaOp init);
1862
1863 // Enqueues a conditional node onto the computation.
1864 XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
1865 const XlaComputation& true_computation, XlaOp false_operand,
1866 const XlaComputation& false_computation);
1867
1868 // Enqueues either a predicated (if/else) or indexed (switch/case/default)
1869 // conditional node onto the computation. N >= 1 branch_computations and
1870 // branch_operands are matched by index. branch_index selects the branch that
1871 // will be executed. Out of range branch_index uses the N-1'th
1872 // branch_computation as default.
1873 XlaOp Conditional(XlaOp branch_index,
1874 absl::Span<const XlaComputation* const> branch_computations,
1875 absl::Span<const XlaOp> branch_operands);
1876
1877 // Enqueues a ReducePrecision node onto the computation.
1878 XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
1879 const int mantissa_bits);
1880
1881 // Enqueues a Gather node onto the computation.
1882 XlaOp Gather(XlaOp input, XlaOp start_indices,
1883 const GatherDimensionNumbers& dimension_numbers,
1884 absl::Span<const int64> slice_sizes,
1885 bool indices_are_sorted = false);
1886
1887 // Enqueues a Scatter node onto the computation.
1888 XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
1889 const XlaComputation& update_computation,
1890 const ScatterDimensionNumbers& dimension_numbers,
1891 bool indices_are_sorted = false, bool unique_indices = false);
1892
1893 // Enqueues a Send node onto the computation for device-to-device
1894 // communication. This operation sends the given operand to
1895 // a Recv instruction in a different computation that shares the same channel
1896 // handle.
1897 void Send(XlaOp operand, const ChannelHandle& handle);
1898
1899 // Variant of Send which takes a token-shaped operand and produces a
1900 // token-shaped value. Tokens are used for ordering side-effecting operations.
1901 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
1902 XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle);
1903
1904 // Enqueues a Recv node onto the computation for device-to-device
1905 // communication. The data comes from a Send instruction in a different
1906 // computation that shares the same channel handle and its shape must be the
1907 // same as the given shape.
1908 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
1909 const ChannelHandle& handle);
1910
1911 // Variant of Recv which takes a token-shaped operand and produces a two-element
1912 // tuple containing the data value and a token-shaped value. Tokens are used
1913 // for ordering side-effecting operations.
1914 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
1915 XlaOp RecvWithToken(XlaOp token, const Shape& shape,
1916 const ChannelHandle& handle);
1917
1918 // Enqueues a Send node which transfers data from the device to the host. The
1919 // 'shape_with_layout' argument defines the layout of the data transferred; its
1920 // shape must be compatible with the shape of the operand. The operand must be
1921 // array-shaped.
1922 // TODO(b/111544877): Support tuple shapes.
1923 XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout,
1924 const ChannelHandle& handle);
1925
1926 // Enqueues a Recv node which transfers data from the host to the device. The
1927 // given shape must contain a layout and must be an array.
1928 // TODO(b/111544877): Support tuple shapes.
1929 XlaOp RecvFromHost(XlaOp token, const Shape& shape,
1930 const ChannelHandle& handle);
1931
1932 // Enqueues an operation (AfterAll) with no operands that produces a
1933 // token-shaped value. Tokens are used for ordering side-effecting operations.
1934 // This is a separate method from AfterAll to facility the removal of
1935 // operand-less AfterAll instructions.
1936 // TODO(b/110532604): Remove this function when all tokens are derived from a
1937 // single token generated or passed into the entry computation.
1938 XlaOp CreateToken(XlaBuilder* builder);
1939
1940 // Enqueues an AfterAll instruction which produces a token-shaped value and
1941 // takes a variadic number of token-shaped operands. The number of operands must
1942 // be greater than zero. Used for joining tokens.
1943 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
1944
1945 // Normalizes operand across spatial and batch dimensions for each feature.
1946 //
1947 // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
1948 // is the normalized result and batch_mean and batch_var are the mean and
1949 // variance, respectively, across batch for the operand.
1950 XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, float epsilon,
1951 int64 feature_index);
1952
1953 // Normalizes operand across spatial and batch dimensions for each feature.
1954 //
1955 // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
1956 // computing `mean` and `variance` for each batch inside the operation. It
1957 // uses the input `mean` and `variance` instead as estimated values. The
1958 // purpose of this op is to reduce latency in inference, hence the name
1959 // `BatchNormInference`.
1960 //
1961 // The output has the same shape as `operand`, and contains the normalized
1962 // values for each batch.
1963 XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean,
1964 XlaOp variance, float epsilon, int64 feature_index);
1965
1966 // Calculates the gradients of a batch norm op.
1967 //
1968 // The inputs `batch_mean` and `batch_var` represent the mean and variance
1969 // across the batch.
1970 //
1971 // Returns a tuple of three elements:
1972 // - grad_operand: Gradient with respect to input `operand`
1973 // - grad_offset: Gradient with respect to input `offset`
1974 // - grad_scale: Gradient with respect to input `scale`
1975 XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
1976 XlaOp batch_var, XlaOp grad_output, float epsilon,
1977 int64 feature_index);
1978
1979 // Returns the size of the given dimension of the operand. The operand must be
1980 // array shaped.
1981 XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
1982
1983 XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
1984
1985 // Implementation details below this point.
1986 //
1987
1988 // Free function template implementations.
1989
1990 template <typename NativeT>
ConstantR0(XlaBuilder * builder,NativeT value)1991 XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
1992 return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
1993 }
1994
1995 template <typename NativeT>
ConstantR1(XlaBuilder * builder,absl::Span<const NativeT> values)1996 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
1997 BorrowingLiteral literal(
1998 reinterpret_cast<const char*>(values.begin()),
1999 ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
2000 {static_cast<int64>(values.size())}));
2001 return ConstantLiteral(builder, literal);
2002 }
2003
2004 template <typename NativeT>
ConstantR1(XlaBuilder * builder,int64 length,NativeT value)2005 XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
2006 Literal literal(ShapeUtil::MakeShape(
2007 primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
2008 literal.PopulateWithValue(value);
2009 return ConstantLiteral(builder, literal);
2010 }
2011
ConstantR1(XlaBuilder * builder,const tensorflow::core::Bitmap & values)2012 inline XlaOp ConstantR1(XlaBuilder* builder,
2013 const tensorflow::core::Bitmap& values) {
2014 return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
2015 }
2016
2017 template <typename NativeT>
ConstantR2(XlaBuilder * builder,std::initializer_list<std::initializer_list<NativeT>> values)2018 XlaOp ConstantR2(XlaBuilder* builder,
2019 std::initializer_list<std::initializer_list<NativeT>> values) {
2020 return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
2021 }
2022
2023 template <typename NativeT>
ConstantFromArrayWithLayout(XlaBuilder * builder,const Array<NativeT> & values,const Layout & layout)2024 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
2025 const Array<NativeT>& values,
2026 const Layout& layout) {
2027 return ConstantLiteral(
2028 builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
2029 }
2030
2031 template <typename NativeT>
ConstantFromArray(XlaBuilder * builder,const Array<NativeT> & values)2032 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
2033 return ConstantLiteral(builder,
2034 LiteralUtil::CreateFromArray<NativeT>(values));
2035 }
2036
2037 template <typename NativeT>
ConstantR2FromArray2DWithLayout(XlaBuilder * builder,const Array2D<NativeT> & values,const Layout & layout)2038 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
2039 const Array2D<NativeT>& values,
2040 const Layout& layout) {
2041 return ConstantLiteral(
2042 builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
2043 }
2044
2045 template <typename NativeT>
ConstantR2FromArray2D(XlaBuilder * builder,const Array2D<NativeT> & values)2046 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
2047 const Array2D<NativeT>& values) {
2048 return ConstantLiteral(builder,
2049 LiteralUtil::CreateR2FromArray2D<NativeT>(values));
2050 }
2051
2052 template <typename NativeT>
ConstantR3FromArray3DWithLayout(XlaBuilder * builder,const Array3D<NativeT> & values,const Layout & layout)2053 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
2054 const Array3D<NativeT>& values,
2055 const Layout& layout) {
2056 return ConstantLiteral(
2057 builder,
2058 LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
2059 }
2060
2061 template <typename NativeT>
ConstantR3FromArray3D(XlaBuilder * builder,const Array3D<NativeT> & values)2062 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
2063 const Array3D<NativeT>& values) {
2064 return ConstantFromArray(builder, values);
2065 }
2066
2067 template <typename NativeT>
ConstantR4FromArray4DWithLayout(XlaBuilder * builder,const Array4D<NativeT> & values,const Layout & layout)2068 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
2069 const Array4D<NativeT>& values,
2070 const Layout& layout) {
2071 return ConstantFromArrayWithLayout(builder, values, layout);
2072 }
2073
2074 template <typename NativeT>
ConstantR4FromArray4D(XlaBuilder * builder,const Array4D<NativeT> & values)2075 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
2076 const Array4D<NativeT>& values) {
2077 return ConstantFromArray(builder, values);
2078 }
2079
2080 } // namespace xla
2081
2082 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
2083