• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstddef>
17 #include <vector>
18 
19 #include "absl/algorithm/container.h"
20 #include "absl/strings/match.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_split.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/framework/common_shape_fns.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/shape_inference.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 
29 // Note: Most of the operators defined in this module are used by the jax2tf
30 // converter (see go/jax2tf for details) and are used in SavedModel produced
31 // by jax2tf. Hence, we need to maintain backwards compatibility for these
32 // operators. Please reach out to the JAX team if you want to make changes.
33 
34 namespace tensorflow {
35 namespace {
36 
37 // Helper shape function for operators that return an output with the same rank
38 // as their first input.
UnchangedRank(shape_inference::InferenceContext * c)39 Status UnchangedRank(shape_inference::InferenceContext* c) {
40   if (c->RankKnown(c->input(0))) {
41     c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
42   } else {
43     c->set_output(0, c->input(0));
44   }
45   return OkStatus();
46 }
47 
48 REGISTER_OP("XlaBroadcastHelper")
49     .Input("lhs: T")
50     .Input("rhs: T")
51     .Input("broadcast_dims: Tindices")
52     .Attr("T: numbertype")
53     .Attr("Tindices: {int32, int64}")
54     .Output("lhs_output: T")
55     .Output("rhs_output: T")
56     .SetShapeFn(shape_inference::UnknownShape)
57     .Doc(R"doc(
58 Helper operator for performing XLA-style broadcasts
59 
60 Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to
61 whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules
62 for binary operators.
63 
64 lhs: the LHS input tensor
65 rhs: the RHS input tensor
66 broadcast_dims: an XLA-style broadcast dimension specification
67 lhs_output: the broadcasted LHS tensor
68 rhs_output: the broadcasted RHS tensor
69 )doc");
70 
71 REGISTER_OP("XlaSelfAdjointEig")
72     .Input("a: T")
73     .Attr("lower: bool")
74     .Attr("max_iter: int")
75     .Attr("epsilon: float")
76     .Output("w: T")
77     .Output("v: T")
78     .SetShapeFn(shape_inference::UnknownShape)
79     .Attr("T: numbertype")
80     .Doc(R"doc(
81 Computes the eigen decomposition of a batch of self-adjoint matrices
82 (Note: Only real inputs are supported).
83 
84 Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in
85 tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for
86 i=0...N-1.
87 
88 a: the input tensor.
89 
90 lower: a boolean specifies whether the calculation is done with the lower
91   triangular part or the upper triangular part.
92 
93 max_iter: maximum number of sweep update, i.e., the whole lower triangular
94   part or upper triangular part based on parameter lower. Heuristically, it has
95   been argued that approximately logN sweeps are needed in practice (Ref: Golub &
96   van Loan "Matrix Computation").
97 
98 epsilon: the tolerance ratio.
99 
100 w: The eigenvalues in ascending order, each repeated according to its
101   multiplicity.
102 v: The column v[..., :, i] is the normalized eigenvector corresponding to the
103   eigenvalue w[..., i].
104 )doc");
105 
106 REGISTER_OP("XlaSvd")
107     .Input("a: T")
108     .Attr("max_iter: int")
109     .Attr("epsilon: float")
110     .Attr("precision_config: string")
111     .Output("s: T")
112     .Output("u: T")
113     .Output("v: T")
114     .SetShapeFn(shape_inference::UnknownShape)
115     .Attr("T: numbertype")
116     .Doc(R"doc(
117 Computes the eigen decomposition of a batch of self-adjoint matrices
118 (Note: Only real inputs are supported).
119 
120 Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in
121 tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]).
122 
123 a: the input tensor.
124 
125 max_iter: maximum number of sweep update, i.e., the whole lower triangular
126   part or upper triangular part based on parameter lower. Heuristically, it has
127   been argued that approximately log(min (M, N)) sweeps are needed in practice
128   (Ref: Golub & van Loan "Matrix Computation").
129 
130 epsilon: the tolerance ratio.
131 
132 precision_config: a serialized xla::PrecisionConfig proto.
133 
134 s: Singular values. The values are sorted in reverse order of magnitude, so
135   s[..., 0] is the largest value, s[..., 1] is the second largest, etc.
136 u: Left singular vectors.
137 v: Right singular vectors.
138 )doc");
139 
140 REGISTER_OP("XlaConv")
141     .Input("lhs: T")
142     .Input("rhs: T")
143     .Input("window_strides: Tindices")
144     .Input("padding: Tindices")
145     .Input("lhs_dilation: Tindices")
146     .Input("rhs_dilation: Tindices")
147     .Input("feature_group_count: Tindices")
148     .Attr("T: numbertype")
149     .Attr("Tindices: {int32, int64}")
150     .Attr("dimension_numbers: string")
151     .Attr("precision_config: string")
152     .Output("output: T")
153     .SetShapeFn(UnchangedRank)
154     .Doc(R"doc(
155 Wraps the XLA ConvGeneralDilated operator, documented at
156  https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
157 .
158 
159 lhs: the input tensor
160 rhs: the kernel tensor
161 window_strides: the inter-window strides
162 padding: the padding to apply at the start and end of each input dimensions
163 lhs_dilation: dilation to apply between input elements
164 rhs_dilation: dilation to apply between kernel elements
165 feature_group_count: number of feature groups for grouped convolution.
166 dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
167 precision_config: a serialized xla::PrecisionConfig proto.
168 )doc");
169 
170 REGISTER_OP("XlaConvV2")
171     .Input("lhs: LhsT")
172     .Input("rhs: RhsT")
173     .Input("window_strides: Tindices")
174     .Input("padding: Tindices")
175     .Input("lhs_dilation: Tindices")
176     .Input("rhs_dilation: Tindices")
177     .Input("feature_group_count: Tindices")
178     .Attr("LhsT: numbertype")
179     .Attr("RhsT: numbertype")
180     .Attr("Tindices: {int32, int64}")
181     .Attr("dimension_numbers: string")
182     .Attr("precision_config: string")
183     .Attr("preferred_element_type: numbertype")
184     .Attr("batch_group_count: int = 1")
185     .Output("output: preferred_element_type")
186     .SetShapeFn(UnchangedRank)
187     .Doc(R"doc(
188 Wraps the XLA ConvGeneralDilated operator, documented at
189  https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
190 .
191 
192 lhs: input tensor
193 rhs: kernel tensor
194 window_strides: inter-window strides
195 padding: padding to apply at the start and end of each input dimensions
196 lhs_dilation: dilation to apply between input elements
197 rhs_dilation: dilation to apply between kernel elements
198 feature_group_count: number of feature groups for grouped convolution.
199 dimension_numbers: serialized xla::ConvolutionDimensionNumbers proto.
200 precision_config: serialized xla::PrecisionConfig proto.
201 preferred_element_type: type of the tensor.
202 batch_group_count: number of batch groups or grouped filters.
203 )doc");
204 
XlaDotShapeFunction(shape_inference::InferenceContext * c)205 static Status XlaDotShapeFunction(shape_inference::InferenceContext* c) {
206   shape_inference::ShapeHandle lhs_shape_handle = c->input(0);
207   shape_inference::ShapeHandle rhs_shape_handle = c->input(1);
208   if (!c->RankKnown(lhs_shape_handle) || !c->RankKnown(rhs_shape_handle)) {
209     return shape_inference::UnknownShape(c);
210   }
211 
212   string dimension_numbers_string;
213   TF_RETURN_IF_ERROR(
214       c->GetAttr("dimension_numbers", &dimension_numbers_string));
215 
216   xla::DotDimensionNumbers dimension_numbers;
217   dimension_numbers.ParseFromString(dimension_numbers_string);
218 
219   // Check that number of contracting dimensions match.
220   if (dimension_numbers.lhs_contracting_dimensions_size() !=
221       dimension_numbers.rhs_contracting_dimensions_size())
222     return errors::InvalidArgument(
223         "Must specify the same number of contracting dimensions for lhs "
224         "and rhs. Got: ",
225         dimension_numbers.lhs_contracting_dimensions_size(), " and ",
226         dimension_numbers.rhs_contracting_dimensions_size());
227 
228   // Check that contracting dimension sizes match.
229   for (int64_t i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
230        ++i) {
231     const int64_t lhs_contracting_dimension =
232         dimension_numbers.lhs_contracting_dimensions(i);
233     const int64_t rhs_contracting_dimension =
234         dimension_numbers.rhs_contracting_dimensions(i);
235     shape_inference::DimensionHandle unused;
236     TF_RETURN_WITH_CONTEXT_IF_ERROR(
237         c->Merge(c->DimKnownRank(lhs_shape_handle, lhs_contracting_dimension),
238                  c->DimKnownRank(rhs_shape_handle, rhs_contracting_dimension),
239                  &unused),
240         "For contracting dimension ", i, " which is lhs dimension ",
241         lhs_contracting_dimension, " and rhs dimension ",
242         rhs_contracting_dimension);
243   }
244 
245   // Check that number of batch dimensions match.
246   if (dimension_numbers.lhs_batch_dimensions_size() !=
247       dimension_numbers.rhs_batch_dimensions_size())
248     return errors::InvalidArgument(
249         "Must specify the same number of batch dimensions for lhs "
250         "and rhs. Got: ",
251         dimension_numbers.lhs_batch_dimensions_size(), " and ",
252         dimension_numbers.rhs_batch_dimensions_size());
253 
254   // The ranks of lhs and rhs are decremented by the number of contractions,
255   // and added for the rank of the result. When an input tensor
256   // is a scalar, its contribution to the rank of the result is 0. Generate
257   // the result dimensions in order, batch dimensions, then the
258   // non-contracted and non-batch lhs and rhs dimensions.
259   std::vector<shape_inference::DimensionHandle> output_dims;
260 
261   // Check that batch dimension sizes match, and add them to output_dims.
262   for (int64_t i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
263     const int64_t lhs_batch_dimension =
264         dimension_numbers.lhs_batch_dimensions(i);
265     const int64_t rhs_batch_dimension =
266         dimension_numbers.rhs_batch_dimensions(i);
267     shape_inference::DimensionHandle out;
268     TF_RETURN_WITH_CONTEXT_IF_ERROR(
269         c->Merge(c->DimKnownRank(lhs_shape_handle, lhs_batch_dimension),
270                  c->DimKnownRank(rhs_shape_handle, rhs_batch_dimension), &out),
271         "For batch dimension ", i, " which is lhs dimension ",
272         lhs_batch_dimension, " and rhs dimension ", rhs_batch_dimension);
273     output_dims.emplace_back(out);
274   }
275 
276   const int32_t lhs_rank = c->Rank(lhs_shape_handle);
277   for (int64_t i = 0; i < lhs_rank; ++i) {
278     if (absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
279                               i) ||
280         absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
281       continue;
282     }
283     output_dims.emplace_back(c->Dim(lhs_shape_handle, i));
284   }
285 
286   const int32_t rhs_rank = c->Rank(rhs_shape_handle);
287   for (int64_t i = 0; i < rhs_rank; ++i) {
288     if (absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
289                               i) ||
290         absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
291       continue;
292     }
293     output_dims.emplace_back(c->Dim(rhs_shape_handle, i));
294   }
295 
296   c->set_output(0, c->MakeShape(output_dims));
297   return OkStatus();
298 }
299 
300 REGISTER_OP("XlaDot")
301     .Input("lhs: T")
302     .Input("rhs: T")
303     .Attr("T: numbertype")
304     .Attr("dimension_numbers: string")
305     .Attr("precision_config: string")
306     .Output("output: T")
307     .SetShapeFn(XlaDotShapeFunction)
308     .Doc(R"doc(
309 Wraps the XLA DotGeneral operator, documented at
310  https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
311 .
312 
313 lhs: the LHS tensor
314 rhs: the RHS tensor
315 dimension_numbers: a serialized xla::DotDimensionNumbers proto.
316 precision_config: a serialized xla::PrecisionConfig proto.
317 )doc");
318 
319 REGISTER_OP("XlaDotV2")
320     .Input("lhs: LhsT")
321     .Input("rhs: RhsT")
322     .Attr("LhsT: numbertype")
323     .Attr("RhsT: numbertype")
324     .Attr("dimension_numbers: string")
325     .Attr("precision_config: string")
326     .Attr("preferred_element_type: numbertype")
327     .Output("output: preferred_element_type")
328     .SetShapeFn(XlaDotShapeFunction)
329     .Doc(R"doc(
330 Wraps the XLA DotGeneral operator, documented at
331  https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
332 .
333 
334 lhs: the LHS tensor
335 rhs: the RHS tensor
336 dimension_numbers: a serialized xla::DotDimensionNumbers proto.
337 precision_config: a serialized xla::PrecisionConfig proto.
338 preferred_element_type: The type of the tensor.
339 )doc");
340 
341 REGISTER_OP("XlaSetBound")
342     .Input("input: int32")
343     .Input("bound: int32")
344     .Output("output: int32")
345     .SetShapeFn(shape_inference::UnknownShape)
346     .Doc(
347         R"doc(Set a bound for the given input value as a hint to Xla compiler,
348         returns the same value.
349 )doc");
350 
351 REGISTER_OP("XlaSetDynamicDimensionSize")
352     .Input("input: T")
353     .Input("dim_index: int32")
354     .Input("size: int32")
355     .Output("output: T")
356     .Attr("T: type")
357     // Use unknown shape to prevent constant folding.
358     .SetShapeFn(shape_inference::UnknownShape)
359     .Doc(
360         R"doc(Make a static dimension into a xla bounded dynamic dimension.
361         The current static dimension size will become the bound and the second
362         operand becomes the dynamic size of the dimension.)doc");
363 
364 REGISTER_OP("XlaRemoveDynamicDimensionSize")
365     .Input("input: T")
366     .Input("dim_index: int32")
367     .Output("output: T")
368     .Attr("T: type")
369     // Use unknown shape to prevent constant folding.
370     .SetShapeFn(shape_inference::UnknownShape)
371     .Doc(R"doc(
372 Inverse of XlaSetDynamicDimensionSize.
373 
374 Make an xla bounded dynamic dimension into a static dimension. The bound of the
375 size of dimension `dim_index` becomes the static dimension size.
376 )doc");
377 
378 REGISTER_OP("XlaDynamicSlice")
379     .Input("input: T")
380     .Input("start_indices: Tindices")
381     .Input("size_indices: Tindices")
382     .Output("output: T")
383     .Attr("T: type")
384     .Attr("Tindices: {int32, int64}")
__anonb720852f0202(shape_inference::InferenceContext* c) 385     .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
386       shape_inference::ShapeHandle size_indices_shape = c->input(2);
387       if (!c->RankKnown(size_indices_shape)) {
388         return UnchangedRank(c);
389       }
390       if (c->Rank(size_indices_shape) != 1) {
391         return errors::InvalidArgument("size_indices must be a 1D tensor");
392       }
393       shape_inference::ShapeHandle size_indices_value;
394       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &size_indices_value));
395       if (!c->RankKnown(size_indices_value)) {
396         // If we cannot tell the rank of the output from the value of
397         // size_indices, perhaps we can find it from the rank of first operand.
398         return UnchangedRank(c);
399       }
400       c->set_output(0, size_indices_value);
401       return OkStatus();
402     })
403     .Doc(R"doc(
404 Wraps the XLA DynamicSlice operator, documented at
405  https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice
406 .
407 
408 DynamicSlice extracts a sub-array from the input array at dynamic
409 start_indices. The size of the slice in each dimension is passed in
410 size_indices, which specify the end point of exclusive slice intervals in each
411 dimension -- [start, start + size). The shape of start_indices must have rank 1,
412 with dimension size equal to the rank of operand.
413 
414 input: A `Tensor` of type T.
415 
416 start_indices: Rank 1 tensor of N integers containing the starting indices of
417   the slice for each dimension. Value must be greater than or equal to zero.
418 
419 start_indices: List of N integers containing the slice size for each
420   dimension. Each value must be strictly greater than zero, and start + size
421   must be less than or equal to the size of the dimension to avoid
422   implementation defined behavior.
423 )doc");
424 
425 REGISTER_OP("XlaDynamicUpdateSlice")
426     .Input("input: T")
427     .Input("update: T")
428     .Input("indices: Tindices")
429     .Output("output: T")
430     .Attr("T: type")
431     .Attr("Tindices: {int32, int64}")
432     .SetShapeFn(shape_inference::UnchangedShape)
433     .Doc(R"doc(
434 Wraps the XLA DynamicUpdateSlice operator, documented at
435  https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice
436 .
437 
438 XlaDynamicUpdateSlice generates a result which is the value of the `input`
439 operand, with a slice update overwritten at `indices`. The shape of `update`
440 determines the shape of the sub-array of the result which is updated. The shape
441 of indices must be rank == 1, with dimension size equal to the rank of `input`.
442 
443 Handling of out-of-bounds slice indices is implementation-defined.
444 
445 input: A `Tensor` of type T.
446 indices: A vector of indices into `input`. Must have length equal to the rank of
447   `input`.
448 update: A `Tensor` of type T. Same rank as `input`.
449 output: A `Tensor` of type T.
450 )doc");
451 
452 // TODO(b/37549631) setting the If Op to always be stateful is too
453 // conservative.
454 REGISTER_OP("XlaIf")
455     .Input("cond: Tcond")
456     .Input("inputs: Tin")
457     .Output("output: Tout")
458     .Attr("Tcond: type")
459     .Attr("then_branch: func")
460     .Attr("else_branch: func")
461     .Attr("Tin: list(type) >= 0")
462     .Attr("Tout: list(type) >= 0")
463     .SetIsStateful()
464     .SetShapeFn(shape_inference::UnknownShape)
465     .Doc(R"doc(
466 output = cond ? then_branch(inputs) : else_branch(inputs).
467 
468 cond: A boolean scalar.
469 inputs: A list of input tensors.
470 output: A list of tensors returned by either then_branch(inputs) or
471         else_branch(inputs). The input shapes of the then_branch and
472         else_branch must match.
473 then_branch: A function takes 'inputs' and returns a list of tensors,
474              whose types are the same as what else_branch returns.
475 else_branch: A function takes 'inputs' and returns a list of tensors.
476              whose types are the same as what then_branch returns.
477 )doc");
478 
479 REGISTER_OP("XlaPad")
480     .Input("input: T")
481     .Input("padding_value: T")
482     .Input("padding_low: Tindices")
483     .Input("padding_high: Tindices")
484     .Input("padding_interior: Tindices")
485     .Output("output: T")
486     .Attr("T: type")
487     .Attr("Tindices: {int32, int64}")
__anonb720852f0302(shape_inference::InferenceContext* c) 488     .SetShapeFn([](shape_inference::InferenceContext* c) {
489       shape_inference::ShapeHandle input_shape_handle = c->input(0);
490       if (!c->RankKnown(input_shape_handle)) {
491         return UnchangedRank(c);
492       }
493       const int32_t op_rank = c->Rank(input_shape_handle);
494 
495       shape_inference::ShapeHandle padding_shape_handle = c->input(1);
496       if (c->RankKnown(padding_shape_handle) &&
497           c->Rank(padding_shape_handle) != 0) {
498         return errors::InvalidArgument(
499             "padding_value input must be scalar, found rank ",
500             c->Rank(padding_shape_handle));
501       }
502       const Tensor* padding_low_tensor = c->input_tensor(2);
503       const Tensor* padding_high_tensor = c->input_tensor(3);
504       const Tensor* padding_interior_tensor = c->input_tensor(4);
505       if (padding_low_tensor == nullptr || padding_high_tensor == nullptr ||
506           padding_interior_tensor == nullptr) {
507         return UnchangedRank(c);
508       }
509 
510       if (padding_low_tensor->shape().dims() != 1 ||
511           padding_low_tensor->shape().dim_size(0) != op_rank) {
512         return errors::InvalidArgument(
513             "padding_low must be a 1D tensor of size ", op_rank);
514       }
515       if (padding_high_tensor->shape().dims() != 1 ||
516           padding_high_tensor->shape().dim_size(0) != op_rank) {
517         return errors::InvalidArgument(
518             "padding_high must be a 1D tensor of size ", op_rank);
519       }
520       if (padding_interior_tensor->shape().dims() != 1 ||
521           padding_interior_tensor->shape().dim_size(0) != op_rank) {
522         return errors::InvalidArgument(
523             "padding_interior must be a 1D tensor of size ", op_rank);
524       }
525       std::vector<shape_inference::DimensionHandle> output_dims;
526       output_dims.reserve(op_rank);
527       for (int64_t i = 0; i < op_rank; ++i) {
528         int64_t low, high, interior;
529         TF_RETURN_IF_ERROR(c->GetScalarFromTensor(padding_low_tensor, i, &low));
530         TF_RETURN_IF_ERROR(
531             c->GetScalarFromTensor(padding_high_tensor, i, &high));
532         TF_RETURN_IF_ERROR(
533             c->GetScalarFromTensor(padding_interior_tensor, i, &interior));
534         if (interior < 0) {
535           return errors::InvalidArgument(
536               "padding_interior must contain only non-negative values, found ",
537               interior);
538         }
539 
540         shape_inference::DimensionHandle orig_size_handle =
541             c->Dim(input_shape_handle, i);
542         if (c->ValueKnown(orig_size_handle)) {
543           auto orig_dim = c->Value(orig_size_handle);
544           int64_t new_dim = orig_dim + low + high;
545           if (orig_dim > 0) {
546             new_dim += interior * (orig_dim - 1);
547           }
548           if (new_dim < 0) {
549             return errors::InvalidArgument(
550                 "resulting padded dimension has negative size ", new_dim);
551           }
552           output_dims.emplace_back(c->MakeDim(new_dim));
553         } else {
554           output_dims.emplace_back(c->UnknownDim());
555         }
556       }
557 
558       c->set_output(0, c->MakeShape(output_dims));
559       return OkStatus();
560     })
561     .Doc(R"doc(
562 Wraps the XLA Pad operator, documented at
563  https://www.tensorflow.org/performance/xla/operation_semantics#pad
564 .
565 
566 input: A `Tensor` of type T.
567 padding_value: A scalar `Tensor` of type T.
568 padding_low: the padding to apply at the start of each input dimensions. Must
569   be a compile-time constant 1D tensor of length equal to rank of input.
570 padding_high: the padding to apply at the end of each input dimension. Must
571   be a compile-time constant 1D tensor of length equal to rank of input.
572 padding_interior: the padding to apply between each input element. Must
573   be a compile-time constant 1D tensor of length equal to rank of input,
574   containing only non-negative values.
575 output: A `Tensor` of type T.
576 )doc");
577 
578 REGISTER_OP("XlaRecv")
579     .Output("tensor: dtype")
580     .Attr("dtype: type")
581     .Attr("tensor_name: string")
582     .Attr("shape: shape")
583     .SetIsStateful()
__anonb720852f0402(shape_inference::InferenceContext* c) 584     .SetShapeFn([](shape_inference::InferenceContext* c) {
585       TensorShape shape_attr;
586       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
587       shape_inference::ShapeHandle s;
588       TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
589       c->set_output(0, s);
590       return OkStatus();
591     })
592     .Doc(R"doc(
593 Receives the named tensor from another XLA computation. Wraps the XLA Recv
594 operator documented at
595  https://www.tensorflow.org/performance/xla/operation_semantics#recv .
596 
597 tensor: The tensor to receive.
598 dtype: The type of the tensor.
599 tensor_name: A string key that identifies the channel.
600 shape: The shape of the tensor.
601 )doc");
602 
603 REGISTER_OP("XlaReduce")
604     .Input("input: T")
605     .Input("init_value: T")
606     .Attr("T: {numbertype, bool}")
607     .Attr("dimensions_to_reduce: list(int)")
608     .Attr("reducer: func")
609     .Output("output: T")
__anonb720852f0502(shape_inference::InferenceContext* c) 610     .SetShapeFn([](shape_inference::InferenceContext* c) {
611       if (c->RankKnown(c->input(0))) {
612         int rank = c->Rank(c->input(0));
613         std::vector<int64_t> dimensions_to_reduce;
614         TF_RETURN_IF_ERROR(
615             c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
616         std::set<int64_t> dims_set(dimensions_to_reduce.begin(),
617                                    dimensions_to_reduce.end());
618         auto dim_in_range = [rank](int64_t dim) {
619           return dim >= 0 && dim < rank;
620         };
621         const int dimensions_to_reduce_size = dimensions_to_reduce.size();
622         if (rank < dimensions_to_reduce_size ||
623             dims_set.size() != dimensions_to_reduce.size() ||
624             !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
625           return errors::InvalidArgument(
626               "Invalid dimensions_to_reduce argument to XlaReduce");
627         }
628         c->set_output(
629             0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
630       } else {
631         c->set_output(0, c->input(0));
632       }
633       return OkStatus();
634     })
635     .Doc(R"doc(
636 Wraps the XLA Reduce operator, documented at
637  https://www.tensorflow.org/performance/xla/operation_semantics#reduce .
638 
639 input: the input tensor
640 init_value: a scalar representing the initial value for the reduction
641 reducer: a reducer function to apply
642 dimensions_to_reduce: dimension numbers over which to reduce
643 )doc");
644 
645 REGISTER_OP("XlaVariadicReduce")
646     .Input("input: N * T")
647     .Input("init_value: N * T")
648     .Attr("N: int >= 1")
649     .Attr("T: {numbertype, bool}")
650     .Attr("dimensions_to_reduce: list(int)")
651     .Attr("reducer: func")
652     .Output("output: N * T")
__anonb720852f0702(shape_inference::InferenceContext* c) 653     .SetShapeFn([](shape_inference::InferenceContext* c) {
654       int n;
655       TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
656       for (int i = 0; i < n; i++) {
657         for (int j = 0; j < n; j++) {
658           c->MergeInput(i, c->input(j));
659         }
660       }
661       if (c->RankKnown(c->input(0))) {
662         int rank = c->Rank(c->input(0));
663         std::vector<int64_t> dimensions_to_reduce;
664         TF_RETURN_IF_ERROR(
665             c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
666         std::set<int64_t> dims_set(dimensions_to_reduce.begin(),
667                                    dimensions_to_reduce.end());
668         auto dim_in_range = [rank](int64_t dim) {
669           return dim >= 0 && dim < rank;
670         };
671         const int dimensions_to_reduce_size = dimensions_to_reduce.size();
672         if (rank < dimensions_to_reduce_size ||
673             dims_set.size() != dimensions_to_reduce.size() ||
674             !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
675           return errors::InvalidArgument(
676               "Invalid dimensions_to_reduce argument to XlaVariadicReduce");
677         }
678         for (int i = 0; i < n; i++) {
679           c->set_output(
680               i, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
681         }
682       } else {
683         for (int i = 0; i < n; i++) {
684           c->set_output(i, c->input(i));
685         }
686       }
687       return OkStatus();
688     })
689     .Doc(R"doc(
690 Wraps the variadic XLA Reduce operator.
691 
692 Semantics are documented at
693  https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce.
694 
695 This version is limited to operands of the same dtype.
696 XlaVariadicReduceV2 is a version that supports heterogeneous operands.
697 
698 input: the input tensor(s)
699 init_value: scalar initial value(s) for the reduction
700 reducer: a reducer function to apply
701 dimensions_to_reduce: dimension numbers over which to reduce
702 )doc");
703 
704 REGISTER_OP("XlaVariadicReduceV2")
705     .Input("inputs: T")
706     .Input("init_values: T")
707     .Attr("T: list(type) >= 1")
708     .Attr("dimensions_to_reduce: list(int)")
709     .Attr("reducer: func")
710     .Output("outputs: T")
__anonb720852f0902(shape_inference::InferenceContext* c) 711     .SetShapeFn([](shape_inference::InferenceContext* c) {
712       std::vector<shape_inference::ShapeHandle> input_shapes;
713       TF_RETURN_IF_ERROR(c->input("inputs", &input_shapes));
714       std::vector<shape_inference::ShapeHandle> init_values_shapes;
715       TF_RETURN_IF_ERROR(c->input("init_values", &init_values_shapes));
716       const int nr_inputs = input_shapes.size();
717       if (nr_inputs != init_values_shapes.size()) {
718         return errors::InvalidArgument(
719             "Must specify the same number of inputs and init_values. ", "Got ",
720             nr_inputs, " and ", init_values_shapes.size());
721       }
722       if (nr_inputs == 0) {
723         return errors::InvalidArgument("Must specify at least one input");
724       }
725 
726       shape_inference::ShapeHandle input_shape = input_shapes[0];
727       for (int i = 1; i < nr_inputs; ++i) {
728         shape_inference::ShapeHandle merged;
729         TF_RETURN_WITH_CONTEXT_IF_ERROR(
730             c->Merge(input_shape, input_shapes[i], &merged),
731             "All inputs must have the same shape. Input ", i,
732             " (zero-based) has shape ", c->DebugString(input_shapes[i]),
733             " incompatible with the shape ", "inferred from previous inputs ",
734             c->DebugString(input_shape));
735         input_shape = merged;
736       }
737       // All outputs have the same shape
738       shape_inference::ShapeHandle output_shape = c->UnknownShape();
739 
740       if (c->RankKnown(input_shape)) {
741         int rank = c->Rank(input_shape);
742 
743         std::vector<int64_t> dimensions_to_reduce;
744         TF_RETURN_IF_ERROR(
745             c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
746         std::set<int64_t> dims_set(dimensions_to_reduce.begin(),
747                                    dimensions_to_reduce.end());
748 
749         auto dim_in_range = [rank](int64_t dim) {
750           return dim >= 0 && dim < rank;
751         };
752         const int dimensions_to_reduce_size = dimensions_to_reduce.size();
753         if (rank < dimensions_to_reduce_size ||
754             dims_set.size() != dimensions_to_reduce.size() ||
755             !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
756           return errors::InvalidArgument(
757               "Invalid dimensions_to_reduce argument to XlaVariadicReduceV2");
758         }
759 
760         std::vector<shape_inference::DimensionHandle> output_dims;
761         for (int64_t i = 0; i < rank; ++i) {
762           if (dims_set.find(i) == dims_set.end()) {
763             output_dims.emplace_back(c->Dim(input_shape, i));
764           }
765         }
766         output_shape = c->MakeShape(output_dims);
767       }
768       for (int i = 0; i < nr_inputs; ++i) {
769         c->set_output(i, output_shape);
770       }
771       return OkStatus();
772     })
773     .Doc(R"doc(
774 Wraps the variadic XLA Reduce operator.
775 
776 Semantics are documented at
777  https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce.
778 
779 This is an expanded version of XlaVariadicReduce, with support for
780 operands of different dtypes, and improved shape inference.
781 
782 inputs: the input tensor(s)
783 init_values: scalar initial value(s) for the reduction
784 reducer: a reducer function to apply
785 dimensions_to_reduce: dimension numbers over which to reduce
786 )doc");
787 
788 REGISTER_OP("XlaReduceWindow")
789     .Input("input: T")
790     .Input("init_value: T")
791     .Input("window_dimensions: Tindices")
792     .Input("window_strides: Tindices")
793     .Input("base_dilations: Tindices")
794     .Input("window_dilations: Tindices")
795     .Input("padding: Tindices")
796     .Attr("T: {numbertype, bool}")
797     .Attr("Tindices: {int32, int64}")
798     .Attr("computation: func")
799     .Output("output: T")
800     .SetShapeFn(UnchangedRank)
801     .Doc(R"doc(
802 Wraps the XLA ReduceWindow operator, documented at
803  https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
804 
805 input: the input tensor
806 init_value: a scalar representing the initial value for the reduction
807 computation: a reducer function to apply
808 window_dimensions: the shape of the window
809 window_strides: the inter-window strides
810 padding: the padding to apply at the start and end of each input dimensions
811 )doc");
812 
813 REGISTER_OP("XlaRngBitGenerator")
814     .Input("algorithm: int32")
815     .Input("initial_state: uint64")
816     .Input("shape: Tshape")
817     .Output("output_key: uint64")
818     .Output("output: dtype")
819     .Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64")
820     .Attr("Tshape: {int32, int64} = DT_INT32")
__anonb720852f0b02(shape_inference::InferenceContext* c) 821     .SetShapeFn([](shape_inference::InferenceContext* c) {
822       shape_inference::ShapeHandle algorithm;
823       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &algorithm));
824       shape_inference::ShapeHandle initial_state;
825       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &initial_state));
826 
827       c->set_output(0, initial_state);
828       shape_inference::ShapeHandle output;
829       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output));
830       c->set_output(1, output);
831       return OkStatus();
832     })
833     .Doc(R"doc(
834 Stateless PRNG bit generator.
835 Wraps the XLA RngBitGenerator operator, documented at
836  https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator.
837 
838 algorithm: The PRNG algorithm to use, one of
839   tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}.
840 initial_state: Initial state for the PRNG algorithm. For THREEFRY, it should be
841   a u64[2] and for PHILOX a u64[3].
842 shape: The output shape of the generated data.
843 dtype: The type of the tensor.
844 )doc");
845 
846 REGISTER_OP("XlaSelectAndScatter")
847     .Input("operand: T")
848     .Input("window_dimensions: Tindices")
849     .Input("window_strides: Tindices")
850     .Input("padding: Tindices")
851     .Input("source: T")
852     .Input("init_value: T")
853     .Attr("T: numbertype")
854     .Attr("Tindices: {int32, int64}")
855     .Attr("select: func")
856     .Attr("scatter: func")
857     .Output("output: T")
858     .SetShapeFn(UnchangedRank)
859     .Doc(R"doc(
860 Wraps the XLA SelectAndScatter operator, documented at
861  https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter
862 .
863 
864 operand: the input tensor
865 window_dimensions: the shape of the window
866 window_strides: the inter-window strides
867 padding: the padding to apply at the start and end of each input dimensions
868 source: a tensor of values to scatter
869 init_value: a scalar representing the initial value for the output tensor
870 select: a selection function to apply
871 scatter: a scatter function to apply
872 )doc");
873 
874 REGISTER_OP("XlaSend")
875     .Input("tensor: T")
876     .Attr("T: type")
877     .Attr("tensor_name: string")
878     .SetIsStateful()
879     .SetShapeFn(shape_inference::UnknownShape)
880     .Doc(R"doc(
881 Sends the named tensor to another XLA computation. Wraps the XLA Send operator
882 documented at
883  https://www.tensorflow.org/performance/xla/operation_semantics#send .
884 
885 tensor: The tensor to send.
886 tensor_name: A string key that identifies the channel.
887 )doc");
888 
889 REGISTER_OP("XlaSort")
890     .Input("input: T")
891     .Output("output: T")
892     .Attr("T: type")
893     .SetShapeFn(shape_inference::UnchangedShape)
894     .Doc(R"doc(
895 Wraps the XLA Sort operator, documented at
896  https://www.tensorflow.org/performance/xla/operation_semantics#sort
897 .
898 
899 Sorts a tensor. Currently only sorts in ascending order are supported.
900 
901 input: A `Tensor` of type T.
902 output: A `Tensor` of type T.
903 )doc");
904 
905 REGISTER_OP("XlaKeyValueSort")
906     .Input("keys: K")
907     .Input("values: V")
908     .Output("sorted_keys: K")
909     .Output("sorted_values: V")
910     .Attr("K: realnumbertype")
911     .Attr("V: type")
__anonb720852f0c02(shape_inference::InferenceContext* c) 912     .SetShapeFn([](shape_inference::InferenceContext* c) {
913       c->set_output(0, c->input(0));
914       c->set_output(1, c->input(1));
915       return OkStatus();
916     })
917     .Doc(R"doc(
918 Wraps the XLA Sort operator, documented at
919  https://www.tensorflow.org/performance/xla/operation_semantics#sort
920 .
921 
922 Sorts a tensor. Currently only sorts in ascending order are supported.
923 
924 keys: A `Tensor` of type K.
925 values: A `Tensor` of type V.
926 sorted_keys: A `Tensor` of type K.
927 sorted_values: A `Tensor` of type V.
928 )doc");
929 
930 REGISTER_OP("XlaVariadicSort")
931     .Input("inputs: T")
932     .Input("dimension: int32")
933     .Output("outputs: T")
934     .Attr("T: list(type) >= 1")
935     .Attr("comparator: func")
936     .Attr("is_stable: bool")
__anonb720852f0d02(shape_inference::InferenceContext* c) 937     .SetShapeFn([](shape_inference::InferenceContext* c) {
938       std::vector<shape_inference::ShapeHandle> input_shapes;
939       TF_RETURN_IF_ERROR(c->input("inputs", &input_shapes));
940       TF_RETURN_IF_ERROR(c->set_output("outputs", input_shapes));
941       return OkStatus();
942     })
943     .Doc(R"doc(
944 Wraps the XLA Sort operator, documented at
945  https://www.tensorflow.org/performance/xla/operation_semantics#sort
946 .
947 
948 Sorts one or more tensors, with support for custom comparator, dimension, and
949 is_stable attributes.
950 
951 inputs: A list of `Tensor` of identical shape but possibly different types.
952 dimension: The dimension along which to sort. Must be a compile-time constant.
953 is_stable: Whether to use stable sort.
954 comparator: A comparator function to apply to 2*N scalars and returning a
955   boolean. N is the number of sort inputs. If you want to sort in ascending
956   order then the comparator should perform a less-than comparison.
957 outputs: A list of `Tensor` of same shape and types as the `input`.
958 )doc");
959 
960 // TODO(b/37549631) setting the While Op to always be stateful is too
961 // conservative.
962 REGISTER_OP("XlaWhile")
963     .Input("input: T")
964     .Output("output: T")
965     .Attr("T: list(type) >= 0")
966     .Attr("cond: func")
967     .Attr("body: func")
968     .SetIsStateful()
969     .SetShapeFn(shape_inference::UnknownShape)
970     .Doc(R"doc(
971 output = input; While (Cond(output)) { output = Body(output) }
972 
973 input: A list of input tensors whose types are T.
974 output: A list of output tensors whose types are T.
975 cond: A function takes 'input' and returns a tensor.  If the tensor is
976       a scalar of non-boolean, the scalar is converted to a boolean
977       according to the following rule: if the scalar is a numerical
978       value, non-zero means True and zero means False; if the scalar is
979       a string, non-empty means True and empty means False. If the
980       tensor is not a scalar, non-emptiness means True and False
981       otherwise.
982 body: A function that takes a list of tensors and returns another
983       list of tensors. Both lists have the same types as specified by T.
984 )doc");
985 
986 REGISTER_OP("XlaDequantize")
987     .Input("input: uint32")
988     .Output("output: bfloat16")
989     .Attr("min_range: float")
990     .Attr("max_range: float")
991     .Attr("mode: string")
992     .Attr("transpose_output: bool")
993     .SetIsStateful()
994     .SetShapeFn(shape_inference::UnknownShape)
995     .Doc(R"doc(
996 Takes the packed uint32 input and unpacks the input to uint8 to do
997 Dequantization on device.
998 
999 input: Input tensors whose types is uint32, shape is [d0, ..., dn].
1000 output: Output tensors whose types is bloat16. If transpose_output is true,
1001      output shape is [dn * 4, dn-1, ..., d1, d0]. If transpose_output
1002      is false, output shape is [d0,..., dn * 4].
1003 min_range: The minimum scalar value possibly produced for the input.
1004 max_range: The maximum scalar value possibly produced for the input.
1005 mode: String to determine the dequantize mode in {"MIN_COMBINED", "MIN_FIRST", "SCALED"}.
1006 transpose_output: Boolean to determine if output is transposed. transpose_output
1007      is faster when input is large and rank of input is higher than 1.
1008 )doc");
1009 
1010 REGISTER_OP("XlaEinsum")
1011     .Input("a: T")
1012     .Input("b: T")
1013     .Output("product: T")
1014     .Attr("equation: string")
1015     .Attr("T: {complex64, bfloat16, float}")
__anonb720852f0e02(shape_inference::InferenceContext* context) 1016     .SetShapeFn([](shape_inference::InferenceContext* context) {
1017       string equation;
1018       TF_RETURN_IF_ERROR(context->GetAttr("equation", &equation));
1019       // XlaEinsum supports only two-input einsum equations.
1020       if (!absl::StrContains(equation, ",")) {
1021         return errors::InvalidArgument("Expected one \",\" in equation. Got: ",
1022                                        equation);
1023       }
1024       // Use EinsumShape for the rest of the inference now that we know we must
1025       // have a two-input einsum.
1026       return shape_inference::EinsumShape(context);
1027     })
1028     .Doc(R"doc(
1029 An op which supports basic einsum op with 2 inputs and 1 output.
1030 
1031 This op has better TPU performance since it doesn't have explicitly reshape and
1032 transpose operations as tf.einsum does.
1033 )doc");
1034 
1035 REGISTER_OP("XlaSpmdFullToShardShape")
1036     .Input("input: T")
1037     .Output("output: T")
1038     .Attr("T: type")
1039     .Attr("manual_sharding: string")
1040     .Attr("dim: int = -1")
1041     .Attr("unspecified_dims: list(int) = []")
__anonb720852f0f02(shape_inference::InferenceContext* c) 1042     .SetShapeFn([](shape_inference::InferenceContext* c) {
1043       auto input_handle = c->input(0);
1044       if (!c->RankKnown(input_handle)) {
1045         return shape_inference::UnknownShape(c);
1046       }
1047       string sharding_attr;
1048       TF_RETURN_IF_ERROR(c->GetAttr("manual_sharding", &sharding_attr));
1049       int32 single_dim;
1050       TF_RETURN_IF_ERROR(c->GetAttr("dim", &single_dim));
1051       xla::OpSharding sharding;
1052       sharding.ParseFromString(sharding_attr);
1053       if (sharding.type() != xla::OpSharding::OTHER) {
1054         return shape_inference::UnchangedShape(c);
1055       }
1056       std::vector<shape_inference::DimensionHandle> dims;
1057       for (int64_t i = 0; i < c->Rank(input_handle); ++i) {
1058         auto dim = c->Value(c->Dim(input_handle, i));
1059         if (single_dim < 0 || single_dim == i) {
1060           int64_t partitions_i = sharding.tile_assignment_dimensions(i);
1061           if (dim != shape_inference::InferenceContext::kUnknownDim &&
1062               partitions_i != 1) {
1063             dim = (dim + partitions_i - 1) / partitions_i;
1064           }
1065         }
1066         dims.push_back(c->MakeDim(dim));
1067       }
1068       c->set_output(0, c->MakeShape(dims));
1069       return OkStatus();
1070     })
1071     .Doc(R"doc(
1072 An op used by XLA SPMD partitioner to switch from automatic partitioning to
1073 manual partitioning. It annotates the input (full-shape, to be automatically
1074 partitioned) with the same sharding used by manual partitioning, and outputs a
1075 shard-shaped tensor to be consumed by later manually-partitioned ops. If the
1076 shape is not evenly partitionable, the padding region will be masked with 0s.
1077 The conversion can happen partially in subgroups, by specifying the dim
1078 attribute, where only that dim will be converted.
1079 )doc");
1080 
1081 REGISTER_OP("XlaSpmdShardToFullShape")
1082     .Input("input: T")
1083     .Output("output: T")
1084     .Attr("T: type")
1085     .Attr("manual_sharding: string")
1086     .Attr("full_shape: shape")
1087     .Attr("dim: int = -1")
1088     .Attr("unspecified_dims: list(int) = []")
__anonb720852f1002(shape_inference::InferenceContext* c) 1089     .SetShapeFn([](shape_inference::InferenceContext* c) {
1090       TensorShape shape_attr;
1091       TF_RETURN_IF_ERROR(c->GetAttr("full_shape", &shape_attr));
1092       shape_inference::ShapeHandle s;
1093       TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
1094       c->set_output(0, s);
1095       return OkStatus();
1096     })
1097     .Doc(R"doc(
1098 An op used by XLA SPMD partitioner to switch from manual partitioning to
1099 automatic partitioning. It converts the shard-shaped, manually partitioned input
1100 into full-shaped tensor to be partitioned automatically with the same sharding
1101 used by manual partitioning. The conversion can happen partially in subgroups,
1102 by specifying the dim attribute, where only that dim will be converted.
1103 )doc");
1104 
1105 REGISTER_OP("XlaSharding")
1106     .Input("input: T")
1107     .Output("output: T")
1108     .Attr("T: type")
1109     .Attr("sharding: string = ''")
1110     .Attr("unspecified_dims: list(int) = []")
1111     .SetShapeFn(shape_inference::UnchangedShape)
1112     .Doc(R"doc(
1113 An op which shards the input based on the given sharding attribute. It can
1114 selectively annotate a subset of tensor dimensions by skipping unspecified_dims,
1115 and the sharding annotation should be replicated in those dims.
1116 )doc");
1117 
1118 REGISTER_OP("XlaReplicaId")
1119     .Output("id: int32")
__anonb720852f1102(shape_inference::InferenceContext* context) 1120     .SetShapeFn([](shape_inference::InferenceContext* context) {
1121       context->set_output(0, context->MakeShape({}));
1122       return OkStatus();
1123     })
1124     .Doc("Replica ID.");
1125 
1126 REGISTER_OP("XlaGather")
1127     .Input("operand: T")
1128     .Input("start_indices: Tindices")
1129     .Input("slice_sizes: Tindices")
1130     .Attr("dimension_numbers: string")
1131     .Attr("indices_are_sorted: bool")
1132     .Attr("T: {numbertype, bool}")
1133     .Attr("Tindices: {int32, int64}")
1134     .Output("output: T")
1135     .SetShapeFn(shape_inference::UnknownShape)
1136     .Doc(R"doc(
1137 Wraps the XLA Gather operator documented at
1138   https://www.tensorflow.org/xla/operation_semantics#gather
1139 operand: The array we're gathering from.
1140 start_indices: Array containing the starting indices of the slices we gather.
1141 dimension_numbers: A serialized xla::GatherDimensionNumbers proto.
1142 slice_sizes: slice_sizes[i] is the bounds for the slice on dimension i.
1143 indices_are_sorted: Boolean indicating if the indices are sorted.
1144 )doc");
1145 
1146 REGISTER_OP("XlaScatter")
1147     .Input("operand: T")
1148     .Input("scatter_indices: Tindices")
1149     .Input("updates: T")
1150     .Attr("update_computation: func")
1151     .Attr("dimension_numbers: string")
1152     .Attr("indices_are_sorted: bool")
1153     .Attr("T: {numbertype, bool}")
1154     .Attr("Tindices: {int32, int64}")
1155     .Output("output: T")
1156     .SetShapeFn(shape_inference::UnchangedShape)
1157     .Doc(R"doc(
1158 Wraps the XLA Scatter operator documented at
1159   https://www.tensorflow.org/xla/operation_semantics#scatter.
1160 
1161 operand: Array to be scattered into.
1162 scatter_indices: Array containing the starting indices of the slices that must
1163   be scattered to.
1164 updates: Array containing the values that must be used for scattering.
1165 update_computation: Computation to be used for combining the existing values in
1166   the input array and the updates during scatter.
1167 dimension_numbers: A serialized xla::ScatterDimensionNumbers proto.
1168 indices_are_sorted: Boolean indicating if the indices are sorted.
1169 )doc");
1170 
1171 REGISTER_OP("XlaAllReduce")
1172     .Input("input: T")
1173     .Input("group_assignment: int32")
1174     .Output("output: T")
1175     .Attr("T: {half, bfloat16, float, int32, uint32}")
1176     .Attr("reduce_op: {'Min', 'Max', 'Mul', 'Add', 'Mean'}")
1177     .Attr("mode: {'CrossReplica', 'CrossReplicaAndPartition'}")
1178     .SetShapeFn(shape_inference::UnchangedShape)
1179     .Doc(R"doc(
1180 Wraps the XLA AllReduce operator
1181   documented at https://www.tensorflow.org/xla/operation_semantics#allreduce.
1182 
1183 input: Array or a non-empty tuple of arrays to reduce across replicas.
1184 group_assignment: Groups between which the reductions are performed.
1185 reduce_op: Reduction computation.
1186 mode: group mode.
1187   CrossReplica: group_assignment contains replica_id. Each group contains the
1188     replicas for the current partition.
1189   CrossReplicaAndPartition: group_assignment contains replica_id. Each group
1190     contains the replicas for all partitions.
1191 )doc");
1192 
1193 REGISTER_OP("XlaReduceScatter")
1194     .Input("input: T")
1195     .Input("group_assignment: int32")
1196     .Input("scatter_dimension: int32")
1197     .Output("output: T")
1198     .Attr("T: {half, bfloat16, float, int32, uint32}")
1199     .Attr("reduce_op: {'Min', 'Max', 'Mul', 'Add', 'Mean'}")
1200     .SetShapeFn(shape_inference::ReduceScatterShape)
1201     .Doc(R"doc(
1202 Wraps the XLA ReduceScatter operator
1203   documented at https://www.tensorflow.org/xla/operation_semantics#reducescatter.
1204 
1205 input: Array or a non-empty tuple of arrays to reduce across replicas.
1206 group_assignment: Groups between which the reductions are performed.
1207 scatter_dimension: Dimension to scatter.
1208 reduce_op: Reduction computation.
1209 )doc");
1210 
OptimizationBarrierShape(shape_inference::InferenceContext * c)1211 Status OptimizationBarrierShape(shape_inference::InferenceContext* c) {
1212   for (int i = 0; i < c->num_inputs(); ++i) {
1213     c->set_output(i, c->input(i));
1214   }
1215   return OkStatus();
1216 }
1217 
1218 REGISTER_OP("XlaOptimizationBarrier")
1219     .Input("input: T")
1220     .Output("output: T")
1221     .Attr("T: list(type) >= 0")
1222     .SetShapeFn(OptimizationBarrierShape)
1223     .Doc(R"doc(
1224 Wraps the XLA OptimizationBarrier operator.
1225 
1226 Documented at https://www.tensorflow.org/xla/operation_semantics#optimizationbarrier.
1227 
1228 input: A Tuple of Arrays of any type.
1229 )doc");
1230 
1231 REGISTER_OP("XlaReducePrecision")
1232     .Input("operand: T")
1233     .Output("output: T")
1234     .Attr("T: {bfloat16, half, float, double}")
1235     .Attr("exponent_bits: int")
1236     .Attr("mantissa_bits: int")
1237     .SetShapeFn(shape_inference::UnchangedShape)
1238     .Doc(R"doc(
1239 Wraps the XLA ReducePrecision operator
1240   documented at https://www.tensorflow.org/xla/operation_semantics#reduceprecision.
1241 
1242 operand: array of floating-point type.
1243 exponent_bits: number of exponent bits in lower-precision format
1244 mantissa_bits: number of mantissa bits in lower-precision format
1245 )doc");
1246 
1247 REGISTER_OP("XlaCustomCall")
1248     .Input("args: T")
1249     .Output("output: dtype")
1250     .Attr("target_name: string")
1251     .Attr("backend_config: string")
1252     .Attr("T: list(type) >= 0")
1253     .Attr("dtype: type")
1254     .Attr("shape: shape")
__anonb720852f1202(shape_inference::InferenceContext* c) 1255     .SetShapeFn([](shape_inference::InferenceContext* c) {
1256       TensorShape shape_attr;
1257       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
1258       shape_inference::ShapeHandle s;
1259       TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
1260       c->set_output(0, s);
1261       return OkStatus();
1262     })
1263     .Doc(R"doc(
1264 Wraps the XLA CustomCall operator
1265   documented at https://www.tensorflow.org/xla/operation_semantics#customcall.
1266 
1267 args: A list of `Tensor` with possibly different types.
1268 target_name: Name of the function. A call instruction will be emitted which
1269   targets this symbol name.
1270 backend_config: String, used to encode serialized metadata to the backend.
1271 dtype: Output tensor data type.
1272 shape: Output tensor shape.
1273 )doc");
1274 
1275 REGISTER_OP("XlaCallModule")
1276     .Input("args: Tin")
1277     .Output("output: Tout")
1278     .Attr("module: string")
1279     .Attr("Sout: list(shape) >= 0")
1280     .Attr("Tout: list(type) >= 0")
1281     .Attr("Tin: list(type) >= 0")
1282     .Attr("dim_args_spec: list(string) >= 0")
__anonb720852f1302(shape_inference::InferenceContext* c) 1283     .SetShapeFn([](shape_inference::InferenceContext* c) {
1284       // For debugging
1285       VLOG(3) << "XlaCallModule.shape_inference";
1286       std::vector<shape_inference::ShapeHandle> args_shapes;
1287       TF_RETURN_IF_ERROR(c->input("args", &args_shapes));
1288       for (int i = 0; i < args_shapes.size(); ++i) {
1289         VLOG(3) << "XlaCallModule.shape_inference args[" << i
1290                 << "] : " << c->DebugString(args_shapes[i]);
1291       }
1292       std::vector<PartialTensorShape> shapes_attr;
1293       TF_RETURN_IF_ERROR(c->GetAttr("Sout", &shapes_attr));
1294       for (int i = 0; i < shapes_attr.size(); ++i) {
1295         shape_inference::ShapeHandle s;
1296         TF_RETURN_IF_ERROR(
1297             c->MakeShapeFromPartialTensorShape(shapes_attr[i], &s));
1298         VLOG(3) << "XlaCallModule.shape_inference out[" << i
1299                 << "] : " << c->DebugString(s);
1300         c->set_output(i, s);
1301       }
1302       return OkStatus();
1303     })
1304     .Doc(R"doc(
1305 Temporary op for experimenting with jax2tf.
1306 
1307 DO NOT USE THIS OP. It has no backwards compatibility guarantees. It is also
1308 very likely to change. This op will be used only in jax2tf under an
1309 experimental flag.
1310 
1311 This is an experimental op to allow a smooth evolution of jax2tf towards
1312 emitting and serializing MHLO directly from JAX. At the moment this op
1313 carries a serialized MHLO module, therefore there are no backward-compatibility
1314 guarantees, and should not be used for serialization.
1315 Eventually, the op will carry a MHLO object, which will have
1316 backwards-compatibility guarantees.
1317 
1318 The serialized module must return a tuple if and only if the Sout is an empty
1319 list or a list with more than 1 elements. The length of Tout and Sout must
1320 match. This op always returns a tuple of results, even if the module returns
1321 a single result.
1322 
1323 The handling of dynamic shapes is work-in-progress. At the moment, the
1324 JAX lowering for dynamic shapes will prepend one dimension parameter to the
1325 serialized module for each dimension whose value must be passed in.
1326 The "args" correspond to the non-dimension arguments. During compilation
1327 we compute the values of the dimension arguments based on the static shapes of
1328 the "args". In order to do this, we encode for each dimension argument a
1329 specification of how to compute its value, as a string, in the form
1330 "<arg_idx>.<axis_idx>".
1331 E.g., the specification "2.1" denotes the value args[2].shape[1].
1332 
1333 args: A list of `Tensor` with possibly different types to be passed as arguments
1334   to the HLO module.
1335 module: A serialized computation, a text representation of mlir.Module.
1336 Tout: List of output tensor data types.
1337 Sout: List of output tensor shapes.
1338 dim_args_spec: the specification for the dimension arguments, one for each
1339   dimension argument. In absence of dynamic shapes this list is empty.
1340 )doc");
1341 
1342 }  // namespace
1343 }  // namespace tensorflow
1344