1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
18
19 #include <map>
20 #include <string>
21 #include <type_traits>
22 #include <utility>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/client/padding.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/comparison_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/literal_util.h"
33 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
34 #include "tensorflow/compiler/xla/service/hlo.pb.h"
35 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status_macros.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/stacktrace.h"
43 #include "tensorflow/core/platform/types.h"
44
45 namespace xla {
46
47 class XlaBuilder;
48
49 // This represents an instruction that has been enqueued using the XlaBuilder.
50 // This is used to pass to subsequent computations that depends upon the
51 // instruction as an operand.
52 class XlaOp {
53 public:
XlaOp()54 XlaOp() : handle_(-1), builder_(nullptr) {
55 static_assert(std::is_trivially_destructible<XlaOp>::value,
56 "XlaOp should be trivially destructible");
57 }
58 ~XlaOp() = default;
59
60 XlaOp(const XlaOp& other) = default;
61 XlaOp& operator=(const XlaOp& other) = default;
62
63 // Precondition: !IsUninitialized().
64 //
65 // It's very common to do foo.builder()->bar(). Without this precondition, if
66 // foo.builder() is null, the call to bar will segfault at some point possibly
67 // deep in the callstack when we finally dereference `this`. The precondition
68 // lets us avoid this tricky-to-debug problem.
builder()69 XlaBuilder* builder() const {
70 CHECK(builder_ != nullptr);
71 return builder_;
72 }
73
74 // Returns true if the XlaOp represents valid, non-erroneous value.
valid()75 bool valid() const { return handle_ >= 0; }
76
77 // Returns true if the XlaOp was created by the XlaOp() constructor and
78 // not returned by a builder.
IsUninitialized()79 bool IsUninitialized() const { return builder_ == nullptr; }
80
IsIdenticalTo(const XlaOp & rhs)81 bool IsIdenticalTo(const XlaOp& rhs) const {
82 return handle_ == rhs.handle_ && builder_ == rhs.builder_;
83 }
84
85 friend std::ostream& operator<<(std::ostream& out, const XlaOp& op) {
86 out << op.handle();
87 return out;
88 }
89
90 private:
XlaOp(XlaBuilder * builder)91 explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {}
XlaOp(int64 handle,XlaBuilder * builder)92 XlaOp(int64 handle, XlaBuilder* builder)
93 : handle_(handle), builder_(builder) {}
94
handle()95 int64 handle() const { return handle_; }
96
97 friend class XlaBuilder;
98
99 // < 0 means "invalid handle".
100 int64 handle_;
101
102 // Not owned. Non-null for any handle returned by XlaBuilder, even if the
103 // handle is invalid.
104 XlaBuilder* builder_;
105 };
106
107 // Arithmetic operator overloads for the XlaOp type.
108 XlaOp operator-(const XlaOp& x);
109 XlaOp operator+(const XlaOp& x, const XlaOp& y);
110 XlaOp operator-(const XlaOp& x, const XlaOp& y);
111 XlaOp operator*(const XlaOp& x, const XlaOp& y);
112 XlaOp operator/(const XlaOp& x, const XlaOp& y);
113 XlaOp operator%(const XlaOp& x, const XlaOp& y);
114
115 // Bitwise operator overloads for the XlaOp type.
116 XlaOp operator~(const XlaOp& x);
117 XlaOp operator&(const XlaOp& x, const XlaOp& y);
118 XlaOp operator|(const XlaOp& x, const XlaOp& y);
119 XlaOp operator^(const XlaOp& x, const XlaOp& y);
120 XlaOp operator<<(const XlaOp& x, const XlaOp& y);
121 // Performs a right arithmetic shift if 'x' is a signed type, otherwise performs
122 // a right logical shift.
123 XlaOp operator>>(const XlaOp& x, const XlaOp& y);
124
125 // We don't overload the relational operators (==, !=, <, <=, >, >=) because the
126 // semantics might be surprising since their result types are usually 'bool'.
127 // Further programmers may expect == to be a structural equality.
128 // We also choose not to overload any of the mutating operators (e.g., +=, -=)
129 // because the semantics might be misleading — XLA computations are immutable.
130
131 // A convenient interface for building up computations.
132 //
133 // Thread-compatible.
134 class XlaBuilder {
135 public:
136 // computation_name: name to use for the built computation.
137 XlaBuilder(const string& computation_name);
138
139 XlaBuilder(const XlaBuilder&) = delete;
140 XlaBuilder& operator=(const XlaBuilder&) = delete;
141
142 ~XlaBuilder();
143
144 // Returns the computation name.
name()145 const string& name() const { return name_; }
146
147 // Sets OpMetadata that will be added to all instructions until cleared.
148 //
149 // OpMetadata is often applied to a series of XLA HLO instructions. As a
150 // result, OpMetadata is set on the Computation Builder. All subsequent
151 // instructions generated via this Computation Builder will have the same
152 // OpMetadata attached until a call to ClearOpMetadata.
SetOpMetadata(const OpMetadata & metadata)153 void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; }
154
155 // Clears the HloMetadata state.
ClearOpMetadata()156 void ClearOpMetadata() { metadata_.Clear(); }
157
158 // Sets an OpSharding that will be attached to all instructions until cleared.
SetSharding(const OpSharding & sharding)159 void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
160
161 // Clears the sharding. Ops will be sharded according to the default placement
162 // policy.
ClearSharding()163 void ClearSharding() { sharding_ = absl::nullopt; }
164
165 // Returns the OpSharding that will be attached to all instructions.
sharding()166 const absl::optional<OpSharding>& sharding() const { return sharding_; }
167
168 // Sets the builder to a mode where it will die immediately when an error is
169 // encountered, rather than producing it in a deferred fashion when Build() is
170 // called (which is the default).
set_die_immediately_on_error(bool enabled)171 void set_die_immediately_on_error(bool enabled) {
172 die_immediately_on_error_ = enabled;
173 }
174
175 // Default dimension numbers used for a 2D convolution.
176 static constexpr int64 kConvBatchDimension = 0;
177 static constexpr int64 kConvFeatureDimension = 1;
178 static constexpr int64 kConvFirstSpatialDimension = 2;
179 static constexpr int64 kConvSecondSpatialDimension = 3;
180 static constexpr int64 kConvKernelOutputDimension = 0;
181 static constexpr int64 kConvKernelInputDimension = 1;
182 static constexpr int64 kConvKernelFirstSpatialDimension = 2;
183 static constexpr int64 kConvKernelSecondSpatialDimension = 3;
184
185 // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
186 // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
187 // the kernel operand
188 // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
189 static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
190 int num_spatial_dims = 2);
191
192 // Returns an error if the convolution dimension numbers have conflicts.
193 static Status Validate(const ConvolutionDimensionNumbers& dnum);
194
195 // Returns a new XlaBuilder whose resultant Computation is used only by this
196 // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
197 // behavior as the parent.
198 std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
199
200 // Builds the computation with the requested operations, or returns a non-ok
201 // status. Note that all ops that have been enqueued will be moved to the
202 // computation being returned. The root of the computation will be the last
203 // added operation.
204 //
205 // `remove_dynamic_dimensions` tells the builder whether to remove the
206 // dyanmic dimensions information in all ops.
207 //
208 // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the
209 // dynamic dimensions information when XLA backend can handle dynamic
210 // dimensions.
211 StatusOr<XlaComputation> Build(bool remove_dynamic_dimensions = true);
212
213 // Overload of Build which specifies a particular root instruction for the
214 // computation.
215 StatusOr<XlaComputation> Build(XlaOp root,
216 bool remove_dynamic_dimensions = true);
217
218 // Builds the computation with the requested operations, or notes an error in
219 // the parent XlaBuilder and returns an empty computation if building failed.
220 // This function is intended to be used where the returned XlaComputation is
221 // only used by the parent XlaBuilder and hence further operation on the
222 // returned XlaComputation will simply be error'ed out if an error occurred
223 // while building this computation. If the built computation is to be used by
224 // a XlaBuilder other than the parent XlaBuilder then Build() should be used
225 // instead.
226 XlaComputation BuildAndNoteError();
227
228 // Returns a subgraph that roots on the given root. If the root is not a
229 // compile-time constant (see `IsConstant`), returns an error.
230 //
231 // This will copy the needed ops/computations to the subgraph.
232 StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op);
233
234 // Returns the first error that was encountered while building the
235 // computation. When an error is encountered, by default we return a vacuous
236 // XlaOp and inform the user of the error that occurred while
237 // building the computation when they make a final call to Build().
238 //
239 // See also set_die_immediately_on_error().
first_error()240 Status first_error() const { return first_error_; }
241
242 // Returns the current status of the builder, complete with the stack trace
243 // information.
244 Status GetCurrentStatus() const;
245
246 // Returns the shape of the given op.
247 StatusOr<Shape> GetShape(const XlaOp& op) const;
248
249 // Returns the (inferred) result for the current computation's shape. This
250 // assumes the root instruction is the last added instruction.
251 StatusOr<ProgramShape> GetProgramShape() const;
252
253 // Returns the (inferred) result for the current computation's shape using the
254 // given operation as the root.
255 StatusOr<ProgramShape> GetProgramShape(XlaOp root) const;
256
257 // Reports an error to the builder, by
258 // * storing it internally and capturing a backtrace if it's the first error
259 // (this deferred value will be produced on the call to
260 // Build()/GetShape()/...)
261 // * dying if die_immediately_on_error_ is true.
262 // Returns an XlaOp with an invalid handle but a valid builder. This value can
263 // be returned in place of a value in APIs that return an XlaOp.
264 XlaOp ReportError(const Status& error);
265
266 // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
267 // If the Status was an error, reports the error to builder and returns an
268 // invalid XlaOp handle.
269 XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
270
271 // A helper function that runs a function that returns a StatusOr<XlaOp> and
272 // returns an XlaOp.
273 XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
274
275 // Returns true if 'operand' is a compile-time constant. A compile-time
276 // constant does not depend on any parameters, or on stateful operators such
277 // as `RngNormal` or `Infeed`.
278 //
279 // This tests whether a computation is a compile-time constant without
280 // evaluating the computation.
281 StatusOr<bool> IsConstant(const XlaOp& operand) const;
282
283 // Sets up binding which indicates that the `target_dim_num` in the subshape
284 // `target_param_index` of parameter `target_param_num` is a dynamic dimension
285 // and its real dynamic size is represented by `dynamic_param_index` in
286 // parameter `dynamic_param_num`.
287 //
288 // Note that this should be called before the dynamic parameters are used to
289 // create other operations, otherwise created operations won't have the
290 // dynamic dimensions information.
291 //
292 // TODO(b/119520625): Remove this API once we have more dynamic shape infra
293 // ready.
294 Status SetDynamicBinding(int64 dynamic_size_param_num,
295 ShapeIndex dynamic_size_param_index,
296 int64 target_param_num,
297 ShapeIndex target_param_index, int64 target_dim_num);
298
299 // Adds a new input/output alias. Since the input/ouput shape information are
300 // not available until the computation is built, and eventual error in the
301 // arguments of this API will be detected only at computation Build() time.
SetUpAlias(const ShapeIndex & output_index,int64 param_number,const ShapeIndex & param_index)302 void SetUpAlias(const ShapeIndex& output_index, int64 param_number,
303 const ShapeIndex& param_index) {
304 input_output_aliases_.push_back({output_index, param_number, param_index});
305 }
306
307 // Describes an input/output alias as inserted by the SetUpAlias() API.
308 struct InputOutputAlias {
309 // Specifies the index of the aliased buffer in the result tuple.
310 ShapeIndex output_index;
311 // Specifies the parameter containing the buffer to be aliased.
312 int64 param_number;
313 // Specifies the index of the aliased buffer in the parameter
314 ShapeIndex param_index;
315 };
316
317 private:
318 // Build helper which takes the id of the root operation..
319 StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions);
320
321 // Description for the methods below can be found in the corresponding public
322 // functions section in this file.
323
324 XlaOp Parameter(int64 parameter_number, const Shape& shape,
325 const string& name);
326
327 XlaOp ConstantLiteral(const LiteralSlice& literal);
328
329 XlaOp Broadcast(const XlaOp& operand,
330 absl::Span<const int64> broadcast_sizes);
331
332 XlaOp BroadcastInDim(const XlaOp& operand,
333 const absl::Span<const int64> out_dim_size,
334 const absl::Span<const int64> broadcast_dimensions);
335
336 XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
337 const PaddingConfig& padding_config);
338
339 XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
340 absl::Span<const int64> new_sizes);
341
342 XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);
343
344 XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions);
345
346 XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
347 absl::Span<const int64> limit_indices,
348 absl::Span<const int64> strides);
349
350 XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
351 int64 stride, int64 dimno);
352
353 ABSL_DEPRECATED("Use span-of-indices form instead")
354 XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
355 absl::Span<const int64> slice_sizes);
356 XlaOp DynamicSlice(const XlaOp& operand,
357 absl::Span<const XlaOp> start_indices,
358 absl::Span<const int64> slice_sizes);
359
360 ABSL_DEPRECATED("Use span-of-indices form instead")
361 XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
362 const XlaOp& start_indices);
363 XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
364 absl::Span<const XlaOp> start_indices);
365
366 XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
367
368 void Trace(const string& tag, const XlaOp& operand);
369
370 XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
371
372 XlaOp Tuple(absl::Span<const XlaOp> elements);
373
374 XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
375
376 XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
377 const PrecisionConfig* precision_config = nullptr);
378
379 XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
380 const DotDimensionNumbers& dimension_numbers,
381 const PrecisionConfig* precision_config = nullptr);
382
383 XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
384 absl::Span<const int64> window_strides, Padding padding,
385 int64 feature_group_count = 1, int64 batch_group_count = 1,
386 const PrecisionConfig* precision_config = nullptr);
387
388 XlaOp ConvWithGeneralPadding(
389 const XlaOp& lhs, const XlaOp& rhs,
390 absl::Span<const int64> window_strides,
391 absl::Span<const std::pair<int64, int64>> padding,
392 int64 feature_group_count = 1, int64 batch_group_count = 1,
393 const PrecisionConfig* precision_config = nullptr);
394
395 XlaOp ConvWithGeneralDimensions(
396 const XlaOp& lhs, const XlaOp& rhs,
397 absl::Span<const int64> window_strides, Padding padding,
398 const ConvolutionDimensionNumbers& dimension_numbers,
399 int64 feature_group_count = 1, int64 batch_group_count = 1,
400 const PrecisionConfig* precision_config = nullptr);
401
402 XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
403 absl::Span<const int64> window_strides,
404 absl::Span<const std::pair<int64, int64>> padding,
405 const ConvolutionDimensionNumbers& dimension_numbers,
406 int64 feature_group_count = 1, int64 batch_group_count = 1,
407 const PrecisionConfig* precision_config = nullptr);
408
409 XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
410 absl::Span<const int64> window_strides,
411 absl::Span<const std::pair<int64, int64>> padding,
412 absl::Span<const int64> lhs_dilation,
413 absl::Span<const int64> rhs_dilation,
414 const ConvolutionDimensionNumbers& dimension_numbers,
415 int64 feature_group_count = 1,
416 int64 batch_group_count = 1,
417 const PrecisionConfig* precision_config = nullptr);
418
419 XlaOp Fft(const XlaOp& operand, FftType fft_type,
420 absl::Span<const int64> fft_length);
421
422 XlaOp Infeed(const Shape& shape, const string& config = "");
423 XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
424 const string& config = "");
425
426 void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
427 const string& outfeed_config);
428 XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
429 const Shape& shape_with_layout,
430 const string& outfeed_config);
431
432 XlaOp Call(const XlaComputation& computation,
433 absl::Span<const XlaOp> operands);
434
435 XlaOp CustomCall(
436 const string& call_target_name, absl::Span<const XlaOp> operands,
437 const Shape& shape_with_layout, const string& opaque,
438 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
439
440 XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
441 const XlaComputation& computation,
442 absl::Span<const int64> dimensions_to_reduce);
443
444 XlaOp Reduce(absl::Span<const XlaOp> operands,
445 absl::Span<const XlaOp> init_values,
446 const XlaComputation& computation,
447 absl::Span<const int64> dimensions_to_reduce);
448
449 XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
450 const XlaComputation& computation);
451
452 XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
453 const XlaComputation& computation,
454 absl::Span<const int64> window_dimensions,
455 absl::Span<const int64> window_strides, Padding padding);
456
457 XlaOp ReduceWindowWithGeneralPadding(
458 const XlaOp& operand, const XlaOp& init_value,
459 const XlaComputation& computation,
460 absl::Span<const int64> window_dimensions,
461 absl::Span<const int64> window_strides,
462 absl::Span<const int64> base_dilations,
463 absl::Span<const int64> window_dilations,
464 absl::Span<const std::pair<int64, int64>> padding);
465
466 XlaOp CrossReplicaSum(const XlaOp& operand,
467 absl::Span<const ReplicaGroup> replica_groups = {});
468
469 XlaOp CrossReplicaSum(
470 const XlaOp& operand, const XlaComputation& computation,
471 absl::Span<const ReplicaGroup> replica_groups = {},
472 const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
473
474 XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
475 int64 concat_dimension, int64 split_count,
476 const std::vector<ReplicaGroup>& replica_groups);
477
478 XlaOp CollectivePermute(
479 const XlaOp& operand,
480 const std::vector<std::pair<int64, int64>>& source_target_pairs);
481
482 XlaOp ReplicaId();
483
484 XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
485 absl::Span<const int64> window_dimensions,
486 absl::Span<const int64> window_strides,
487 Padding padding, const XlaOp& source,
488 const XlaOp& init_value,
489 const XlaComputation& scatter);
490
491 XlaOp SelectAndScatterWithGeneralPadding(
492 const XlaOp& operand, const XlaComputation& select,
493 absl::Span<const int64> window_dimensions,
494 absl::Span<const int64> window_strides,
495 absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
496 const XlaOp& init_value, const XlaComputation& scatter);
497
498 XlaOp Iota(const Shape& shape, int64 iota_dimension);
499
500 XlaOp Iota(PrimitiveType type, int64 size);
501
502 XlaOp ConvertElementType(const XlaOp& operand,
503 PrimitiveType new_element_type);
504
505 XlaOp BitcastConvertType(const XlaOp& operand,
506 PrimitiveType new_element_type);
507
508 XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation);
509
510 XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
511
512 ABSL_DEPRECATED("Use form with comparator computation instead")
513 XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {},
514 int64 dimension = -1);
515 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
516 int64 dimension = -1, bool is_stable = false);
517
518 XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
519
520 XlaOp Map(absl::Span<const XlaOp> operands, const XlaComputation& computation,
521 absl::Span<const int64> dimensions,
522 absl::Span<const XlaOp> static_operands = {});
523
524 XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape);
525
526 XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
527
528 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
529 const XlaOp& init);
530
531 XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
532 const XlaComputation& true_computation,
533 const XlaOp& false_operand,
534 const XlaComputation& false_computation);
535
536 XlaOp Conditional(const XlaOp& branch_index,
537 absl::Span<const XlaComputation* const> branch_computations,
538 absl::Span<const XlaOp> branch_operands);
539
540 XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
541 const int mantissa_bits);
542
543 XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
544 const GatherDimensionNumbers& dimension_numbers,
545 absl::Span<const int64> slice_sizes);
546
547 XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
548 const XlaOp& updates, const XlaComputation& update_computation,
549 const ScatterDimensionNumbers& dimension_numbers);
550
551 void Send(const XlaOp& operand, const ChannelHandle& handle);
552 XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
553 const ChannelHandle& handle);
554
555 XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
556 const Shape& shape_with_layout, const ChannelHandle& handle);
557
558 XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
559 const ChannelHandle& handle);
560
561 XlaOp CreateToken();
562
563 XlaOp AfterAll(absl::Span<const XlaOp> tokens);
564
565 XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
566 XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
567 const ChannelHandle& handle);
568
569 XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
570 const XlaOp& offset, float epsilon,
571 int64 feature_index);
572
573 XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
574 const XlaOp& offset, const XlaOp& mean,
575 const XlaOp& variance, float epsilon,
576 int64 feature_index);
577
578 XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
579 const XlaOp& batch_mean, const XlaOp& batch_var,
580 const XlaOp& grad_output, float epsilon,
581 int64 feature_index);
582
583 XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension);
584
585 StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
586 absl::Span<const XlaOp> operands = {});
587
588 void AddCalledComputation(const XlaComputation& computation,
589 HloInstructionProto* instr);
590
591 StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
592 StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
593 int64 handle) const;
594
595 // Internal helper method that does the building for an arbitrary unary op.
596 XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
597
598 // Internal helper method that does the building for an arbitrary binary op.
599 // broadcast_dimensions specifies which dimensions to use for broadcasting
600 // when the operation is between tensors of different ranks. The direction is
601 // only used if opcode is kCompare.
602 XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
603 absl::Span<const int64> broadcast_dimensions,
604 absl::optional<ComparisonDirection> direction = absl::nullopt);
605
606 // Internal helper method that does the building for an arbitrary ternary op.
607 XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
608 const XlaOp& ehs);
609
610 XlaOp RngOp(RandomDistribution distribution,
611 absl::Span<const XlaOp> parameters, const Shape& shape);
612
613 StatusOr<XlaOp> InDimBroadcast(const Shape& shape, const XlaOp& operand,
614 absl::Span<const int64> broadcast_dimensions);
615
616 // Internal helper method that creates a sequence of instructions that
617 // performs an explicit broadcast of the operand to the target shape.
618 StatusOr<XlaOp> AddBroadcastSequence(const Shape& output_shape,
619 const XlaOp& operand);
620
621 // Internal helper method for creating a Reshape op with the already inferred
622 // shape.
623 StatusOr<XlaOp> Reshape(const Shape& shape, const XlaOp& operand);
624
625 // Returns the (inferred) result for the program shape using the given root.
626 StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;
627
628 // Returns shapes for the operands.
629 StatusOr<std::vector<Shape>> GetOperandShapes(
630 absl::Span<const XlaOp> operands) const;
631
632 // A visitor which checks whether an operation is a compile-time constant,
633 // meaning that it doesn't depend on any parameters, or on any stateful
634 // operation such as `RngNormal` or `Infeed`. The visitor walks the
635 // computation starting at a given operation and sets is_constant to false iff
636 // a parameter or stateful operation is encountered.
637 void IsConstantVisitor(const int64 op_handle,
638 absl::flat_hash_set<int64>* visited,
639 bool* is_constant) const;
640
641 // Checks bounds for convolution parameters.
642 Status VerifyConvolution(
643 const Shape& lhs_shape, const Shape& rhs_shape,
644 const ConvolutionDimensionNumbers& dimension_numbers) const;
645
646 // Helper function for creating a Window proto from user-supplied data.
647 // Returns error if the user-supplied data was invalid.
648 StatusOr<Window> MakeWindow(absl::Span<const int64> window_dimensions,
649 absl::Span<const int64> window_strides,
650 absl::Span<const std::pair<int64, int64>> padding,
651 absl::Span<const int64> lhs_dilation,
652 absl::Span<const int64> rhs_dilation) const;
653
GetNextId()654 int64 GetNextId() { return ++next_id_; }
655
656 // Populates the module with the input/output alias information stored within
657 // the input_output_aliases vector.
658 static Status PopulateInputOutputAlias(
659 HloModuleProto* module, const ProgramShape& program_shape,
660 const std::vector<InputOutputAlias>& input_output_aliases);
661
662 string name_; // Name to use for the built computation.
663
664 // The next sequential ID for every instruction/computation contained within
665 // this computation.
666 int64 next_id_ = 0;
667
668 // The first error encountered while building the computation.
669 // This is OK until the first error is encountered.
670 Status first_error_;
671
672 // The saved stack trace from the point at which the first error occurred.
673 tensorflow::SavedStackTrace first_error_backtrace_;
674
675 // The instructions of this computation.
676 std::vector<HloInstructionProto> instructions_;
677
678 // Dynamic parameter configuration of this computation.
679 DynamicParameterBinding dynamic_parameter_binding_;
680
681 // Holds the input/output alias information populated by the SetUpAlias() API.
682 std::vector<InputOutputAlias> input_output_aliases_;
683
684 // A map from XlaOp::Handle to the index in the instructions_ vector where the
685 // instruction is held.
686 absl::flat_hash_map<int64, int64> handle_to_index_;
687
688 // The embedded computations used by this computation. Each computation was
689 // the entry computation of some XlaComputation, the key is the unique id of
690 // that XlaComputation.
691 std::map<int64, HloComputationProto> embedded_;
692
693 // The unique parameter numbers.
694 absl::flat_hash_set<int64> parameter_numbers_;
695
696 // The metadata to attach to each op. This is structured as a "modal"-like
697 // operation, in order to simplify client code (and not sprinkle this metadata
698 // throughout the TensorFlow op kernel implementations).
699 OpMetadata metadata_;
700
701 // Sharding for this operator. This is structured as a "model"-like operation,
702 // in order to simplify client code, similar to metadata_.
703 absl::optional<OpSharding> sharding_;
704
705 // Mode bit that indicates whether to die when a first error is encountered.
706 bool die_immediately_on_error_ = false;
707
708 XlaBuilder* parent_builder_{nullptr};
709
710 friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number,
711 const Shape& shape, const string& name);
712 friend XlaOp ConstantLiteral(XlaBuilder* builder,
713 const LiteralSlice& literal);
714
715 friend XlaOp Broadcast(const XlaOp& operand,
716 absl::Span<const int64> broadcast_sizes);
717
718 friend XlaOp BroadcastInDim(
719 const XlaOp& operand, const absl::Span<const int64> out_dim_size,
720 const absl::Span<const int64> broadcast_dimensions);
721
722 friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
723 const PaddingConfig& padding_config);
724
725 friend XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
726 absl::Span<const int64> new_sizes);
727
728 friend XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);
729
730 friend XlaOp Collapse(const XlaOp& operand,
731 absl::Span<const int64> dimensions);
732
733 friend XlaOp Slice(const XlaOp& operand,
734 absl::Span<const int64> start_indices,
735 absl::Span<const int64> limit_indices,
736 absl::Span<const int64> strides);
737
738 friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index,
739 int64 limit_index, int64 stride, int64 dimno);
740
741 friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
742 absl::Span<const int64> slice_sizes);
743 friend XlaOp DynamicSlice(const XlaOp& operand,
744 absl::Span<const XlaOp> start_indices,
745 absl::Span<const int64> slice_sizes);
746
747 friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
748 const XlaOp& start_indices);
749 friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
750 absl::Span<const XlaOp> start_indices);
751
752 friend XlaOp ConcatInDim(XlaBuilder* builder,
753 absl::Span<const XlaOp> operands, int64 dimension);
754
755 friend void Trace(const string& tag, const XlaOp& operand);
756
757 friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true,
758 const XlaOp& on_false);
759 friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
760 friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
761 friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
762 absl::Span<const int64> broadcast_dimensions);
763 friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
764 absl::Span<const int64> broadcast_dimensions);
765 friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
766 absl::Span<const int64> broadcast_dimensions);
767 friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
768 absl::Span<const int64> broadcast_dimensions);
769 friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
770 absl::Span<const int64> broadcast_dimensions);
771 friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
772 absl::Span<const int64> broadcast_dimensions);
773 friend XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs,
774 absl::Span<const int64> broadcast_dimensions,
775 ComparisonDirection direction);
776 friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
777 const PrecisionConfig* precision_config);
778 friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
779 const DotDimensionNumbers& dimension_number,
780 const PrecisionConfig* precision_config);
781 friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
782 absl::Span<const int64> window_strides, Padding padding,
783 int64 feature_group_count, int64 batch_group_count,
784 const PrecisionConfig* precision_config);
785 friend XlaOp ConvWithGeneralPadding(
786 const XlaOp& lhs, const XlaOp& rhs,
787 absl::Span<const int64> window_strides,
788 absl::Span<const std::pair<int64, int64>> padding,
789 int64 feature_group_count, int64 batch_group_count,
790 const PrecisionConfig* precision_config);
791 friend XlaOp ConvWithGeneralDimensions(
792 const XlaOp& lhs, const XlaOp& rhs,
793 absl::Span<const int64> window_strides, Padding padding,
794 const ConvolutionDimensionNumbers& dimension_numbers,
795 int64 feature_group_count, int64 batch_group_count,
796 const PrecisionConfig* precision_config);
797 friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
798 absl::Span<const int64> window_strides,
799 absl::Span<const std::pair<int64, int64>> padding,
800 const ConvolutionDimensionNumbers& dimension_numbers,
801 int64 feature_group_count, int64 batch_group_count,
802 const PrecisionConfig* precision_config);
803 friend XlaOp ConvGeneralDilated(
804 const XlaOp& lhs, const XlaOp& rhs,
805 absl::Span<const int64> window_strides,
806 absl::Span<const std::pair<int64, int64>> padding,
807 absl::Span<const int64> lhs_dilation,
808 absl::Span<const int64> rhs_dilation,
809 const ConvolutionDimensionNumbers& dimension_numbers,
810 int64 feature_group_count, int64 batch_group_count,
811 const PrecisionConfig* precision_config);
812 friend XlaOp Fft(const XlaOp& operand, FftType fft_type,
813 absl::Span<const int64> fft_length);
814 friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
815 bool unit_diagonal,
816 TriangularSolveOptions::Transpose transpose_a);
817 friend XlaOp Cholesky(XlaOp a, bool lower);
818 friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
819 const string& config);
820 friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
821 const string& outfeed_config);
822 friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
823 absl::Span<const XlaOp> operands);
824 friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
825 absl::Span<const XlaOp> operands, const Shape& shape,
826 const string& opaque);
827 friend XlaOp CustomCallWithLayout(
828 XlaBuilder* builder, const string& call_target_name,
829 absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
830 absl::Span<const Shape> operand_shapes_with_layout, const string& opaque);
831 friend XlaOp Complex(const XlaOp& real, const XlaOp& imag,
832 absl::Span<const int64> broadcast_dimensions);
833 friend XlaOp Conj(const XlaOp& operand);
834 friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
835 absl::Span<const int64> broadcast_dimensions);
836 friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
837 absl::Span<const int64> broadcast_dimensions);
838 friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
839 absl::Span<const int64> broadcast_dimensions);
840 friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
841 absl::Span<const int64> broadcast_dimensions);
842 friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
843 absl::Span<const int64> broadcast_dimensions);
844 friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
845 absl::Span<const int64> broadcast_dimensions);
846 friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
847 absl::Span<const int64> broadcast_dimensions);
848 friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
849 absl::Span<const int64> broadcast_dimensions);
850 friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
851 absl::Span<const int64> broadcast_dimensions);
852 friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
853 absl::Span<const int64> broadcast_dimensions);
854 friend XlaOp Not(const XlaOp& operand);
855 friend XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
856 absl::Span<const int64> broadcast_dimensions);
857 friend XlaOp ShiftRightArithmetic(
858 const XlaOp& lhs, const XlaOp& rhs,
859 absl::Span<const int64> broadcast_dimensions);
860 friend XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
861 absl::Span<const int64> broadcast_dimensions);
862 friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
863 const XlaComputation& computation,
864 absl::Span<const int64> dimensions_to_reduce);
865 friend XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
866 absl::Span<const XlaOp> init_values,
867 const XlaComputation& computation,
868 absl::Span<const int64> dimensions_to_reduce);
869 friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
870 const XlaComputation& computation);
871 friend XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
872 const XlaComputation& computation,
873 absl::Span<const int64> window_dimensions,
874 absl::Span<const int64> window_strides,
875 Padding padding);
876 friend XlaOp ReduceWindowWithGeneralPadding(
877 const XlaOp& operand, const XlaOp& init_value,
878 const XlaComputation& computation,
879 absl::Span<const int64> window_dimensions,
880 absl::Span<const int64> window_strides,
881 absl::Span<const int64> base_dilations,
882 absl::Span<const int64> window_dilations,
883 absl::Span<const std::pair<int64, int64>> padding);
884 friend XlaOp CrossReplicaSum(const XlaOp& operand,
885 absl::Span<const ReplicaGroup> replica_groups);
886 friend XlaOp CrossReplicaSum(const XlaOp& operand,
887 const XlaComputation& computation,
888 absl::Span<const ReplicaGroup> replica_groups,
889 const absl::optional<ChannelHandle>& channel_id);
890 friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
891 int64 concat_dimension, int64 split_count,
892 const std::vector<ReplicaGroup>& replica_groups);
893 friend XlaOp CollectivePermute(
894 const XlaOp& operand,
895 const std::vector<std::pair<int64, int64>>& source_target_pairs);
896 friend XlaOp ReplicaId(XlaBuilder* builder);
897 friend XlaOp SelectAndScatter(const XlaOp& operand,
898 const XlaComputation& select,
899 absl::Span<const int64> window_dimensions,
900 absl::Span<const int64> window_strides,
901 Padding padding, const XlaOp& source,
902 const XlaOp& init_value,
903 const XlaComputation& scatter);
904 friend XlaOp SelectAndScatterWithGeneralPadding(
905 const XlaOp& operand, const XlaComputation& select,
906 absl::Span<const int64> window_dimensions,
907 absl::Span<const int64> window_strides,
908 absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
909 const XlaOp& init_value, const XlaComputation& scatter);
910 friend XlaOp Abs(const XlaOp& operand);
911 friend XlaOp Atan2(const XlaOp& y, const XlaOp& x,
912 absl::Span<const int64> broadcast_dimensions);
913 friend XlaOp Exp(const XlaOp& operand);
914 friend XlaOp Expm1(const XlaOp& operand);
915 friend XlaOp Floor(const XlaOp& operand);
916 friend XlaOp Ceil(const XlaOp& operand);
917 friend XlaOp Round(const XlaOp& operand);
918 friend XlaOp Log(const XlaOp& operand);
919 friend XlaOp Log1p(const XlaOp& operand);
920 friend XlaOp Sign(const XlaOp& operand);
921 friend XlaOp Clz(const XlaOp& operand);
922 friend XlaOp Cos(const XlaOp& operand);
923 friend XlaOp Sin(const XlaOp& operand);
924 friend XlaOp Tanh(const XlaOp& operand);
925 friend XlaOp Real(const XlaOp& operand);
926 friend XlaOp Imag(const XlaOp& operand);
927 friend XlaOp Sqrt(const XlaOp& operand);
928 friend XlaOp Rsqrt(const XlaOp& operand);
929 friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
930 absl::Span<const int64> broadcast_dimensions);
931 friend XlaOp IsFinite(const XlaOp& operand);
932 friend XlaOp Iota(XlaBuilder* builder, const Shape& shape,
933 int64 iota_dimension);
934 friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
935 friend XlaOp ConvertElementType(const XlaOp& operand,
936 PrimitiveType new_element_type);
937 friend XlaOp BitcastConvertType(const XlaOp& operand,
938 PrimitiveType new_element_type);
939 friend XlaOp Neg(const XlaOp& operand);
940 friend XlaOp Transpose(const XlaOp& operand,
941 absl::Span<const int64> permutation);
942 friend XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
943 friend XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values,
944 int64 dimension);
945 friend XlaOp Sort(absl::Span<const XlaOp> operands,
946 const XlaComputation& comparator, int64 dimension,
947 bool is_stable);
948 friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
949 friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
950 const XlaComputation& computation,
951 absl::Span<const int64> dimensions,
952 absl::Span<const XlaOp> static_operands);
953 friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma,
954 const Shape& shape);
955 friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
956 friend XlaOp While(const XlaComputation& condition,
957 const XlaComputation& body, const XlaOp& init);
958 friend XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
959 const XlaComputation& true_computation,
960 const XlaOp& false_operand,
961 const XlaComputation& false_computation);
962 friend XlaOp Conditional(
963 const XlaOp& branch_index,
964 absl::Span<const XlaComputation* const> branch_computations,
965 absl::Span<const XlaOp> branch_operands);
966 friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
967 const int mantissa_bits);
968 friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
969 const GatherDimensionNumbers& dimension_numbers,
970 absl::Span<const int64> slice_sizes);
971 friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
972 const XlaOp& updates,
973 const XlaComputation& update_computation,
974 const ScatterDimensionNumbers& dimension_numbers);
975 friend void Send(const XlaOp& operand, const ChannelHandle& handle);
976 friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
977 const ChannelHandle& handle);
978 friend XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
979 const XlaOp& offset, float epsilon,
980 int64 feature_index);
981 friend XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
982 const XlaOp& offset, const XlaOp& mean,
983 const XlaOp& variance, float epsilon,
984 int64 feature_index);
985 friend XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
986 const XlaOp& batch_mean, const XlaOp& batch_var,
987 const XlaOp& grad_output, float epsilon,
988 int64 feature_index);
989 friend XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
990 const ChannelHandle& handle);
991 friend XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
992 const ChannelHandle& handle);
993 friend XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
994 const Shape& shape_with_layout,
995 const ChannelHandle& handle);
996 friend XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
997 const ChannelHandle& handle);
998 friend XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
999 const string& config);
1000 friend XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
1001 const Shape& shape_with_layout,
1002 const string& outfeed_config);
1003 friend XlaOp CreateToken(XlaBuilder* builder);
1004 friend XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
1005
1006 friend XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension);
1007 };
1008
1009 // RAII-style object: sets the current sharding assignment in builder on
1010 // construction, and sets back to the previous assignment on destruction.
1011 class XlaScopedShardingAssignment {
1012 public:
XlaScopedShardingAssignment(xla::XlaBuilder * builder,absl::optional<OpSharding> sharding)1013 XlaScopedShardingAssignment(xla::XlaBuilder* builder,
1014 absl::optional<OpSharding> sharding)
1015 : builder_(builder), prev_sharding_(builder->sharding()) {
1016 SetSharding(sharding);
1017 }
1018
1019 XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
1020 XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
1021 delete;
1022
~XlaScopedShardingAssignment()1023 ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
1024
1025 private:
SetSharding(const absl::optional<OpSharding> & sharding)1026 void SetSharding(const absl::optional<OpSharding>& sharding) {
1027 if (sharding.has_value()) {
1028 builder_->SetSharding(sharding.value());
1029 } else {
1030 builder_->ClearSharding();
1031 }
1032 }
1033
1034 xla::XlaBuilder* const builder_;
1035 absl::optional<OpSharding> prev_sharding_;
1036 };
1037
1038 // Free functions for building XlaOps. The intention is that these will
1039 // become the public API for building XlaOps rather than calling methods on
1040 // XlaBuilder directly.
1041 //
1042
1043 // Enqueues a "retrieve parameter value" instruction for a parameter that was
1044 // passed to the computation.
1045 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
1046 const string& name);
1047
1048 // Enqueues a constant with the value of the given literal onto the
1049 // computation.
1050 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal);
1051
1052 // Enqueues a constant onto the computation. Methods are templated on the
1053 // native host type (NativeT) which corresponds to a specific XLA
1054 // PrimitiveType as given in the following table:
1055 //
1056 // Native Type PrimitiveType
1057 // -----------------------------
1058 // bool PRED
1059 // int32 S32
1060 // int64 S64
1061 // uint32 U32
1062 // uint64 U64
1063 // float F32
1064 // double F64
1065 //
1066 // Note: not all primitive types defined in xla_data.proto have a
1067 // corresponding native type yet.
1068 template <typename NativeT>
1069 XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
1070 template <typename NativeT>
1071 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values);
1072 XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values);
1073 template <typename NativeT>
1074 XlaOp ConstantR2(XlaBuilder* builder,
1075 std::initializer_list<std::initializer_list<NativeT>> values);
1076 template <typename NativeT>
1077 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
1078 const Array<NativeT>& values,
1079 const Layout& layout);
1080 template <typename NativeT>
1081 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values);
1082 template <typename NativeT>
1083 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
1084 const Array2D<NativeT>& values,
1085 const Layout& layout);
1086 template <typename NativeT>
1087 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
1088 const Array2D<NativeT>& values);
1089 template <typename NativeT>
1090 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
1091 const Array3D<NativeT>& values,
1092 const Layout& layout);
1093 template <typename NativeT>
1094 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
1095 const Array3D<NativeT>& values);
1096 template <typename NativeT>
1097 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
1098 const Array4D<NativeT>& values,
1099 const Layout& layout);
1100 template <typename NativeT>
1101 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
1102 const Array4D<NativeT>& values);
1103
1104 // Enqueues a rank one constant (XlaBuilder* builder, vector) onto the
1105 // computation. The vector has size 'length' and every element has the value
1106 // 'value'.
1107 template <typename NativeT>
1108 XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
1109
1110 // Adds dimensions to an array by duplicating the data in the array.
1111 //
1112 // The new dimensions are inserted on the left, i.e. if
1113 // broadcast_sizes has values {a0, ..., aN} and the operand shape
1114 // has dimensions {b0, ..., bM} then the shape of the output has
1115 // dimensions {a0, ..., aN, b0, ..., bM}.
1116 //
1117 // The new dimensions index into copies of the operand, i.e.
1118 //
1119 // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
1120 XlaOp Broadcast(const XlaOp& operand, absl::Span<const int64> broadcast_sizes);
1121
1122 // This op broadcasts the `operand` to an output with the given `shape`.
1123 // `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the
1124 // i'th dimension of the operand is mapped to the broadcast_dimensions[i]'th
1125 // dimension of the output. This also requires that the i'th input dimension is
1126 // either 1 or is the same as the output dimension it's broadcasting into.
1127 //
1128 // For example, say operand = {1, 2}, i.e., a 1D tensor in shape s32[2]; the
1129 // output shape is s32[2,2]:
1130 // - Specifying {1} as brodcast_dimension will generate output
1131 // {{1, 2},
1132 // {1, 2}}
1133 // - On the other hand, specifying {0} as broadcast_dimension
1134 // will generate output
1135 // {{1 , 1},
1136 // {2 , 2}}
1137 XlaOp BroadcastInDim(const XlaOp& operand,
1138 const absl::Span<const int64> out_dim_size,
1139 const absl::Span<const int64> broadcast_dimensions);
1140
1141 // Enqueues a pad operation onto the computation that pads the given value on
1142 // the edges as well as between the elements of the input. padding_config
1143 // specifies the padding amount for each dimension.
1144 XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
1145 const PaddingConfig& padding_config);
1146
1147 // Enqueues an operation onto the computation that flattens the operand based
1148 // on the dimension order (major/slowest-varying to minor/fastest-varying)
1149 // given, followed by reshaping it into the shape with the given dimension
1150 // sizes (also major to minor). Conceptually, this is a limited form of
1151 // "shape casting".
1152 XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
1153 absl::Span<const int64> new_sizes);
1154
1155 // Enqueues an operation onto the computation that collapses the operand, from
1156 // first to last dimension (C order), then reshapes it to the given dimension
1157 // sizes. Conceptually, this is a limited form of "shape casting".
1158 XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);
1159
1160 // Wrapper for Reshape.
1161 // Enqueues an operation to collapse the provided dimensions; e.g. an
1162 // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
1163 // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
1164 // be a consecutive, in-order subsequence of the operand dimensions.
1165 //
1166 // Note that collapsing a single dimension does nothing:
1167 //
1168 // {256} collapsing {0} => {256}
1169 // {1} collapsing {0} => {1}
1170 //
1171 // Collapsing multiple dimensions produces a single result dimension:
1172 //
1173 // {256, 2} collapsing {0,1} => {512}
1174 // {256, 2, 3} collapsing {0,1} => {512, 3}
1175 //
1176 // This could potentially cause data to be moved -- it provides a more
1177 // structured form of reshaping than an arbitrary Reshape operation.
1178 XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions);
1179
1180 // Enqueues a slice operation onto the computation that slices the operand
1181 // from the start indices to the limit indices; e.g.
1182 //
1183 // x
1184 // [ 0 1 2 3 ]
1185 // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
1186 // [ 8 9 a b ]
1187 //
1188 // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
1189 // range notation.
1190 // The strides parameter determines the stride over the slice
1191 XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
1192 absl::Span<const int64> limit_indices,
1193 absl::Span<const int64> strides);
1194
1195 // Enqueues a slice operation in a given dimension, taking all other
1196 // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
1197 // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
1198 // for:
1199 //
1200 // array[:, 2:4:1, :]
1201 XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
1202 int64 stride, int64 dimno);
1203
1204 // Enqueues a slice operation onto the computation that slices the 'operand'
1205 // from dynamic start indices which are passed in 'start_indices'.
1206 // The size of the slice in each dimension is passed in 'slice_sizes',
1207 // which specify the end point of exclusive slice intervals in each
1208 // dimension [start, start + size).
1209 // The shape of each element of 'start_indices' must be scalar, with the span
1210 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1211 // have the same shape.
1212 // Slice index calculations are computed modulo input dimension sizes to
1213 // prevent dynamic start indices from generating out-of-bound array accesses.
1214 XlaOp DynamicSlice(const XlaOp& operand, absl::Span<const XlaOp> start_indices,
1215 absl::Span<const int64> slice_sizes);
1216
1217 ABSL_DEPRECATED("Use span-of-indices form instead")
1218 XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
1219 absl::Span<const int64> slice_sizes);
1220
1221 // Enqueues a dynamic update slice operation onto the computation, which
1222 // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
1223 // The shape of 'update' determines the shape of the slice of 'operand'
1224 // which is updated.
1225 // The indices specified in 'start_indices' specify the offset of the slice
1226 // of 'operand' which is updated.
1227 //
1228 // update = {10, 11} // calculated at runtime.
1229 // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ]
1230 // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11]
1231 // [7 8 9] [7 8 9 ]
1232 //
1233 // The shape of each element of 'start_indices' must be scalar, with the span
1234 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1235 // have the same shape.
1236 // Slice index calculations are computed modulo update dimension sizes to
1237 // prevent dynamic start indices from generating out-of-bound array accesses.
1238 XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
1239 absl::Span<const XlaOp> start_indices);
1240
1241 ABSL_DEPRECATED("Use span-of-indices form instead")
1242 XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
1243 const XlaOp& start_indices);
1244
1245 // Enqueues a concatenate instruction onto the computation. 'operands' must
1246 // have >= 1 entry.
1247 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1248 int64 dimension);
1249
1250 // Enqueue a tracing operation onto the computation; the computation will emit
1251 // a logging message with the operand.
1252 void Trace(const string& tag, const XlaOp& operand);
1253
1254 // Enqueues a conditional-move-like select operation onto the computation;
1255 // predicated on pred, selects between on_true and on_false.
1256 XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
1257
1258 // Enqueues a tuple-creation instruction onto the computation.
1259 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
1260
1261 // Enqueues a tuple-element-get instruction onto the computation.
1262 XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
1263
1264 // Enqueues an equal-to comparison instruction onto the computation.
1265 XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
1266 absl::Span<const int64> broadcast_dimensions = {});
1267
1268 // Enqueues a not-equal comparison instruction onto the computation.
1269 XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
1270 absl::Span<const int64> broadcast_dimensions = {});
1271
1272 // Enqueues a greater-or-equal comparison instruction onto the computation.
1273 XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
1274 absl::Span<const int64> broadcast_dimensions = {});
1275
1276 // Enqueues a greater-than comparison instruction onto the computation.
1277 XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
1278 absl::Span<const int64> broadcast_dimensions = {});
1279
1280 // Enqueues a less-than comparison instruction onto the computation.
1281 XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
1282 absl::Span<const int64> broadcast_dimensions = {});
1283
1284 // Enqueues a less-or-equal comparison instruction onto the computation.
1285 XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
1286 absl::Span<const int64> broadcast_dimensions = {});
1287
1288 // Enqueues a comparison instruction onto the computation.
1289 XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs,
1290 absl::Span<const int64> broadcast_dimensions,
1291 ComparisonDirection direction);
1292
1293 // Enqueues a dot instruction onto the computation.
1294 XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
1295 const PrecisionConfig* precision_config = nullptr);
1296
1297 // Enqueues a general dot instruction onto the computation.
1298 XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
1299 const DotDimensionNumbers& dimension_numbers,
1300 const PrecisionConfig* precision_config = nullptr);
1301
1302 // Enqueues a convolution instruction onto the computation, which uses the
1303 // default convolution dimension numbers.
1304 XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
1305 absl::Span<const int64> window_strides, Padding padding,
1306 int64 feature_group_count = 1, int64 batch_group_count = 1,
1307 const PrecisionConfig* precision_config = nullptr);
1308
1309 // Enqueues a convolution instruction onto the computation, with the caller
1310 // provided padding configuration in the format returned by MakePadding().
1311 XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs,
1312 absl::Span<const int64> window_strides,
1313 absl::Span<const std::pair<int64, int64>> padding,
1314 int64 feature_group_count = 1,
1315 int64 batch_group_count = 1,
1316 const PrecisionConfig* precision_config = nullptr);
1317
1318 // Enqueues a convolution instruction onto the computation, with the caller
1319 // provided dimension numbers configuration.
1320 XlaOp ConvWithGeneralDimensions(
1321 const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
1322 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1323 int64 feature_group_count = 1, int64 batch_group_count = 1,
1324 const PrecisionConfig* precision_config = nullptr);
1325
1326 // Enqueues a convolution instruction onto the computation, with the caller
1327 // provided padding configuration as well as the dimension numbers.
1328 XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
1329 absl::Span<const int64> window_strides,
1330 absl::Span<const std::pair<int64, int64>> padding,
1331 const ConvolutionDimensionNumbers& dimension_numbers,
1332 int64 feature_group_count = 1, int64 batch_group_count = 1,
1333 const PrecisionConfig* precision_config = nullptr);
1334
1335 // Enqueues a convolution instruction onto the computation, with the caller
1336 // provided padding configuration, dilation factors and dimension numbers.
1337 XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
1338 absl::Span<const int64> window_strides,
1339 absl::Span<const std::pair<int64, int64>> padding,
1340 absl::Span<const int64> lhs_dilation,
1341 absl::Span<const int64> rhs_dilation,
1342 const ConvolutionDimensionNumbers& dimension_numbers,
1343 int64 feature_group_count = 1,
1344 int64 batch_group_count = 1,
1345 const PrecisionConfig* precision_config = nullptr);
1346
1347 // Enqueues an FFT instruction onto the computation, of the given type and
1348 // with the given FFT length.
1349 XlaOp Fft(const XlaOp& operand, FftType fft_type,
1350 absl::Span<const int64> fft_length);
1351
1352 // Solves systems of linear equations with lower or upper triangular coefficient
1353 // matrices by forward- or back-substitution. Broadcasting along leading
1354 // dimensions, this routine solves for x in one of the matrix systems
1355 // `op(a) * x = b`, or `x * op(a) = b`,
1356 // for the variable `x` given `a` and `b`, where `op(a)` is either
1357 // `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`.
1358 //
1359 // * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form
1360 // square matrices. If `lower` is true (false), then the strictly upper
1361 // (lower) triangular part of each innermost matrix in `a` is assumed to be
1362 // zero and is not accessed.
1363 // * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a
1364 // tensor of shape `[..., K, M]`.
1365 // * `left_side` is a boolean, indicating whether to solve a system of the form
1366 // op(a) * x = b (true) or x * op(a) = b (false).
1367 // * `lower` is a boolean, indicating whether the argument `a` is
1368 // lower-triangular (true) or upper-triangular (false).
1369 // * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be
1370 // 1 and not accessed.
1371 // * `transpose_a` indicates which function `op` we use to transform the tensor
1372 // `a`: the identity function, transpose(a), or conjugate(transpose(a))
1373 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
1374 bool unit_diagonal,
1375 TriangularSolveOptions::Transpose transpose_a);
1376
1377 // Computes the Cholesky decompositions of a batch of symmetric (Hermitian)
1378 // positive definite matrices.
1379 // `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the
1380 // two minor dimensions equal.
1381 // If `lower` is true, the data from the lower triangle is used; if false, the
1382 // upper triangle is used. The input data in the other triangle of the input
1383 // does not affect the output. Returns the output in the same lower/uppper
1384 // triangle. The data returned in the other output triangle is arbitrary and
1385 // implementation-defined.
1386 //
1387 // The value returned if `a` is not Hermitian positive definite is
1388 // implementation-defined.
1389 XlaOp Cholesky(XlaOp a, bool lower);
1390
1391 // Enqueues an infeed instruction onto the computation, which writes data of
1392 // the given shape to the infeed buffer of the device.
1393 XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
1394 const string& config = "");
1395
1396 // Variant of Infeed which takes a token-shaped operand and produces a
1397 // two-element tuple containing the data value and a token-shaped value.
1398 // Tokens are used for ordering side-effecting operations.
1399 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
1400 XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
1401 const string& config = "");
1402
1403 // Enqueues an outfeed instruction onto the computation. This instruction
1404 // generates outgoing data transfers for the given data.
1405 //
1406 // shape_with_layout communicates the laid out shape that we want to outfeed
1407 // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
1408 // will occur.
1409 void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
1410 const string& outfeed_config);
1411
1412 // Variant of Outfeed which takes a token-shaped operand and produces a
1413 // token-shaped value. Tokens are used for ordering side-effecting operations.
1414 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
1415 XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
1416 const Shape& shape_with_layout,
1417 const string& outfeed_config);
1418
1419 // Enqueues a call instruction onto the computation.
1420 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
1421 absl::Span<const XlaOp> operands);
1422
1423 // Enqueues a custom call instruction onto the computation. A custom call
1424 // invokes code external to XLA. The |operands| are passed to the external code,
1425 // and the external code is expected to produce a result of the given
1426 // |shape|. The exact mechanism is backend-specific. For example, in the CPU
1427 // backend, a call instruction is emitted which targets a symbol with the name
1428 // |call_target_name|. |call_target_name| and |opaque| can arbitrary strings,
1429 // but |call_target_name| should be short as it may be used in labels. |opaque|
1430 // can encode arbitrarily large amounts of information.
1431 XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
1432 absl::Span<const XlaOp> operands, const Shape& shape,
1433 const string& opaque = "");
1434
1435 // Overload which constructs a custom call with fixed layouts. The operands will
1436 // have the layouts specified by |operand_shapes_with_layout| when provided to
1437 // external code, and the external code is expected to produce a result with the
1438 // layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
1439 // and |operand_shapes_with_layout| must have layouts.
1440 XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
1441 absl::Span<const XlaOp> operands,
1442 const Shape& shape_with_layout,
1443 absl::Span<const Shape> operand_shapes_with_layout,
1444 const string& opaque = "");
1445
1446 // The following methods enqueue element-wise binary arithmetic operations
1447 // onto the computation. The shapes of the operands have to match unless one
1448 // of the operands is a scalar, or an explicit broadcast dimension is given
1449 // (see g3doc for more details).
1450
1451 // Enqueues a complex compose instruction onto the computation.
1452 XlaOp Complex(const XlaOp& real, const XlaOp& imag,
1453 absl::Span<const int64> broadcast_dimensions = {});
1454
1455 // Enqueues a complex conjugate instruction onto the computation.
1456 XlaOp Conj(const XlaOp& operand);
1457
1458 // Enqueues an add instruction onto the computation.
1459 XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
1460 absl::Span<const int64> broadcast_dimensions = {});
1461
1462 // Enqueues a subtract instruction onto the computation.
1463 XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
1464 absl::Span<const int64> broadcast_dimensions = {});
1465
1466 // Enqueues a multiply instruction onto the computation.
1467 XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
1468 absl::Span<const int64> broadcast_dimensions = {});
1469
1470 // Enqueues a divide instruction onto the computation.
1471 XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
1472 absl::Span<const int64> broadcast_dimensions = {});
1473
1474 // Enqueues a remainder instruction onto the computation.
1475 XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
1476 absl::Span<const int64> broadcast_dimensions = {});
1477
1478 // Enqueues a max instruction onto the computation.
1479 XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
1480 absl::Span<const int64> broadcast_dimensions = {});
1481
1482 // Enqueues a min instruction onto the computation.
1483 XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
1484 absl::Span<const int64> broadcast_dimensions = {});
1485
1486 // Element-wise logical operators
1487 XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
1488 absl::Span<const int64> broadcast_dimensions = {});
1489
1490 // Overload to call And with 3 or more operands. We need the following somewhat
1491 // convoluted overload set to disambiguate with the overload that takes the
1492 // `broadcast_dimensions` optional param.
And(const XlaOp & op1,const XlaOp & op2,const XlaOp & op3)1493 inline XlaOp And(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3) {
1494 return And(op1, And(op2, op3));
1495 }
1496 template <typename... XlaOpTs>
And(const XlaOp & op1,const XlaOp & op2,const XlaOp & op3,const XlaOpTs &...operands)1497 XlaOp And(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3,
1498 const XlaOpTs&... operands) {
1499 return And(op1, And(op2, And(op3, operands...)));
1500 }
1501
1502 XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
1503 absl::Span<const int64> broadcast_dimensions = {});
1504
1505 // Overload to call Or with 3 or more operands. As with `And`, we need the
1506 // following complicated overload set to handle the default arg in the `Or`
1507 // overload above.
Or(const XlaOp & op1,const XlaOp & op2,const XlaOp & op3)1508 inline XlaOp Or(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3) {
1509 return Or(op1, Or(op2, op3));
1510 }
1511 template <typename... XlaOpTs>
Or(const XlaOp & op1,const XlaOp & op2,const XlaOp & op3,const XlaOpTs &...operands)1512 XlaOp Or(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3,
1513 const XlaOpTs&... operands) {
1514 return Or(op1, Or(op2, Or(op3, operands...)));
1515 }
1516
1517 XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
1518 absl::Span<const int64> broadcast_dimensions = {});
1519
1520 XlaOp Not(const XlaOp& operand);
1521
1522 XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
1523 absl::Span<const int64> broadcast_dimensions = {});
1524 XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
1525 absl::Span<const int64> broadcast_dimensions = {});
1526 XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
1527 absl::Span<const int64> broadcast_dimensions = {});
1528
1529 // Reduces an array among the provided dimensions, given "computation" as a
1530 // reduction operator.
1531 XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
1532 const XlaComputation& computation,
1533 absl::Span<const int64> dimensions_to_reduce);
1534
1535 // Reduces several arrays simultaneously among the provided dimensions, given
1536 // "computation" as a reduction operator.
1537 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1538 absl::Span<const XlaOp> init_values,
1539 const XlaComputation& computation,
1540 absl::Span<const int64> dimensions_to_reduce);
1541
1542 // Convenience wrapper around the above that reduces all the dimensions in the
1543 // operand shape.
1544 XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
1545 const XlaComputation& computation);
1546
1547 // Enqueues a windowed reduce instruction onto the computation.
1548 XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
1549 const XlaComputation& computation,
1550 absl::Span<const int64> window_dimensions,
1551 absl::Span<const int64> window_strides, Padding padding);
1552
1553 // As ReduceWindow(), but the padding is given in the format
1554 // returned by MakePadding().
1555 XlaOp ReduceWindowWithGeneralPadding(
1556 const XlaOp& operand, const XlaOp& init_value,
1557 const XlaComputation& computation,
1558 absl::Span<const int64> window_dimensions,
1559 absl::Span<const int64> window_strides,
1560 absl::Span<const int64> base_dilations,
1561 absl::Span<const int64> window_dilations,
1562 absl::Span<const std::pair<int64, int64>> padding);
1563
1564 // Returns the sum of the operand value within each subgroup of replicas. All
1565 // replicas supply one input to the sum and all replicas receive the resulting
1566 // sum for each subgroup.
1567 XlaOp CrossReplicaSum(const XlaOp& operand,
1568 absl::Span<const ReplicaGroup> replica_groups = {});
1569
1570 // Enqueues an operation that do an AllReduce of the operand cross cores. Here
1571 // AllReduce means doing a reduction on the input operand cross cores and then
1572 // broadcasting the reduction result to those cores. The reduction function is
1573 // defined by `computation`, which should be a commutative computation on
1574 // scalars, e.g., add, min, or max. The way that AllReduce is applied is
1575 // configured by:
1576 //
1577 // - `replica_groups`: each ReplicaGroup contains a list of replica id. If
1578 // empty, all replicas belong to one group. Allreduce will be applied within
1579 // subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}}
1580 // means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
1581 //
1582 // - `channel_id`: for Allreduce nodes from different modules, if they have the
1583 // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
1584 // applied cross modules.
1585 //
1586 // TODO(b/117564385): Rename this to AllReduce when it's ready to use.
1587 XlaOp CrossReplicaSum(
1588 const XlaOp& operand, const XlaComputation& computation,
1589 absl::Span<const ReplicaGroup> replica_groups = {},
1590 const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
1591
1592 // Enqueues an operation that do an Alltoall of the operand cross cores.
1593 XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
1594 int64 concat_dimension, int64 split_count,
1595 const std::vector<ReplicaGroup>& replica_groups = {});
1596
1597 // Enqueues an collective operation that sends and receives data cross replicas.
1598 //
1599 // - `source_target_pair`: a list of (source_replica_id, target_replica_id)
1600 // pairs. For each pair, the operand is sent from source replica to target
1601 // replica. Note that, 1) any two pairs should not have the same target replica
1602 // id, and they should not have the same source replica id; 2) if a replica id
1603 // is not a target in any pair, then the output on that replica is a tensor
1604 // consists of 0(s) with the same shape as the input.
1605 XlaOp CollectivePermute(
1606 const XlaOp& operand,
1607 const std::vector<std::pair<int64, int64>>& source_target_pairs);
1608
1609 // Enqueues an operation that returns the replica ID.
1610 XlaOp ReplicaId(XlaBuilder* builder);
1611
1612 // Enqueues an operation that scatters the `source` array to the selected
1613 // indices of each window.
1614 XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
1615 absl::Span<const int64> window_dimensions,
1616 absl::Span<const int64> window_strides, Padding padding,
1617 const XlaOp& source, const XlaOp& init_value,
1618 const XlaComputation& scatter);
1619
1620 // As SelectAndScatter(), but the padding is given in the format
1621 // returned by MakePadding().
1622 XlaOp SelectAndScatterWithGeneralPadding(
1623 const XlaOp& operand, const XlaComputation& select,
1624 absl::Span<const int64> window_dimensions,
1625 absl::Span<const int64> window_strides,
1626 absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
1627 const XlaOp& init_value, const XlaComputation& scatter);
1628
1629 // Enqueues an abs instruction onto the computation.
1630 XlaOp Abs(const XlaOp& operand);
1631
1632 // Enqueues a atan2 instruction onto the computation.
1633 XlaOp Atan2(const XlaOp& y, const XlaOp& x,
1634 absl::Span<const int64> broadcast_dimensions = {});
1635
1636 // Enqueues an exp instruction onto the computation.
1637 XlaOp Exp(const XlaOp& operand);
1638
1639 // Enqueues an expm1 instruction onto the computation.
1640 XlaOp Expm1(const XlaOp& operand);
1641
1642 // Enqueues a floor instruction onto the computation.
1643 XlaOp Floor(const XlaOp& operand);
1644
1645 // Enqueues a ceil instruction onto the computation.
1646 XlaOp Ceil(const XlaOp& operand);
1647
1648 // Enqueues a round instruction onto the computation, rounding to nearest even
1649 // with half-way cases rounding away from zero.
1650 XlaOp Round(const XlaOp& operand);
1651
1652 // Enqueues an log instruction (natural logarithm) onto the computation.
1653 XlaOp Log(const XlaOp& operand);
1654
1655 // Enqueues an log1p instruction (log(x+1)) onto the computation.
1656 XlaOp Log1p(const XlaOp& operand);
1657
1658 // Enqueues a sign instruction onto the computation.
1659 XlaOp Sign(const XlaOp& operand);
1660
1661 // Enqueues a count leading zeros instruction onto the computation.
1662 XlaOp Clz(const XlaOp& operand);
1663
1664 // Enqueues a cosine instruction onto the computation.
1665 XlaOp Cos(const XlaOp& operand);
1666
1667 // Enqueues a sine instruction onto the computation.
1668 XlaOp Sin(const XlaOp& operand);
1669
1670 // Enqueues a tanh instruction onto the computation.
1671 XlaOp Tanh(const XlaOp& operand);
1672
1673 // Enqueues a real-part instruction onto the computation.
1674 XlaOp Real(const XlaOp& operand);
1675
1676 // Enqueues an imaginary-part instruction onto the computation.
1677 XlaOp Imag(const XlaOp& operand);
1678
1679 // Enqueues a sqrt computation onto the computation.
1680 XlaOp Sqrt(const XlaOp& operand);
1681
1682 // Enqueues a rsqrt computation onto the computation.
1683 XlaOp Rsqrt(const XlaOp& operand);
1684
1685 // Enqueues a lhs^rhs computation onto the computation.
1686 XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
1687 absl::Span<const int64> broadcast_dimensions = {});
1688
1689 // Enqueues an operator that tests if the operand's values are finite, i.e., not
1690 // +/-Inf or NaN. Returns an array of booleans with the same shape where
1691 // entries are true iff the corresponding entry was not infinite or NaN.
1692 //
1693 // Defined only for real-valued (i.e. not complex) floating-point types; raises
1694 // an error for other types.
1695 //
1696 // See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h.
1697 XlaOp IsFinite(const XlaOp& operand);
1698
1699 // Enqueues an iota operation onto the computation.
1700 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension);
1701
1702 // Enqueues a rank-1 iota operation onto the computation.
1703 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
1704
1705 // Enqueues a convert instruction onto the computation that changes the
1706 // element type of the operand array to primitive_type.
1707 XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type);
1708
1709 // Enqueues a no-op instruction onto the computation that changes
1710 // the element type of the operand array to primitive_type. The
1711 // bit-widths of the source and destination element types must be
1712 // identical.
1713 XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type);
1714
1715 // Enqueues a negate instruction onto the computation.
1716 XlaOp Neg(const XlaOp& operand);
1717
1718 // Enqueues a transpose instruction onto the computation.
1719 XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation);
1720
1721 // Enqueues a reverse instruction onto the computation. The order of the
1722 // elements in the given dimensions is reversed (i.e., the element at index i
1723 // is moved to index dimension_size - 1 - i).
1724 XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
1725
1726 // Enqueues a sort (as increasing order) instruction onto the computation.
1727 // If only keys are provided:
1728 // * If the keys are an rank-1 tensor (an array), the result is a sorted array
1729 // of keys, in ascending order.
1730 // * If the keys have higher rank, the keys are sorted along the provided
1731 // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension
1732 // value of 0 will independently sort every column, and a dimension value of 1
1733 // will independently sort each row. If no dimension number is provided, then
1734 // the last dimension is chosen by default.
1735 //
1736 // If both keys and values are provided:
1737 // * The keys and all values must be tensors with the same dimensions. The
1738 // element types of the tensors may be different.
1739 // * The result is a tuple that consists of a sorted tensor of keys (along the
1740 // provided dimension, as above) as the first element, and tensors with their
1741 // corresponding values as the other elements.
1742 ABSL_DEPRECATED("Use form with comparator computation instead")
1743 XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {},
1744 int64 dimension = -1);
1745
1746 // Enqueues a sort instruction onto the computation, using 'comparator' for
1747 // comparisons. 'comparator' needs to define a strict weak order. 'is_stable'
1748 // determines whether the stable sorting should be used.
1749 // If only one operand is provided:
1750 // * If the operand is a rank-1 tensor (an array), the result is a sorted array.
1751 // The resulting sorting order has the property that for all index positions
1752 // i, j with i < j, either
1753 // comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or
1754 // comparator(value[i], value[j]) = true.
1755 // * If the operand has higher rank, the operand is sorted along the provided
1756 // dimension. For example, for a rank-2 tensor (a matrix), a dimension value
1757 // of 0 will independently sort every column, and a dimension value of 1 will
1758 // independently sort each row. If no dimension number is provided, then the
1759 // last dimension is chosen by default. For the dimension which is sorted, the
1760 // same sorting order applies as in the rank-1 case.
1761 //
1762 // If more than one operand is provided:
1763 // * All operands must be tensors with the same dimensions. The element types of
1764 // the tensors may be different.
1765 // * The result is a tuple that consists of the operands in sorted order (along
1766 // the provided dimension, as above). The same permutation as implied by the
1767 // comparison computation is applied to all operand tensors. When comparing
1768 // two index positions, 'comparator' is called with 2 * n scalar parameters,
1769 // where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at
1770 // two index positions.
1771 // Default comparator computations can be found in lib/comparators.h
1772 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
1773 int64 dimension = -1, bool is_stable = false);
1774
1775 // Enqueues a clamp instruction onto the computation.
1776 XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
1777
1778 // Enqueues a map instruction onto the computation.
1779 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1780 const XlaComputation& computation, absl::Span<const int64> dimensions,
1781 absl::Span<const XlaOp> static_operands = {});
1782
1783 // Enqueues a N(mu, sigma) random number generation instruction onto the
1784 // computation.
1785 XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape);
1786
1787 // Enqueues a U(a, b) random number generation instruction onto the
1788 // computation. Returns values in the semi-open interval [a, b).
1789 XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
1790
1791 // Enqueues a while node onto the computation.
1792 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
1793 const XlaOp& init);
1794
1795 // Enqueues a conditional node onto the computation.
1796 XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
1797 const XlaComputation& true_computation,
1798 const XlaOp& false_operand,
1799 const XlaComputation& false_computation);
1800
1801 // Enqueues either a predicated (if/else) or indexed (switch/case/default)
1802 // conditional node onto the computation. N >= 1 branch_computations and
1803 // branch_operands are matched by index. branch_index selects the branch that
1804 // will be executed. Out of range branch_index uses the N-1'th
1805 // branch_computation as default.
1806 XlaOp Conditional(const XlaOp& branch_index,
1807 absl::Span<const XlaComputation* const> branch_computations,
1808 absl::Span<const XlaOp> branch_operands);
1809
1810 // Enqueues a ReducePrecision node onto the computation.
1811 XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
1812 const int mantissa_bits);
1813
1814 // Enqueues a Gather node onto the computation.
1815 XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
1816 const GatherDimensionNumbers& dimension_numbers,
1817 absl::Span<const int64> slice_sizes);
1818
1819 // Enqueues a Scatter node onto the computation.
1820 XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
1821 const XlaOp& updates, const XlaComputation& update_computation,
1822 const ScatterDimensionNumbers& dimension_numbers);
1823
1824 // Enqueues a Send node onto the computation for device-to-device
1825 // communication. This operation sends the given operand to
1826 // a Recv instruction in a different computation that shares the same channel
1827 // handle.
1828 void Send(const XlaOp& operand, const ChannelHandle& handle);
1829
1830 // Variant of Send which takes a token-shaped operand and produces a
1831 // token-shaped value. Tokens are used for ordering side-effecting operations.
1832 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
1833 XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
1834 const ChannelHandle& handle);
1835
1836 // Enqueues a Recv node onto the computation for device-to-device
1837 // communication. The data comes from a Send instruction in a different
1838 // computation that shares the same channel handle and its shape must be the
1839 // same as the given shape.
1840 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
1841 const ChannelHandle& handle);
1842
1843 // Variant of Recv which takes a token-shaped operand and produces a two-element
1844 // tuple containing the data value and a token-shaped value. Tokens are used
1845 // for ordering side-effecting operations.
1846 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
1847 XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
1848 const ChannelHandle& handle);
1849
1850 // Enqueues a Send node which transfers data from the device to the host. The
1851 // 'shape_with_layout' argument defines the layout of the data transferred; its
1852 // shape must be compatible with the shape of the operand. The operand must be
1853 // array-shaped.
1854 // TODO(b/111544877): Support tuple shapes.
1855 XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
1856 const Shape& shape_with_layout, const ChannelHandle& handle);
1857
1858 // Enqueues a Recv node which transfers data from the host to the device. The
1859 // given shape must contain a layout and must be an array.
1860 // TODO(b/111544877): Support tuple shapes.
1861 XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
1862 const ChannelHandle& handle);
1863
1864 // Enqueues an operation (AfterAll) with no operands that produces a
1865 // token-shaped value. Tokens are used for ordering side-effecting operations.
1866 // This is a separate method from AfterAll to facility the removal of
1867 // operand-less AfterAll instructions.
1868 // TODO(b/110532604): Remove this function when all tokens are derived from a
1869 // single token generated or passed into the entry computation.
1870 XlaOp CreateToken(XlaBuilder* builder);
1871
1872 // Enqueues an AfterAll instruction which produces a token-shaped value and
1873 // takes a variadic number of token-shaped operands. The number of operands must
1874 // be greater than zero. Used for joining tokens.
1875 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
1876
1877 // Normalizes operand across spatial and batch dimensions for each feature.
1878 //
1879 // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
1880 // is the normalized result and batch_mean and batch_var are the mean and
1881 // variance, respectively, across batch for the operand.
1882 XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
1883 const XlaOp& offset, float epsilon,
1884 int64 feature_index);
1885
1886 // Normalizes operand across spatial and batch dimensions for each feature.
1887 //
1888 // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
1889 // computing `mean` and `variance` for each batch inside the operation. It
1890 // uses the input `mean` and `variance` instead as estimated values. The
1891 // purpose of this op is to reduce latency in inference, hence the name
1892 // `BatchNormInference`.
1893 //
1894 // The output has the same shape as `operand`, and contains the normalized
1895 // values for each batch.
1896 XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
1897 const XlaOp& offset, const XlaOp& mean,
1898 const XlaOp& variance, float epsilon,
1899 int64 feature_index);
1900
1901 // Calculates the gradients of a batch norm op.
1902 //
1903 // The inputs `batch_mean` and `batch_var` represent the mean and variance
1904 // across the batch.
1905 //
1906 // Returns a tuple of three elements:
1907 // - grad_operand: Gradient with respect to input `operand`
1908 // - grad_offset: Gradient with respect to input `offset`
1909 // - grad_scale: Gradient with respect to input `scale`
1910 XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
1911 const XlaOp& batch_mean, const XlaOp& batch_var,
1912 const XlaOp& grad_output, float epsilon,
1913 int64 feature_index);
1914
1915 // Returns the size of the given dimension of the operand. The operand must be
1916 // array shaped.
1917 XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension);
1918
1919 // Implementation details below this point.
1920 //
1921
1922 // Free function template implementations.
1923
1924 template <typename NativeT>
ConstantR0(XlaBuilder * builder,NativeT value)1925 XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
1926 return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
1927 }
1928
1929 template <typename NativeT>
ConstantR1(XlaBuilder * builder,absl::Span<const NativeT> values)1930 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
1931 return ConstantLiteral(builder, LiteralUtil::CreateR1<NativeT>(values));
1932 }
1933
1934 template <typename NativeT>
ConstantR1(XlaBuilder * builder,int64 length,NativeT value)1935 XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
1936 Literal literal(ShapeUtil::MakeShape(
1937 primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
1938 literal.PopulateWithValue(value);
1939 return ConstantLiteral(builder, literal);
1940 }
1941
ConstantR1(XlaBuilder * builder,const tensorflow::core::Bitmap & values)1942 inline XlaOp ConstantR1(XlaBuilder* builder,
1943 const tensorflow::core::Bitmap& values) {
1944 return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
1945 }
1946
1947 template <typename NativeT>
ConstantR2(XlaBuilder * builder,std::initializer_list<std::initializer_list<NativeT>> values)1948 XlaOp ConstantR2(XlaBuilder* builder,
1949 std::initializer_list<std::initializer_list<NativeT>> values) {
1950 return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
1951 }
1952
1953 template <typename NativeT>
ConstantFromArrayWithLayout(XlaBuilder * builder,const Array<NativeT> & values,const Layout & layout)1954 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
1955 const Array<NativeT>& values,
1956 const Layout& layout) {
1957 return ConstantLiteral(
1958 builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
1959 }
1960
1961 template <typename NativeT>
ConstantFromArray(XlaBuilder * builder,const Array<NativeT> & values)1962 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
1963 return ConstantLiteral(builder,
1964 LiteralUtil::CreateFromArray<NativeT>(values));
1965 }
1966
1967 template <typename NativeT>
ConstantR2FromArray2DWithLayout(XlaBuilder * builder,const Array2D<NativeT> & values,const Layout & layout)1968 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
1969 const Array2D<NativeT>& values,
1970 const Layout& layout) {
1971 return ConstantLiteral(
1972 builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
1973 }
1974
1975 template <typename NativeT>
ConstantR2FromArray2D(XlaBuilder * builder,const Array2D<NativeT> & values)1976 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
1977 const Array2D<NativeT>& values) {
1978 return ConstantLiteral(builder,
1979 LiteralUtil::CreateR2FromArray2D<NativeT>(values));
1980 }
1981
1982 template <typename NativeT>
ConstantR3FromArray3DWithLayout(XlaBuilder * builder,const Array3D<NativeT> & values,const Layout & layout)1983 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
1984 const Array3D<NativeT>& values,
1985 const Layout& layout) {
1986 return ConstantLiteral(
1987 builder,
1988 LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
1989 }
1990
1991 template <typename NativeT>
ConstantR3FromArray3D(XlaBuilder * builder,const Array3D<NativeT> & values)1992 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
1993 const Array3D<NativeT>& values) {
1994 return ConstantFromArray(builder, values);
1995 }
1996
1997 template <typename NativeT>
ConstantR4FromArray4DWithLayout(XlaBuilder * builder,const Array4D<NativeT> & values,const Layout & layout)1998 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
1999 const Array4D<NativeT>& values,
2000 const Layout& layout) {
2001 return ConstantFromArrayWithLayout(builder, values, layout);
2002 }
2003
2004 template <typename NativeT>
ConstantR4FromArray4D(XlaBuilder * builder,const Array4D<NativeT> & values)2005 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
2006 const Array4D<NativeT>& values) {
2007 return ConstantFromArray(builder, values);
2008 }
2009
2010 } // namespace xla
2011
2012 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
2013