• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstddef>
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_split.h"
22 #include "tensorflow/compiler/xla/xla_data.pb.h"
23 #include "tensorflow/core/framework/common_shape_fns.h"
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/framework/shape_inference.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 
28 // Note: Most of the operators defined in this module are used by the jax2tf
29 // converter (see go/jax2tf for details) and are used in SavedModel produced
30 // by jax2tf. Hence, we need to maintain backwards compatibility for these
31 // operators. Please reach out to the JAX team if you want to make changes.
32 
33 namespace tensorflow {
34 namespace {
35 
36 // Helper shape function for operators that return an output with the same rank
37 // as their first input.
UnchangedRank(shape_inference::InferenceContext * c)38 Status UnchangedRank(shape_inference::InferenceContext* c) {
39   if (c->RankKnown(c->input(0))) {
40     c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
41   } else {
42     c->set_output(0, c->input(0));
43   }
44   return Status::OK();
45 }
46 
47 REGISTER_OP("XlaBroadcastHelper")
48     .Input("lhs: T")
49     .Input("rhs: T")
50     .Input("broadcast_dims: Tindices")
51     .Attr("T: numbertype")
52     .Attr("Tindices: {int32, int64}")
53     .Output("lhs_output: T")
54     .Output("rhs_output: T")
55     .SetShapeFn(shape_inference::UnknownShape)
56     .Doc(R"doc(
57 Helper operator for performing XLA-style broadcasts
58 
59 Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to
60 whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules
61 for binary operators.
62 
63 lhs: the LHS input tensor
64 rhs: the RHS input tensor
65 broadcast_dims: an XLA-style broadcast dimension specification
66 lhs_output: the broadcasted LHS tensor
67 rhs_output: the broadcasted RHS tensor
68 )doc");
69 
70 REGISTER_OP("XlaSelfAdjointEig")
71     .Input("a: T")
72     .Attr("lower: bool")
73     .Attr("max_iter: int")
74     .Attr("epsilon: float")
75     .Output("w: T")
76     .Output("v: T")
77     .SetShapeFn(shape_inference::UnknownShape)
78     .Attr("T: numbertype")
79     .Doc(R"doc(
80 Computes the eigen decomposition of a batch of self-adjoint matrices
81 (Note: Only real inputs are supported).
82 
83 Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in
84 tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for
85 i=0...N-1.
86 
87 a: the input tensor.
88 
89 lower: a boolean specifies whether the calculation is done with the lower
90   triangular part or the upper triangular part.
91 
92 max_iter: maximum number of sweep update, i.e., the whole lower triangular
93   part or upper triangular part based on parameter lower. Heuristically, it has
94   been argued that approximately logN sweeps are needed in practice (Ref: Golub &
95   van Loan "Matrix Computation").
96 
97 epsilon: the tolerance ratio.
98 
99 w: The eigenvalues in ascending order, each repeated according to its
100   multiplicity.
101 v: The column v[..., :, i] is the normalized eigenvector corresponding to the
102   eigenvalue w[..., i].
103 )doc");
104 
105 REGISTER_OP("XlaSvd")
106     .Input("a: T")
107     .Attr("max_iter: int")
108     .Attr("epsilon: float")
109     .Attr("precision_config: string")
110     .Output("s: T")
111     .Output("u: T")
112     .Output("v: T")
113     .SetShapeFn(shape_inference::UnknownShape)
114     .Attr("T: numbertype")
115     .Doc(R"doc(
116 Computes the eigen decomposition of a batch of self-adjoint matrices
117 (Note: Only real inputs are supported).
118 
119 Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in
120 tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]).
121 
122 a: the input tensor.
123 
124 max_iter: maximum number of sweep update, i.e., the whole lower triangular
125   part or upper triangular part based on parameter lower. Heuristically, it has
126   been argued that approximately log(min (M, N)) sweeps are needed in practice
127   (Ref: Golub & van Loan "Matrix Computation").
128 
129 epsilon: the tolerance ratio.
130 
131 precision_config: a serialized xla::PrecisionConfig proto.
132 
133 s: Singular values. The values are sorted in reverse order of magnitude, so
134   s[..., 0] is the largest value, s[..., 1] is the second largest, etc.
135 u: Left singular vectors.
136 v: Right singular vectors.
137 )doc");
138 
139 REGISTER_OP("XlaConv")
140     .Input("lhs: T")
141     .Input("rhs: T")
142     .Input("window_strides: Tindices")
143     .Input("padding: Tindices")
144     .Input("lhs_dilation: Tindices")
145     .Input("rhs_dilation: Tindices")
146     .Input("feature_group_count: Tindices")
147     .Attr("T: numbertype")
148     .Attr("Tindices: {int32, int64}")
149     .Attr("dimension_numbers: string")
150     .Attr("precision_config: string")
151     .Output("output: T")
152     .SetShapeFn(UnchangedRank)
153     .Doc(R"doc(
154 Wraps the XLA ConvGeneralDilated operator, documented at
155  https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
156 .
157 
158 lhs: the input tensor
159 rhs: the kernel tensor
160 window_strides: the inter-window strides
161 padding: the padding to apply at the start and end of each input dimensions
162 lhs_dilation: dilation to apply between input elements
163 rhs_dilation: dilation to apply between kernel elements
164 feature_group_count: number of feature groups for grouped convolution.
165 dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
166 precision_config: a serialized xla::PrecisionConfig proto.
167 )doc");
168 
169 REGISTER_OP("XlaConvV2")
170     .Input("lhs: LhsT")
171     .Input("rhs: RhsT")
172     .Input("window_strides: Tindices")
173     .Input("padding: Tindices")
174     .Input("lhs_dilation: Tindices")
175     .Input("rhs_dilation: Tindices")
176     .Input("feature_group_count: Tindices")
177     .Attr("LhsT: numbertype")
178     .Attr("RhsT: numbertype")
179     .Attr("Tindices: {int32, int64}")
180     .Attr("dimension_numbers: string")
181     .Attr("precision_config: string")
182     .Attr("preferred_element_type: numbertype")
183     .Output("output: preferred_element_type")
184     .SetShapeFn(UnchangedRank)
185     .Doc(R"doc(
186 Wraps the XLA ConvGeneralDilated operator, documented at
187  https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
188 .
189 
190 lhs: the input tensor
191 rhs: the kernel tensor
192 window_strides: the inter-window strides
193 padding: the padding to apply at the start and end of each input dimensions
194 lhs_dilation: dilation to apply between input elements
195 rhs_dilation: dilation to apply between kernel elements
196 feature_group_count: number of feature groups for grouped convolution.
197 dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
198 precision_config: a serialized xla::PrecisionConfig proto.
199 preferred_element_type: The type of the tensor.
200 )doc");
201 
XlaDotShapeFunction(shape_inference::InferenceContext * c)202 static Status XlaDotShapeFunction(shape_inference::InferenceContext* c) {
203   shape_inference::ShapeHandle lhs_shape_handle = c->input(0);
204   shape_inference::ShapeHandle rhs_shape_handle = c->input(1);
205   if (!c->RankKnown(lhs_shape_handle) || !c->RankKnown(rhs_shape_handle)) {
206     return shape_inference::UnknownShape(c);
207   }
208 
209   string dimension_numbers_string;
210   TF_RETURN_IF_ERROR(
211       c->GetAttr("dimension_numbers", &dimension_numbers_string));
212 
213   xla::DotDimensionNumbers dimension_numbers;
214   dimension_numbers.ParseFromString(dimension_numbers_string);
215 
216   // Check that number of contracting dimensions match.
217   if (dimension_numbers.lhs_contracting_dimensions_size() !=
218       dimension_numbers.rhs_contracting_dimensions_size())
219     return errors::InvalidArgument(
220         "Must specify the same number of contracting dimensions for lhs "
221         "and rhs. Got: ",
222         dimension_numbers.lhs_contracting_dimensions_size(), " and ",
223         dimension_numbers.rhs_contracting_dimensions_size());
224 
225   // Check that contracting dimension sizes match.
226   for (int64_t i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
227        ++i) {
228     const int64_t lhs_contracting_dimension =
229         dimension_numbers.lhs_contracting_dimensions(i);
230     const int64_t rhs_contracting_dimension =
231         dimension_numbers.rhs_contracting_dimensions(i);
232     shape_inference::DimensionHandle unused;
233     TF_RETURN_WITH_CONTEXT_IF_ERROR(
234         c->Merge(c->DimKnownRank(lhs_shape_handle, lhs_contracting_dimension),
235                  c->DimKnownRank(rhs_shape_handle, rhs_contracting_dimension),
236                  &unused),
237         "For contracting dimension ", i, " which is lhs dimension ",
238         lhs_contracting_dimension, " and rhs dimension ",
239         rhs_contracting_dimension);
240   }
241 
242   // Check that number of batch dimensions match.
243   if (dimension_numbers.lhs_batch_dimensions_size() !=
244       dimension_numbers.rhs_batch_dimensions_size())
245     return errors::InvalidArgument(
246         "Must specify the same number of batch dimensions for lhs "
247         "and rhs. Got: ",
248         dimension_numbers.lhs_batch_dimensions_size(), " and ",
249         dimension_numbers.rhs_batch_dimensions_size());
250 
251   // The ranks of lhs and rhs are decremented by the number of contractions,
252   // and added for the rank of the result. When an input tensor
253   // is a scalar, its contribution to the rank of the result is 0. Generate
254   // the result dimensions in order, batch dimensions, then the
255   // non-contracted and non-batch lhs and rhs dimensions.
256   std::vector<shape_inference::DimensionHandle> output_dims;
257 
258   // Check that batch dimension sizes match, and add them to output_dims.
259   for (int64_t i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
260     const int64_t lhs_batch_dimension =
261         dimension_numbers.lhs_batch_dimensions(i);
262     const int64_t rhs_batch_dimension =
263         dimension_numbers.rhs_batch_dimensions(i);
264     shape_inference::DimensionHandle out;
265     TF_RETURN_WITH_CONTEXT_IF_ERROR(
266         c->Merge(c->DimKnownRank(lhs_shape_handle, lhs_batch_dimension),
267                  c->DimKnownRank(rhs_shape_handle, rhs_batch_dimension), &out),
268         "For batch dimension ", i, " which is lhs dimension ",
269         lhs_batch_dimension, " and rhs dimension ", rhs_batch_dimension);
270     output_dims.emplace_back(out);
271   }
272 
273   const int32_t lhs_rank = c->Rank(lhs_shape_handle);
274   for (int64_t i = 0; i < lhs_rank; ++i) {
275     if (absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
276                               i) ||
277         absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
278       continue;
279     }
280     output_dims.emplace_back(c->Dim(lhs_shape_handle, i));
281   }
282 
283   const int32_t rhs_rank = c->Rank(rhs_shape_handle);
284   for (int64_t i = 0; i < rhs_rank; ++i) {
285     if (absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
286                               i) ||
287         absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
288       continue;
289     }
290     output_dims.emplace_back(c->Dim(rhs_shape_handle, i));
291   }
292 
293   c->set_output(0, c->MakeShape(output_dims));
294   return Status::OK();
295 }
296 
297 REGISTER_OP("XlaDot")
298     .Input("lhs: T")
299     .Input("rhs: T")
300     .Attr("T: numbertype")
301     .Attr("dimension_numbers: string")
302     .Attr("precision_config: string")
303     .Output("output: T")
304     .SetShapeFn(XlaDotShapeFunction)
305     .Doc(R"doc(
306 Wraps the XLA DotGeneral operator, documented at
307  https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
308 .
309 
310 lhs: the LHS tensor
311 rhs: the RHS tensor
312 dimension_numbers: a serialized xla::DotDimensionNumbers proto.
313 precision_config: a serialized xla::PrecisionConfig proto.
314 )doc");
315 
316 REGISTER_OP("XlaDotV2")
317     .Input("lhs: LhsT")
318     .Input("rhs: RhsT")
319     .Attr("LhsT: numbertype")
320     .Attr("RhsT: numbertype")
321     .Attr("dimension_numbers: string")
322     .Attr("precision_config: string")
323     .Attr("preferred_element_type: numbertype")
324     .Output("output: preferred_element_type")
325     .SetShapeFn(XlaDotShapeFunction)
326     .Doc(R"doc(
327 Wraps the XLA DotGeneral operator, documented at
328  https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
329 .
330 
331 lhs: the LHS tensor
332 rhs: the RHS tensor
333 dimension_numbers: a serialized xla::DotDimensionNumbers proto.
334 precision_config: a serialized xla::PrecisionConfig proto.
335 preferred_element_type: The type of the tensor.
336 )doc");
337 
338 REGISTER_OP("XlaSetBound")
339     .Input("input: int32")
340     .Input("bound: int32")
341     .Output("output: int32")
342     .SetShapeFn(shape_inference::UnknownShape)
343     .Doc(
344         R"doc(Set a bound for the given input value as a hint to Xla compiler,
345         returns the same value.
346 )doc");
347 
348 REGISTER_OP("XlaSetDynamicDimensionSize")
349     .Input("input: T")
350     .Input("dim_index: int32")
351     .Input("size: int32")
352     .Output("output: T")
353     .Attr("T: type")
354     // Use unknown shape to prevent constant folding.
355     .SetShapeFn(shape_inference::UnknownShape)
356     .Doc(
357         R"doc(Make a static dimension into a xla bounded dynamic dimension.
358         The current static dimension size will become the bound and the second
359         operand becomes the dynamic size of the dimension.)doc");
360 
361 REGISTER_OP("XlaRemoveDynamicDimensionSize")
362     .Input("input: T")
363     .Input("dim_index: int32")
364     .Output("output: T")
365     .Attr("T: type")
366     // Use unknown shape to prevent constant folding.
367     .SetShapeFn(shape_inference::UnknownShape)
368     .Doc(R"doc(
369 Inverse of XlaSetDynamicDimensionSize.
370 
371 Make an xla bounded dynamic dimension into a static dimension. The bound of the
372 size of dimension `dim_index` becomes the static dimension size.
373 )doc");
374 
375 REGISTER_OP("XlaDynamicSlice")
376     .Input("input: T")
377     .Input("start_indices: Tindices")
378     .Input("size_indices: Tindices")
379     .Output("output: T")
380     .Attr("T: type")
381     .Attr("Tindices: {int32, int64}")
__anon592c2f360202(shape_inference::InferenceContext* c) 382     .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
383       shape_inference::ShapeHandle size_indices_shape = c->input(2);
384       if (!c->RankKnown(size_indices_shape)) {
385         return UnchangedRank(c);
386       }
387       if (c->Rank(size_indices_shape) != 1) {
388         return errors::InvalidArgument("size_indices must be a 1D tensor");
389       }
390       shape_inference::ShapeHandle size_indices_value;
391       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &size_indices_value));
392       if (!c->RankKnown(size_indices_value)) {
393         // If we cannot tell the rank of the output from the value of
394         // size_indices, perhaps we can find it from the rank of first operand.
395         return UnchangedRank(c);
396       }
397       c->set_output(0, size_indices_value);
398       return Status::OK();
399     })
400     .Doc(R"doc(
401 Wraps the XLA DynamicSlice operator, documented at
402  https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice
403 .
404 
405 DynamicSlice extracts a sub-array from the input array at dynamic
406 start_indices. The size of the slice in each dimension is passed in
407 size_indices, which specify the end point of exclusive slice intervals in each
408 dimension -- [start, start + size). The shape of start_indices must have rank 1,
409 with dimension size equal to the rank of operand.
410 
411 input: A `Tensor` of type T.
412 
413 start_indices: Rank 1 tensor of N integers containing the starting indices of
414   the slice for each dimension. Value must be greater than or equal to zero.
415 
416 start_indices: List of N integers containing the slice size for each
417   dimension. Each value must be strictly greater than zero, and start + size
418   must be less than or equal to the size of the dimension to avoid
419   implementation defined behavior.
420 )doc");
421 
422 REGISTER_OP("XlaDynamicUpdateSlice")
423     .Input("input: T")
424     .Input("update: T")
425     .Input("indices: Tindices")
426     .Output("output: T")
427     .Attr("T: type")
428     .Attr("Tindices: {int32, int64}")
429     .SetShapeFn(shape_inference::UnchangedShape)
430     .Doc(R"doc(
431 Wraps the XLA DynamicUpdateSlice operator, documented at
432  https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice
433 .
434 
435 XlaDynamicUpdateSlice generates a result which is the value of the `input`
436 operand, with a slice update overwritten at `indices`. The shape of `update`
437 determines the shape of the sub-array of the result which is updated. The shape
438 of indices must be rank == 1, with dimension size equal to the rank of `input`.
439 
440 Handling of out-of-bounds slice indices is implementation-defined.
441 
442 input: A `Tensor` of type T.
443 indices: A vector of indices into `input`. Must have length equal to the rank of
444   `input`.
445 update: A `Tensor` of type T. Same rank as `input`.
446 output: A `Tensor` of type T.
447 )doc");
448 
449 // TODO(b/37549631) setting the If Op to always be stateful is too
450 // conservative.
451 REGISTER_OP("XlaIf")
452     .Input("cond: Tcond")
453     .Input("inputs: Tin")
454     .Output("output: Tout")
455     .Attr("Tcond: type")
456     .Attr("then_branch: func")
457     .Attr("else_branch: func")
458     .Attr("Tin: list(type) >= 0")
459     .Attr("Tout: list(type) >= 0")
460     .SetIsStateful()
461     .SetShapeFn(shape_inference::UnknownShape)
462     .Doc(R"doc(
463 output = cond ? then_branch(inputs) : else_branch(inputs).
464 
465 cond: A boolean scalar.
466 inputs: A list of input tensors.
467 output: A list of tensors returned by either then_branch(inputs) or
468         else_branch(inputs). The input shapes of the then_branch and
469         else_branch must match.
470 then_branch: A function takes 'inputs' and returns a list of tensors,
471              whose types are the same as what else_branch returns.
472 else_branch: A function takes 'inputs' and returns a list of tensors.
473              whose types are the same as what then_branch returns.
474 )doc");
475 
476 REGISTER_OP("XlaPad")
477     .Input("input: T")
478     .Input("padding_value: T")
479     .Input("padding_low: Tindices")
480     .Input("padding_high: Tindices")
481     .Input("padding_interior: Tindices")
482     .Output("output: T")
483     .Attr("T: type")
484     .Attr("Tindices: {int32, int64}")
__anon592c2f360302(shape_inference::InferenceContext* c) 485     .SetShapeFn([](shape_inference::InferenceContext* c) {
486       shape_inference::ShapeHandle input_shape_handle = c->input(0);
487       if (!c->RankKnown(input_shape_handle)) {
488         return UnchangedRank(c);
489       }
490       const int32_t op_rank = c->Rank(input_shape_handle);
491 
492       shape_inference::ShapeHandle padding_shape_handle = c->input(1);
493       if (c->RankKnown(padding_shape_handle) &&
494           c->Rank(padding_shape_handle) != 0) {
495         return errors::InvalidArgument(
496             "padding_value input must be scalar, found rank ",
497             c->Rank(padding_shape_handle));
498       }
499       const Tensor* padding_low_tensor = c->input_tensor(2);
500       const Tensor* padding_high_tensor = c->input_tensor(3);
501       const Tensor* padding_interior_tensor = c->input_tensor(4);
502       if (padding_low_tensor == nullptr || padding_high_tensor == nullptr ||
503           padding_interior_tensor == nullptr) {
504         return UnchangedRank(c);
505       }
506 
507       if (padding_low_tensor->shape().dims() != 1 ||
508           padding_low_tensor->shape().dim_size(0) != op_rank) {
509         return errors::InvalidArgument(
510             "padding_low must be a 1D tensor of size ", op_rank);
511       }
512       if (padding_high_tensor->shape().dims() != 1 ||
513           padding_high_tensor->shape().dim_size(0) != op_rank) {
514         return errors::InvalidArgument(
515             "padding_high must be a 1D tensor of size ", op_rank);
516       }
517       if (padding_interior_tensor->shape().dims() != 1 ||
518           padding_interior_tensor->shape().dim_size(0) != op_rank) {
519         return errors::InvalidArgument(
520             "padding_interior must be a 1D tensor of size ", op_rank);
521       }
522       std::vector<shape_inference::DimensionHandle> output_dims;
523       output_dims.reserve(op_rank);
524       for (int64_t i = 0; i < op_rank; ++i) {
525         int64_t low, high, interior;
526         TF_RETURN_IF_ERROR(c->GetScalarFromTensor(padding_low_tensor, i, &low));
527         TF_RETURN_IF_ERROR(
528             c->GetScalarFromTensor(padding_high_tensor, i, &high));
529         TF_RETURN_IF_ERROR(
530             c->GetScalarFromTensor(padding_interior_tensor, i, &interior));
531         if (interior < 0) {
532           return errors::InvalidArgument(
533               "padding_interior must contain only non-negative values, found ",
534               interior);
535         }
536 
537         shape_inference::DimensionHandle orig_size_handle =
538             c->Dim(input_shape_handle, i);
539         if (c->ValueKnown(orig_size_handle)) {
540           auto orig_dim = c->Value(orig_size_handle);
541           int64_t new_dim = orig_dim + low + high;
542           if (orig_dim > 0) {
543             new_dim += interior * (orig_dim - 1);
544           }
545           if (new_dim < 0) {
546             return errors::InvalidArgument(
547                 "resulting padded dimension has negative size ", new_dim);
548           }
549           output_dims.emplace_back(c->MakeDim(new_dim));
550         } else {
551           output_dims.emplace_back(c->UnknownDim());
552         }
553       }
554 
555       c->set_output(0, c->MakeShape(output_dims));
556       return Status::OK();
557     })
558     .Doc(R"doc(
559 Wraps the XLA Pad operator, documented at
560  https://www.tensorflow.org/performance/xla/operation_semantics#pad
561 .
562 
563 input: A `Tensor` of type T.
564 padding_value: A scalar `Tensor` of type T.
565 padding_low: the padding to apply at the start of each input dimensions. Must
566   be a compile-time constant 1D tensor of length equal to rank of input.
567 padding_high: the padding to apply at the end of each input dimension. Must
568   be a compile-time constant 1D tensor of length equal to rank of input.
569 padding_interior: the padding to apply between each input element. Must
570   be a compile-time constant 1D tensor of length equal to rank of input,
571   containing only non-negative values.
572 output: A `Tensor` of type T.
573 )doc");
574 
575 REGISTER_OP("XlaRecv")
576     .Output("tensor: dtype")
577     .Attr("dtype: type")
578     .Attr("tensor_name: string")
579     .Attr("shape: shape")
580     .SetIsStateful()
__anon592c2f360402(shape_inference::InferenceContext* c) 581     .SetShapeFn([](shape_inference::InferenceContext* c) {
582       TensorShape shape_attr;
583       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
584       shape_inference::ShapeHandle s;
585       TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
586       c->set_output(0, s);
587       return Status::OK();
588     })
589     .Doc(R"doc(
590 Receives the named tensor from another XLA computation. Wraps the XLA Recv
591 operator documented at
592  https://www.tensorflow.org/performance/xla/operation_semantics#recv .
593 
594 tensor: The tensor to receive.
595 dtype: The type of the tensor.
596 tensor_name: A string key that identifies the channel.
597 shape: The shape of the tensor.
598 )doc");
599 
600 REGISTER_OP("XlaReduce")
601     .Input("input: T")
602     .Input("init_value: T")
603     .Attr("T: {numbertype, bool}")
604     .Attr("dimensions_to_reduce: list(int)")
605     .Attr("reducer: func")
606     .Output("output: T")
__anon592c2f360502(shape_inference::InferenceContext* c) 607     .SetShapeFn([](shape_inference::InferenceContext* c) {
608       if (c->RankKnown(c->input(0))) {
609         int rank = c->Rank(c->input(0));
610         std::vector<int64> dimensions_to_reduce;
611         TF_RETURN_IF_ERROR(
612             c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
613         std::set<int64> dims_set(dimensions_to_reduce.begin(),
614                                  dimensions_to_reduce.end());
615         auto dim_in_range = [rank](int64_t dim) {
616           return dim >= 0 && dim < rank;
617         };
618         const int dimensions_to_reduce_size = dimensions_to_reduce.size();
619         if (rank < dimensions_to_reduce_size ||
620             dims_set.size() != dimensions_to_reduce.size() ||
621             !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
622           return errors::InvalidArgument(
623               "Invalid dimensions_to_reduce argument to XlaReduce");
624         }
625         c->set_output(
626             0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
627       } else {
628         c->set_output(0, c->input(0));
629       }
630       return Status::OK();
631     })
632     .Doc(R"doc(
633 Wraps the XLA Reduce operator, documented at
634  https://www.tensorflow.org/performance/xla/operation_semantics#reduce .
635 
636 input: the input tensor
637 init_value: a scalar representing the initial value for the reduction
638 reducer: a reducer function to apply
639 dimensions_to_reduce: dimension numbers over which to reduce
640 )doc");
641 
642 REGISTER_OP("XlaVariadicReduce")
643     .Input("input: N * T")
644     .Input("init_value: N * T")
645     .Attr("N: int >= 1")
646     .Attr("T: {numbertype, bool}")
647     .Attr("dimensions_to_reduce: list(int)")
648     .Attr("reducer: func")
649     .Output("output: N * T")
__anon592c2f360702(shape_inference::InferenceContext* c) 650     .SetShapeFn([](shape_inference::InferenceContext* c) {
651       int n;
652       TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
653       for (int i = 0; i < n; i++) {
654         for (int j = 0; j < n; j++) {
655           c->MergeInput(i, c->input(j));
656         }
657       }
658       if (c->RankKnown(c->input(0))) {
659         int rank = c->Rank(c->input(0));
660         std::vector<int64> dimensions_to_reduce;
661         TF_RETURN_IF_ERROR(
662             c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
663         std::set<int64> dims_set(dimensions_to_reduce.begin(),
664                                  dimensions_to_reduce.end());
665         auto dim_in_range = [rank](int64_t dim) {
666           return dim >= 0 && dim < rank;
667         };
668         const int dimensions_to_reduce_size = dimensions_to_reduce.size();
669         if (rank < dimensions_to_reduce_size ||
670             dims_set.size() != dimensions_to_reduce.size() ||
671             !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
672           return errors::InvalidArgument(
673               "Invalid dimensions_to_reduce argument to XlaVariadicReduce");
674         }
675         for (int i = 0; i < n; i++) {
676           c->set_output(
677               i, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
678         }
679       } else {
680         for (int i = 0; i < n; i++) {
681           c->set_output(i, c->input(i));
682         }
683       }
684       return Status::OK();
685     })
686     .Doc(R"doc(
687 Wraps the variadic XLA Reduce operator.
688 
689 Semantics are documented at
690  https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce.
691 
692 This version is limited to operands of the same dtype.
693 XlaVariadicReduceV2 is a version that supports heterogeneous operands.
694 
695 input: the input tensor(s)
696 init_value: scalar initial value(s) for the reduction
697 reducer: a reducer function to apply
698 dimensions_to_reduce: dimension numbers over which to reduce
699 )doc");
700 
701 REGISTER_OP("XlaVariadicReduceV2")
702     .Input("inputs: T")
703     .Input("init_values: T")
704     .Attr("T: list(type) >= 1")
705     .Attr("dimensions_to_reduce: list(int)")
706     .Attr("reducer: func")
707     .Output("outputs: T")
__anon592c2f360902(shape_inference::InferenceContext* c) 708     .SetShapeFn([](shape_inference::InferenceContext* c) {
709       std::vector<shape_inference::ShapeHandle> input_shapes;
710       TF_RETURN_IF_ERROR(c->input("inputs", &input_shapes));
711       std::vector<shape_inference::ShapeHandle> init_values_shapes;
712       TF_RETURN_IF_ERROR(c->input("init_values", &init_values_shapes));
713       const int nr_inputs = input_shapes.size();
714       if (nr_inputs != init_values_shapes.size()) {
715         return errors::InvalidArgument(
716             "Must specify the same number of inputs and init_values. ", "Got ",
717             nr_inputs, " and ", init_values_shapes.size());
718       }
719       if (nr_inputs == 0) {
720         return errors::InvalidArgument("Must specify at least one input");
721       }
722 
723       shape_inference::ShapeHandle input_shape = input_shapes[0];
724       for (int i = 1; i < nr_inputs; ++i) {
725         shape_inference::ShapeHandle merged;
726         TF_RETURN_WITH_CONTEXT_IF_ERROR(
727             c->Merge(input_shape, input_shapes[i], &merged),
728             "All inputs must have the same shape. Input ", i,
729             " (zero-based) has shape ", c->DebugString(input_shapes[i]),
730             " incompatible with the shape ", "inferred from previous inputs ",
731             c->DebugString(input_shape));
732         input_shape = merged;
733       }
734       // All outputs have the same shape
735       shape_inference::ShapeHandle output_shape = c->UnknownShape();
736 
737       if (c->RankKnown(input_shape)) {
738         int rank = c->Rank(input_shape);
739 
740         std::vector<int64> dimensions_to_reduce;
741         TF_RETURN_IF_ERROR(
742             c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
743         std::set<int64> dims_set(dimensions_to_reduce.begin(),
744                                  dimensions_to_reduce.end());
745 
746         auto dim_in_range = [rank](int64_t dim) {
747           return dim >= 0 && dim < rank;
748         };
749         const int dimensions_to_reduce_size = dimensions_to_reduce.size();
750         if (rank < dimensions_to_reduce_size ||
751             dims_set.size() != dimensions_to_reduce.size() ||
752             !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
753           return errors::InvalidArgument(
754               "Invalid dimensions_to_reduce argument to XlaVariadicReduceV2");
755         }
756 
757         std::vector<shape_inference::DimensionHandle> output_dims;
758         for (int64_t i = 0; i < rank; ++i) {
759           if (dims_set.find(i) == dims_set.end()) {
760             output_dims.emplace_back(c->Dim(input_shape, i));
761           }
762         }
763         output_shape = c->MakeShape(output_dims);
764       }
765       for (int i = 0; i < nr_inputs; ++i) {
766         c->set_output(i, output_shape);
767       }
768       return Status::OK();
769     })
770     .Doc(R"doc(
771 Wraps the variadic XLA Reduce operator.
772 
773 Semantics are documented at
774  https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce.
775 
776 This is an expanded version of XlaVariadicReduce, with support for
777 operands of different dtypes, and improved shape inference.
778 
779 inputs: the input tensor(s)
780 init_values: scalar initial value(s) for the reduction
781 reducer: a reducer function to apply
782 dimensions_to_reduce: dimension numbers over which to reduce
783 )doc");
784 
785 REGISTER_OP("XlaReduceWindow")
786     .Input("input: T")
787     .Input("init_value: T")
788     .Input("window_dimensions: Tindices")
789     .Input("window_strides: Tindices")
790     .Input("base_dilations: Tindices")
791     .Input("window_dilations: Tindices")
792     .Input("padding: Tindices")
793     .Attr("T: {numbertype, bool}")
794     .Attr("Tindices: {int32, int64}")
795     .Attr("computation: func")
796     .Output("output: T")
797     .SetShapeFn(UnchangedRank)
798     .Doc(R"doc(
799 Wraps the XLA ReduceWindow operator, documented at
800  https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
801 
802 input: the input tensor
803 init_value: a scalar representing the initial value for the reduction
804 computation: a reducer function to apply
805 window_dimensions: the shape of the window
806 window_strides: the inter-window strides
807 padding: the padding to apply at the start and end of each input dimensions
808 )doc");
809 
810 REGISTER_OP("XlaRngBitGenerator")
811     .Input("algorithm: int32")
812     .Input("initial_state: uint64")
813     .Input("shape: Tshape")
814     .Output("output_key: uint64")
815     .Output("output: dtype")
816     .Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64")
817     .Attr("Tshape: {int32, int64} = DT_INT32")
__anon592c2f360b02(shape_inference::InferenceContext* c) 818     .SetShapeFn([](shape_inference::InferenceContext* c) {
819       shape_inference::ShapeHandle algorithm;
820       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &algorithm));
821       shape_inference::ShapeHandle initial_state;
822       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &initial_state));
823 
824       c->set_output(0, initial_state);
825       shape_inference::ShapeHandle output;
826       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output));
827       c->set_output(1, output);
828       return Status::OK();
829     })
830     .Doc(R"doc(
831 Stateless PRNG bit generator.
832 Wraps the XLA RngBitGenerator operator, documented at
833  https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator.
834 
835 algorithm: The PRNG algorithm to use, one of
836   tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}.
837 initial_state: Initial state for the PRNG algorithm. For THREEFRY, it should be
838   a u64[2] and for PHILOX a u64[3].
839 shape: The output shape of the generated data.
840 dtype: The type of the tensor.
841 )doc");
842 
843 REGISTER_OP("XlaSelectAndScatter")
844     .Input("operand: T")
845     .Input("window_dimensions: Tindices")
846     .Input("window_strides: Tindices")
847     .Input("padding: Tindices")
848     .Input("source: T")
849     .Input("init_value: T")
850     .Attr("T: numbertype")
851     .Attr("Tindices: {int32, int64}")
852     .Attr("select: func")
853     .Attr("scatter: func")
854     .Output("output: T")
855     .SetShapeFn(UnchangedRank)
856     .Doc(R"doc(
857 Wraps the XLA SelectAndScatter operator, documented at
858  https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter
859 .
860 
861 operand: the input tensor
862 window_dimensions: the shape of the window
863 window_strides: the inter-window strides
864 padding: the padding to apply at the start and end of each input dimensions
865 source: a tensor of values to scatter
866 init_value: a scalar representing the initial value for the output tensor
867 select: a selection function to apply
868 scatter: a scatter function to apply
869 )doc");
870 
871 REGISTER_OP("XlaSend")
872     .Input("tensor: T")
873     .Attr("T: type")
874     .Attr("tensor_name: string")
875     .SetIsStateful()
876     .SetShapeFn(shape_inference::UnknownShape)
877     .Doc(R"doc(
878 Sends the named tensor to another XLA computation. Wraps the XLA Send operator
879 documented at
880  https://www.tensorflow.org/performance/xla/operation_semantics#send .
881 
882 tensor: The tensor to send.
883 tensor_name: A string key that identifies the channel.
884 )doc");
885 
886 REGISTER_OP("XlaSort")
887     .Input("input: T")
888     .Output("output: T")
889     .Attr("T: type")
890     .SetShapeFn(shape_inference::UnchangedShape)
891     .Doc(R"doc(
892 Wraps the XLA Sort operator, documented at
893  https://www.tensorflow.org/performance/xla/operation_semantics#sort
894 .
895 
896 Sorts a tensor. Currently only sorts in ascending order are supported.
897 
898 input: A `Tensor` of type T.
899 output: A `Tensor` of type T.
900 )doc");
901 
902 REGISTER_OP("XlaKeyValueSort")
903     .Input("keys: K")
904     .Input("values: V")
905     .Output("sorted_keys: K")
906     .Output("sorted_values: V")
907     .Attr("K: realnumbertype")
908     .Attr("V: type")
__anon592c2f360c02(shape_inference::InferenceContext* c) 909     .SetShapeFn([](shape_inference::InferenceContext* c) {
910       c->set_output(0, c->input(0));
911       c->set_output(1, c->input(1));
912       return Status::OK();
913     })
914     .Doc(R"doc(
915 Wraps the XLA Sort operator, documented at
916  https://www.tensorflow.org/performance/xla/operation_semantics#sort
917 .
918 
919 Sorts a tensor. Currently only sorts in ascending order are supported.
920 
921 keys: A `Tensor` of type K.
922 values: A `Tensor` of type V.
923 sorted_keys: A `Tensor` of type K.
924 sorted_values: A `Tensor` of type V.
925 )doc");
926 
927 REGISTER_OP("XlaVariadicSort")
928     .Input("inputs: T")
929     .Input("dimension: int32")
930     .Output("outputs: T")
931     .Attr("T: list(type) >= 1")
932     .Attr("comparator: func")
933     .Attr("is_stable: bool")
__anon592c2f360d02(shape_inference::InferenceContext* c) 934     .SetShapeFn([](shape_inference::InferenceContext* c) {
935       std::vector<shape_inference::ShapeHandle> input_shapes;
936       TF_RETURN_IF_ERROR(c->input("inputs", &input_shapes));
937       TF_RETURN_IF_ERROR(c->set_output("outputs", input_shapes));
938       return Status::OK();
939     })
940     .Doc(R"doc(
941 Wraps the XLA Sort operator, documented at
942  https://www.tensorflow.org/performance/xla/operation_semantics#sort
943 .
944 
945 Sorts one or more tensors, with support for custom comparator, dimension, and
946 is_stable attributes.
947 
948 inputs: A list of `Tensor` of identical shape but possibly different types.
949 dimension: The dimension along which to sort. Must be a compile-time constant.
950 is_stable: Whether to use stable sort.
951 comparator: A comparator function to apply to 2*N scalars and returning a
952   boolean. N is the number of sort inputs. If you want to sort in ascending
953   order then the comparator should perform a less-than comparison.
954 outputs: A list of `Tensor` of same shape and types as the `input`.
955 )doc");
956 
957 // TODO(b/37549631) setting the While Op to always be stateful is too
958 // conservative.
959 REGISTER_OP("XlaWhile")
960     .Input("input: T")
961     .Output("output: T")
962     .Attr("T: list(type) >= 0")
963     .Attr("cond: func")
964     .Attr("body: func")
965     .SetIsStateful()
966     .SetShapeFn(shape_inference::UnknownShape)
967     .Doc(R"doc(
968 output = input; While (Cond(output)) { output = Body(output) }
969 
970 input: A list of input tensors whose types are T.
971 output: A list of output tensors whose types are T.
972 cond: A function takes 'input' and returns a tensor.  If the tensor is
973       a scalar of non-boolean, the scalar is converted to a boolean
974       according to the following rule: if the scalar is a numerical
975       value, non-zero means True and zero means False; if the scalar is
976       a string, non-empty means True and empty means False. If the
977       tensor is not a scalar, non-emptiness means True and False
978       otherwise.
979 body: A function that takes a list of tensors and returns another
980       list of tensors. Both lists have the same types as specified by T.
981 )doc");
982 
983 REGISTER_OP("XlaDequantize")
984     .Input("input: uint32")
985     .Output("output: bfloat16")
986     .Attr("min_range: float")
987     .Attr("max_range: float")
988     .Attr("mode: string")
989     .Attr("transpose_output: bool")
990     .SetIsStateful()
991     .SetShapeFn(shape_inference::UnknownShape)
992     .Doc(R"doc(
993 Takes the packed uint32 input and unpacks the input to uint8 to do
994 Dequantization on device.
995 
996 input: Input tensors whose types is uint32, shape is [d0, ..., dn].
997 output: Output tensors whose types is bloat16. If transpose_output is true,
998      output shape is [dn * 4, dn-1, ..., d1, d0]. If transpose_output
999      is false, output shape is [d0,..., dn * 4].
1000 min_range: The minimum scalar value possibly produced for the input.
1001 max_range: The maximum scalar value possibly produced for the input.
1002 mode: String to determine the dequantize mode in {"MIN_COMBINED", "MIN_FIRST", "SCALED"}.
1003 transpose_output: Boolean to determine if output is transposed. transpose_output
1004      is faster when input is large and rank of input is higher than 1.
1005 )doc");
1006 
1007 REGISTER_OP("XlaEinsum")
1008     .Input("a: T")
1009     .Input("b: T")
1010     .Output("product: T")
1011     .Attr("equation: string")
1012     .Attr("T: {complex64, bfloat16, float}")
__anon592c2f360e02(shape_inference::InferenceContext* context) 1013     .SetShapeFn([](shape_inference::InferenceContext* context) {
1014       string equation;
1015       TF_RETURN_IF_ERROR(context->GetAttr("equation", &equation));
1016       // XlaEinsum supports only two-input einsum equations.
1017       if (!absl::StrContains(equation, ",")) {
1018         return errors::InvalidArgument("Expected one \",\" in equation. Got: ",
1019                                        equation);
1020       }
1021       // Use EinsumShape for the rest of the inference now that we know we must
1022       // have a two-input einsum.
1023       return shape_inference::EinsumShape(context);
1024     })
1025     .Doc(R"doc(
1026 An op which supports basic einsum op with 2 inputs and 1 output.
1027 
1028 This op has better TPU performance since it doesn't have explicitly reshape and
1029 transpose operations as tf.einsum does.
1030 )doc");
1031 
1032 REGISTER_OP("XlaSpmdFullToShardShape")
1033     .Input("input: T")
1034     .Output("output: T")
1035     .Attr("T: type")
1036     .Attr("manual_sharding: string")
__anon592c2f360f02(shape_inference::InferenceContext* c) 1037     .SetShapeFn([](shape_inference::InferenceContext* c) {
1038       auto input_handle = c->input(0);
1039       if (!c->RankKnown(input_handle)) {
1040         return shape_inference::UnknownShape(c);
1041       }
1042       string sharding_attr;
1043       TF_RETURN_IF_ERROR(c->GetAttr("manual_sharding", &sharding_attr));
1044       xla::OpSharding sharding;
1045       sharding.ParseFromString(sharding_attr);
1046       if (sharding.type() != xla::OpSharding::OTHER) {
1047         return shape_inference::UnchangedShape(c);
1048       }
1049       std::vector<shape_inference::DimensionHandle> dims;
1050       for (int64_t i = 0; i < c->Rank(input_handle); ++i) {
1051         auto dim = c->Value(c->Dim(input_handle, i));
1052         int64_t partitions_i = sharding.tile_assignment_dimensions(i);
1053         if (dim != shape_inference::InferenceContext::kUnknownDim &&
1054             partitions_i != 1) {
1055           dim = (dim + partitions_i - 1) / partitions_i;
1056         }
1057         dims.push_back(c->MakeDim(dim));
1058       }
1059       c->set_output(0, c->MakeShape(dims));
1060       return Status::OK();
1061     })
1062     .Doc(R"doc(
1063 An op used by XLA SPMD partitioner to switch from automatic partitioning to
1064 manual partitioning. It annotates the input (full-shape, to be automatically
1065 partitioned) with the same sharding used by manual partitioning, and outputs a
1066 shard-shaped tensor to be consumed by later manually-partitioned ops. If the
1067 shape is not evenly partitionable, the padding region will be masked with 0s.
1068 )doc");
1069 
1070 REGISTER_OP("XlaSpmdShardToFullShape")
1071     .Input("input: T")
1072     .Output("output: T")
1073     .Attr("T: type")
1074     .Attr("manual_sharding: string")
1075     .Attr("full_shape: shape")
__anon592c2f361002(shape_inference::InferenceContext* c) 1076     .SetShapeFn([](shape_inference::InferenceContext* c) {
1077       TensorShape shape_attr;
1078       TF_RETURN_IF_ERROR(c->GetAttr("full_shape", &shape_attr));
1079       shape_inference::ShapeHandle s;
1080       TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
1081       c->set_output(0, s);
1082       return Status::OK();
1083     })
1084     .Doc(R"doc(
1085 An op used by XLA SPMD partitioner to switch from manual partitioning to
1086 automatic partitioning. It converts the shard-shaped, manually partitioned input
1087 into full-shaped tensor to be partitioned automatically with the same sharding
1088 used by manual partitioning.
1089 )doc");
1090 
1091 REGISTER_OP("XlaSharding")
1092     .Input("input: T")
1093     .Output("output: T")
1094     .Attr("T: type")
1095     .Attr("sharding: string = ''")
1096     .SetShapeFn(shape_inference::UnchangedShape)
1097     .Doc(R"doc(
1098 An op which shards the input based on the given sharding attribute.
1099 )doc");
1100 
1101 REGISTER_OP("XlaReplicaId")
1102     .Output("id: int32")
__anon592c2f361102(shape_inference::InferenceContext* context) 1103     .SetShapeFn([](shape_inference::InferenceContext* context) {
1104       context->set_output(0, context->MakeShape({}));
1105       return Status::OK();
1106     })
1107     .Doc("Replica ID.");
1108 
1109 REGISTER_OP("XlaGather")
1110     .Input("operand: T")
1111     .Input("start_indices: Tindices")
1112     .Input("slice_sizes: Tindices")
1113     .Attr("dimension_numbers: string")
1114     .Attr("indices_are_sorted: bool")
1115     .Attr("T: {numbertype, bool}")
1116     .Attr("Tindices: {int32, int64}")
1117     .Output("output: T")
1118     .SetShapeFn(shape_inference::UnknownShape)
1119     .Doc(R"doc(
1120 Wraps the XLA Gather operator documented at
1121   https://www.tensorflow.org/xla/operation_semantics#gather
1122 operand: The array we're gathering from.
1123 start_indices: Array containing the starting indices of the slices we gather.
1124 dimension_numbers: A serialized xla::GatherDimensionNumbers proto.
1125 slice_sizes: slice_sizes[i] is the bounds for the slice on dimension i.
1126 indices_are_sorted: Boolean indicating if the indices are sorted.
1127 )doc");
1128 
1129 REGISTER_OP("XlaScatter")
1130     .Input("operand: T")
1131     .Input("scatter_indices: Tindices")
1132     .Input("updates: T")
1133     .Attr("update_computation: func")
1134     .Attr("dimension_numbers: string")
1135     .Attr("indices_are_sorted: bool")
1136     .Attr("T: {numbertype, bool}")
1137     .Attr("Tindices: {int32, int64}")
1138     .Output("output: T")
1139     .SetShapeFn(shape_inference::UnchangedShape)
1140     .Doc(R"doc(
1141 Wraps the XLA Scatter operator documented at
1142   https://www.tensorflow.org/xla/operation_semantics#scatter.
1143 
1144 operand: Array to be scattered into.
1145 scatter_indices: Array containing the starting indices of the slices that must
1146   be scattered to.
1147 updates: Array containing the values that must be used for scattering.
1148 update_computation: Computation to be used for combining the existing values in
1149   the input array and the updates during scatter.
1150 dimension_numbers: A serialized xla::ScatterDimensionNumbers proto.
1151 indices_are_sorted: Boolean indicating if the indices are sorted.
1152 )doc");
1153 
1154 }  // namespace
1155 }  // namespace tensorflow
1156