1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstddef>
17
18 #include "absl/algorithm/container.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_split.h"
22 #include "tensorflow/compiler/xla/xla_data.pb.h"
23 #include "tensorflow/core/framework/common_shape_fns.h"
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/framework/shape_inference.h"
26 #include "tensorflow/core/lib/core/errors.h"
27
28 // Note: Most of the operators defined in this module are used by the jax2tf
29 // converter (see go/jax2tf for details) and are used in SavedModel produced
30 // by jax2tf. Hence, we need to maintain backwards compatibility for these
31 // operators. Please reach out to the JAX team if you want to make changes.
32
33 namespace tensorflow {
34 namespace {
35
36 // Helper shape function for operators that return an output with the same rank
37 // as their first input.
UnchangedRank(shape_inference::InferenceContext * c)38 Status UnchangedRank(shape_inference::InferenceContext* c) {
39 if (c->RankKnown(c->input(0))) {
40 c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
41 } else {
42 c->set_output(0, c->input(0));
43 }
44 return Status::OK();
45 }
46
47 REGISTER_OP("XlaBroadcastHelper")
48 .Input("lhs: T")
49 .Input("rhs: T")
50 .Input("broadcast_dims: Tindices")
51 .Attr("T: numbertype")
52 .Attr("Tindices: {int32, int64}")
53 .Output("lhs_output: T")
54 .Output("rhs_output: T")
55 .SetShapeFn(shape_inference::UnknownShape)
56 .Doc(R"doc(
57 Helper operator for performing XLA-style broadcasts
58
59 Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to
60 whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules
61 for binary operators.
62
63 lhs: the LHS input tensor
64 rhs: the RHS input tensor
65 broadcast_dims: an XLA-style broadcast dimension specification
66 lhs_output: the broadcasted LHS tensor
67 rhs_output: the broadcasted RHS tensor
68 )doc");
69
70 REGISTER_OP("XlaSelfAdjointEig")
71 .Input("a: T")
72 .Attr("lower: bool")
73 .Attr("max_iter: int")
74 .Attr("epsilon: float")
75 .Output("w: T")
76 .Output("v: T")
77 .SetShapeFn(shape_inference::UnknownShape)
78 .Attr("T: numbertype")
79 .Doc(R"doc(
80 Computes the eigen decomposition of a batch of self-adjoint matrices
81 (Note: Only real inputs are supported).
82
83 Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in
84 tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for
85 i=0...N-1.
86
87 a: the input tensor.
88
89 lower: a boolean specifies whether the calculation is done with the lower
90 triangular part or the upper triangular part.
91
92 max_iter: maximum number of sweep update, i.e., the whole lower triangular
93 part or upper triangular part based on parameter lower. Heuristically, it has
94 been argued that approximately logN sweeps are needed in practice (Ref: Golub &
95 van Loan "Matrix Computation").
96
97 epsilon: the tolerance ratio.
98
99 w: The eigenvalues in ascending order, each repeated according to its
100 multiplicity.
101 v: The column v[..., :, i] is the normalized eigenvector corresponding to the
102 eigenvalue w[..., i].
103 )doc");
104
105 REGISTER_OP("XlaSvd")
106 .Input("a: T")
107 .Attr("max_iter: int")
108 .Attr("epsilon: float")
109 .Attr("precision_config: string")
110 .Output("s: T")
111 .Output("u: T")
112 .Output("v: T")
113 .SetShapeFn(shape_inference::UnknownShape)
114 .Attr("T: numbertype")
115 .Doc(R"doc(
116 Computes the eigen decomposition of a batch of self-adjoint matrices
117 (Note: Only real inputs are supported).
118
119 Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in
120 tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]).
121
122 a: the input tensor.
123
124 max_iter: maximum number of sweep update, i.e., the whole lower triangular
125 part or upper triangular part based on parameter lower. Heuristically, it has
126 been argued that approximately log(min (M, N)) sweeps are needed in practice
127 (Ref: Golub & van Loan "Matrix Computation").
128
129 epsilon: the tolerance ratio.
130
131 precision_config: a serialized xla::PrecisionConfig proto.
132
133 s: Singular values. The values are sorted in reverse order of magnitude, so
134 s[..., 0] is the largest value, s[..., 1] is the second largest, etc.
135 u: Left singular vectors.
136 v: Right singular vectors.
137 )doc");
138
139 REGISTER_OP("XlaConv")
140 .Input("lhs: T")
141 .Input("rhs: T")
142 .Input("window_strides: Tindices")
143 .Input("padding: Tindices")
144 .Input("lhs_dilation: Tindices")
145 .Input("rhs_dilation: Tindices")
146 .Input("feature_group_count: Tindices")
147 .Attr("T: numbertype")
148 .Attr("Tindices: {int32, int64}")
149 .Attr("dimension_numbers: string")
150 .Attr("precision_config: string")
151 .Output("output: T")
152 .SetShapeFn(UnchangedRank)
153 .Doc(R"doc(
154 Wraps the XLA ConvGeneralDilated operator, documented at
155 https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
156 .
157
158 lhs: the input tensor
159 rhs: the kernel tensor
160 window_strides: the inter-window strides
161 padding: the padding to apply at the start and end of each input dimensions
162 lhs_dilation: dilation to apply between input elements
163 rhs_dilation: dilation to apply between kernel elements
164 feature_group_count: number of feature groups for grouped convolution.
165 dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
166 precision_config: a serialized xla::PrecisionConfig proto.
167 )doc");
168
169 REGISTER_OP("XlaConvV2")
170 .Input("lhs: LhsT")
171 .Input("rhs: RhsT")
172 .Input("window_strides: Tindices")
173 .Input("padding: Tindices")
174 .Input("lhs_dilation: Tindices")
175 .Input("rhs_dilation: Tindices")
176 .Input("feature_group_count: Tindices")
177 .Attr("LhsT: numbertype")
178 .Attr("RhsT: numbertype")
179 .Attr("Tindices: {int32, int64}")
180 .Attr("dimension_numbers: string")
181 .Attr("precision_config: string")
182 .Attr("preferred_element_type: numbertype")
183 .Output("output: preferred_element_type")
184 .SetShapeFn(UnchangedRank)
185 .Doc(R"doc(
186 Wraps the XLA ConvGeneralDilated operator, documented at
187 https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
188 .
189
190 lhs: the input tensor
191 rhs: the kernel tensor
192 window_strides: the inter-window strides
193 padding: the padding to apply at the start and end of each input dimensions
194 lhs_dilation: dilation to apply between input elements
195 rhs_dilation: dilation to apply between kernel elements
196 feature_group_count: number of feature groups for grouped convolution.
197 dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
198 precision_config: a serialized xla::PrecisionConfig proto.
199 preferred_element_type: The type of the tensor.
200 )doc");
201
XlaDotShapeFunction(shape_inference::InferenceContext * c)202 static Status XlaDotShapeFunction(shape_inference::InferenceContext* c) {
203 shape_inference::ShapeHandle lhs_shape_handle = c->input(0);
204 shape_inference::ShapeHandle rhs_shape_handle = c->input(1);
205 if (!c->RankKnown(lhs_shape_handle) || !c->RankKnown(rhs_shape_handle)) {
206 return shape_inference::UnknownShape(c);
207 }
208
209 string dimension_numbers_string;
210 TF_RETURN_IF_ERROR(
211 c->GetAttr("dimension_numbers", &dimension_numbers_string));
212
213 xla::DotDimensionNumbers dimension_numbers;
214 dimension_numbers.ParseFromString(dimension_numbers_string);
215
216 // Check that number of contracting dimensions match.
217 if (dimension_numbers.lhs_contracting_dimensions_size() !=
218 dimension_numbers.rhs_contracting_dimensions_size())
219 return errors::InvalidArgument(
220 "Must specify the same number of contracting dimensions for lhs "
221 "and rhs. Got: ",
222 dimension_numbers.lhs_contracting_dimensions_size(), " and ",
223 dimension_numbers.rhs_contracting_dimensions_size());
224
225 // Check that contracting dimension sizes match.
226 for (int64_t i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
227 ++i) {
228 const int64_t lhs_contracting_dimension =
229 dimension_numbers.lhs_contracting_dimensions(i);
230 const int64_t rhs_contracting_dimension =
231 dimension_numbers.rhs_contracting_dimensions(i);
232 shape_inference::DimensionHandle unused;
233 TF_RETURN_WITH_CONTEXT_IF_ERROR(
234 c->Merge(c->DimKnownRank(lhs_shape_handle, lhs_contracting_dimension),
235 c->DimKnownRank(rhs_shape_handle, rhs_contracting_dimension),
236 &unused),
237 "For contracting dimension ", i, " which is lhs dimension ",
238 lhs_contracting_dimension, " and rhs dimension ",
239 rhs_contracting_dimension);
240 }
241
242 // Check that number of batch dimensions match.
243 if (dimension_numbers.lhs_batch_dimensions_size() !=
244 dimension_numbers.rhs_batch_dimensions_size())
245 return errors::InvalidArgument(
246 "Must specify the same number of batch dimensions for lhs "
247 "and rhs. Got: ",
248 dimension_numbers.lhs_batch_dimensions_size(), " and ",
249 dimension_numbers.rhs_batch_dimensions_size());
250
251 // The ranks of lhs and rhs are decremented by the number of contractions,
252 // and added for the rank of the result. When an input tensor
253 // is a scalar, its contribution to the rank of the result is 0. Generate
254 // the result dimensions in order, batch dimensions, then the
255 // non-contracted and non-batch lhs and rhs dimensions.
256 std::vector<shape_inference::DimensionHandle> output_dims;
257
258 // Check that batch dimension sizes match, and add them to output_dims.
259 for (int64_t i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
260 const int64_t lhs_batch_dimension =
261 dimension_numbers.lhs_batch_dimensions(i);
262 const int64_t rhs_batch_dimension =
263 dimension_numbers.rhs_batch_dimensions(i);
264 shape_inference::DimensionHandle out;
265 TF_RETURN_WITH_CONTEXT_IF_ERROR(
266 c->Merge(c->DimKnownRank(lhs_shape_handle, lhs_batch_dimension),
267 c->DimKnownRank(rhs_shape_handle, rhs_batch_dimension), &out),
268 "For batch dimension ", i, " which is lhs dimension ",
269 lhs_batch_dimension, " and rhs dimension ", rhs_batch_dimension);
270 output_dims.emplace_back(out);
271 }
272
273 const int32_t lhs_rank = c->Rank(lhs_shape_handle);
274 for (int64_t i = 0; i < lhs_rank; ++i) {
275 if (absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
276 i) ||
277 absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
278 continue;
279 }
280 output_dims.emplace_back(c->Dim(lhs_shape_handle, i));
281 }
282
283 const int32_t rhs_rank = c->Rank(rhs_shape_handle);
284 for (int64_t i = 0; i < rhs_rank; ++i) {
285 if (absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
286 i) ||
287 absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
288 continue;
289 }
290 output_dims.emplace_back(c->Dim(rhs_shape_handle, i));
291 }
292
293 c->set_output(0, c->MakeShape(output_dims));
294 return Status::OK();
295 }
296
297 REGISTER_OP("XlaDot")
298 .Input("lhs: T")
299 .Input("rhs: T")
300 .Attr("T: numbertype")
301 .Attr("dimension_numbers: string")
302 .Attr("precision_config: string")
303 .Output("output: T")
304 .SetShapeFn(XlaDotShapeFunction)
305 .Doc(R"doc(
306 Wraps the XLA DotGeneral operator, documented at
307 https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
308 .
309
310 lhs: the LHS tensor
311 rhs: the RHS tensor
312 dimension_numbers: a serialized xla::DotDimensionNumbers proto.
313 precision_config: a serialized xla::PrecisionConfig proto.
314 )doc");
315
316 REGISTER_OP("XlaDotV2")
317 .Input("lhs: LhsT")
318 .Input("rhs: RhsT")
319 .Attr("LhsT: numbertype")
320 .Attr("RhsT: numbertype")
321 .Attr("dimension_numbers: string")
322 .Attr("precision_config: string")
323 .Attr("preferred_element_type: numbertype")
324 .Output("output: preferred_element_type")
325 .SetShapeFn(XlaDotShapeFunction)
326 .Doc(R"doc(
327 Wraps the XLA DotGeneral operator, documented at
328 https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
329 .
330
331 lhs: the LHS tensor
332 rhs: the RHS tensor
333 dimension_numbers: a serialized xla::DotDimensionNumbers proto.
334 precision_config: a serialized xla::PrecisionConfig proto.
335 preferred_element_type: The type of the tensor.
336 )doc");
337
338 REGISTER_OP("XlaSetBound")
339 .Input("input: int32")
340 .Input("bound: int32")
341 .Output("output: int32")
342 .SetShapeFn(shape_inference::UnknownShape)
343 .Doc(
344 R"doc(Set a bound for the given input value as a hint to Xla compiler,
345 returns the same value.
346 )doc");
347
348 REGISTER_OP("XlaSetDynamicDimensionSize")
349 .Input("input: T")
350 .Input("dim_index: int32")
351 .Input("size: int32")
352 .Output("output: T")
353 .Attr("T: type")
354 // Use unknown shape to prevent constant folding.
355 .SetShapeFn(shape_inference::UnknownShape)
356 .Doc(
357 R"doc(Make a static dimension into a xla bounded dynamic dimension.
358 The current static dimension size will become the bound and the second
359 operand becomes the dynamic size of the dimension.)doc");
360
361 REGISTER_OP("XlaRemoveDynamicDimensionSize")
362 .Input("input: T")
363 .Input("dim_index: int32")
364 .Output("output: T")
365 .Attr("T: type")
366 // Use unknown shape to prevent constant folding.
367 .SetShapeFn(shape_inference::UnknownShape)
368 .Doc(R"doc(
369 Inverse of XlaSetDynamicDimensionSize.
370
371 Make an xla bounded dynamic dimension into a static dimension. The bound of the
372 size of dimension `dim_index` becomes the static dimension size.
373 )doc");
374
375 REGISTER_OP("XlaDynamicSlice")
376 .Input("input: T")
377 .Input("start_indices: Tindices")
378 .Input("size_indices: Tindices")
379 .Output("output: T")
380 .Attr("T: type")
381 .Attr("Tindices: {int32, int64}")
__anon592c2f360202(shape_inference::InferenceContext* c) 382 .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
383 shape_inference::ShapeHandle size_indices_shape = c->input(2);
384 if (!c->RankKnown(size_indices_shape)) {
385 return UnchangedRank(c);
386 }
387 if (c->Rank(size_indices_shape) != 1) {
388 return errors::InvalidArgument("size_indices must be a 1D tensor");
389 }
390 shape_inference::ShapeHandle size_indices_value;
391 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &size_indices_value));
392 if (!c->RankKnown(size_indices_value)) {
393 // If we cannot tell the rank of the output from the value of
394 // size_indices, perhaps we can find it from the rank of first operand.
395 return UnchangedRank(c);
396 }
397 c->set_output(0, size_indices_value);
398 return Status::OK();
399 })
400 .Doc(R"doc(
401 Wraps the XLA DynamicSlice operator, documented at
402 https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice
403 .
404
405 DynamicSlice extracts a sub-array from the input array at dynamic
406 start_indices. The size of the slice in each dimension is passed in
407 size_indices, which specify the end point of exclusive slice intervals in each
408 dimension -- [start, start + size). The shape of start_indices must have rank 1,
409 with dimension size equal to the rank of operand.
410
411 input: A `Tensor` of type T.
412
413 start_indices: Rank 1 tensor of N integers containing the starting indices of
414 the slice for each dimension. Value must be greater than or equal to zero.
415
416 start_indices: List of N integers containing the slice size for each
417 dimension. Each value must be strictly greater than zero, and start + size
418 must be less than or equal to the size of the dimension to avoid
419 implementation defined behavior.
420 )doc");
421
422 REGISTER_OP("XlaDynamicUpdateSlice")
423 .Input("input: T")
424 .Input("update: T")
425 .Input("indices: Tindices")
426 .Output("output: T")
427 .Attr("T: type")
428 .Attr("Tindices: {int32, int64}")
429 .SetShapeFn(shape_inference::UnchangedShape)
430 .Doc(R"doc(
431 Wraps the XLA DynamicUpdateSlice operator, documented at
432 https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice
433 .
434
435 XlaDynamicUpdateSlice generates a result which is the value of the `input`
436 operand, with a slice update overwritten at `indices`. The shape of `update`
437 determines the shape of the sub-array of the result which is updated. The shape
438 of indices must be rank == 1, with dimension size equal to the rank of `input`.
439
440 Handling of out-of-bounds slice indices is implementation-defined.
441
442 input: A `Tensor` of type T.
443 indices: A vector of indices into `input`. Must have length equal to the rank of
444 `input`.
445 update: A `Tensor` of type T. Same rank as `input`.
446 output: A `Tensor` of type T.
447 )doc");
448
449 // TODO(b/37549631) setting the If Op to always be stateful is too
450 // conservative.
451 REGISTER_OP("XlaIf")
452 .Input("cond: Tcond")
453 .Input("inputs: Tin")
454 .Output("output: Tout")
455 .Attr("Tcond: type")
456 .Attr("then_branch: func")
457 .Attr("else_branch: func")
458 .Attr("Tin: list(type) >= 0")
459 .Attr("Tout: list(type) >= 0")
460 .SetIsStateful()
461 .SetShapeFn(shape_inference::UnknownShape)
462 .Doc(R"doc(
463 output = cond ? then_branch(inputs) : else_branch(inputs).
464
465 cond: A boolean scalar.
466 inputs: A list of input tensors.
467 output: A list of tensors returned by either then_branch(inputs) or
468 else_branch(inputs). The input shapes of the then_branch and
469 else_branch must match.
470 then_branch: A function takes 'inputs' and returns a list of tensors,
471 whose types are the same as what else_branch returns.
472 else_branch: A function takes 'inputs' and returns a list of tensors.
473 whose types are the same as what then_branch returns.
474 )doc");
475
476 REGISTER_OP("XlaPad")
477 .Input("input: T")
478 .Input("padding_value: T")
479 .Input("padding_low: Tindices")
480 .Input("padding_high: Tindices")
481 .Input("padding_interior: Tindices")
482 .Output("output: T")
483 .Attr("T: type")
484 .Attr("Tindices: {int32, int64}")
__anon592c2f360302(shape_inference::InferenceContext* c) 485 .SetShapeFn([](shape_inference::InferenceContext* c) {
486 shape_inference::ShapeHandle input_shape_handle = c->input(0);
487 if (!c->RankKnown(input_shape_handle)) {
488 return UnchangedRank(c);
489 }
490 const int32_t op_rank = c->Rank(input_shape_handle);
491
492 shape_inference::ShapeHandle padding_shape_handle = c->input(1);
493 if (c->RankKnown(padding_shape_handle) &&
494 c->Rank(padding_shape_handle) != 0) {
495 return errors::InvalidArgument(
496 "padding_value input must be scalar, found rank ",
497 c->Rank(padding_shape_handle));
498 }
499 const Tensor* padding_low_tensor = c->input_tensor(2);
500 const Tensor* padding_high_tensor = c->input_tensor(3);
501 const Tensor* padding_interior_tensor = c->input_tensor(4);
502 if (padding_low_tensor == nullptr || padding_high_tensor == nullptr ||
503 padding_interior_tensor == nullptr) {
504 return UnchangedRank(c);
505 }
506
507 if (padding_low_tensor->shape().dims() != 1 ||
508 padding_low_tensor->shape().dim_size(0) != op_rank) {
509 return errors::InvalidArgument(
510 "padding_low must be a 1D tensor of size ", op_rank);
511 }
512 if (padding_high_tensor->shape().dims() != 1 ||
513 padding_high_tensor->shape().dim_size(0) != op_rank) {
514 return errors::InvalidArgument(
515 "padding_high must be a 1D tensor of size ", op_rank);
516 }
517 if (padding_interior_tensor->shape().dims() != 1 ||
518 padding_interior_tensor->shape().dim_size(0) != op_rank) {
519 return errors::InvalidArgument(
520 "padding_interior must be a 1D tensor of size ", op_rank);
521 }
522 std::vector<shape_inference::DimensionHandle> output_dims;
523 output_dims.reserve(op_rank);
524 for (int64_t i = 0; i < op_rank; ++i) {
525 int64_t low, high, interior;
526 TF_RETURN_IF_ERROR(c->GetScalarFromTensor(padding_low_tensor, i, &low));
527 TF_RETURN_IF_ERROR(
528 c->GetScalarFromTensor(padding_high_tensor, i, &high));
529 TF_RETURN_IF_ERROR(
530 c->GetScalarFromTensor(padding_interior_tensor, i, &interior));
531 if (interior < 0) {
532 return errors::InvalidArgument(
533 "padding_interior must contain only non-negative values, found ",
534 interior);
535 }
536
537 shape_inference::DimensionHandle orig_size_handle =
538 c->Dim(input_shape_handle, i);
539 if (c->ValueKnown(orig_size_handle)) {
540 auto orig_dim = c->Value(orig_size_handle);
541 int64_t new_dim = orig_dim + low + high;
542 if (orig_dim > 0) {
543 new_dim += interior * (orig_dim - 1);
544 }
545 if (new_dim < 0) {
546 return errors::InvalidArgument(
547 "resulting padded dimension has negative size ", new_dim);
548 }
549 output_dims.emplace_back(c->MakeDim(new_dim));
550 } else {
551 output_dims.emplace_back(c->UnknownDim());
552 }
553 }
554
555 c->set_output(0, c->MakeShape(output_dims));
556 return Status::OK();
557 })
558 .Doc(R"doc(
559 Wraps the XLA Pad operator, documented at
560 https://www.tensorflow.org/performance/xla/operation_semantics#pad
561 .
562
563 input: A `Tensor` of type T.
564 padding_value: A scalar `Tensor` of type T.
565 padding_low: the padding to apply at the start of each input dimensions. Must
566 be a compile-time constant 1D tensor of length equal to rank of input.
567 padding_high: the padding to apply at the end of each input dimension. Must
568 be a compile-time constant 1D tensor of length equal to rank of input.
569 padding_interior: the padding to apply between each input element. Must
570 be a compile-time constant 1D tensor of length equal to rank of input,
571 containing only non-negative values.
572 output: A `Tensor` of type T.
573 )doc");
574
575 REGISTER_OP("XlaRecv")
576 .Output("tensor: dtype")
577 .Attr("dtype: type")
578 .Attr("tensor_name: string")
579 .Attr("shape: shape")
580 .SetIsStateful()
__anon592c2f360402(shape_inference::InferenceContext* c) 581 .SetShapeFn([](shape_inference::InferenceContext* c) {
582 TensorShape shape_attr;
583 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
584 shape_inference::ShapeHandle s;
585 TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
586 c->set_output(0, s);
587 return Status::OK();
588 })
589 .Doc(R"doc(
590 Receives the named tensor from another XLA computation. Wraps the XLA Recv
591 operator documented at
592 https://www.tensorflow.org/performance/xla/operation_semantics#recv .
593
594 tensor: The tensor to receive.
595 dtype: The type of the tensor.
596 tensor_name: A string key that identifies the channel.
597 shape: The shape of the tensor.
598 )doc");
599
600 REGISTER_OP("XlaReduce")
601 .Input("input: T")
602 .Input("init_value: T")
603 .Attr("T: {numbertype, bool}")
604 .Attr("dimensions_to_reduce: list(int)")
605 .Attr("reducer: func")
606 .Output("output: T")
__anon592c2f360502(shape_inference::InferenceContext* c) 607 .SetShapeFn([](shape_inference::InferenceContext* c) {
608 if (c->RankKnown(c->input(0))) {
609 int rank = c->Rank(c->input(0));
610 std::vector<int64> dimensions_to_reduce;
611 TF_RETURN_IF_ERROR(
612 c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
613 std::set<int64> dims_set(dimensions_to_reduce.begin(),
614 dimensions_to_reduce.end());
615 auto dim_in_range = [rank](int64_t dim) {
616 return dim >= 0 && dim < rank;
617 };
618 const int dimensions_to_reduce_size = dimensions_to_reduce.size();
619 if (rank < dimensions_to_reduce_size ||
620 dims_set.size() != dimensions_to_reduce.size() ||
621 !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
622 return errors::InvalidArgument(
623 "Invalid dimensions_to_reduce argument to XlaReduce");
624 }
625 c->set_output(
626 0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
627 } else {
628 c->set_output(0, c->input(0));
629 }
630 return Status::OK();
631 })
632 .Doc(R"doc(
633 Wraps the XLA Reduce operator, documented at
634 https://www.tensorflow.org/performance/xla/operation_semantics#reduce .
635
636 input: the input tensor
637 init_value: a scalar representing the initial value for the reduction
638 reducer: a reducer function to apply
639 dimensions_to_reduce: dimension numbers over which to reduce
640 )doc");
641
642 REGISTER_OP("XlaVariadicReduce")
643 .Input("input: N * T")
644 .Input("init_value: N * T")
645 .Attr("N: int >= 1")
646 .Attr("T: {numbertype, bool}")
647 .Attr("dimensions_to_reduce: list(int)")
648 .Attr("reducer: func")
649 .Output("output: N * T")
__anon592c2f360702(shape_inference::InferenceContext* c) 650 .SetShapeFn([](shape_inference::InferenceContext* c) {
651 int n;
652 TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
653 for (int i = 0; i < n; i++) {
654 for (int j = 0; j < n; j++) {
655 c->MergeInput(i, c->input(j));
656 }
657 }
658 if (c->RankKnown(c->input(0))) {
659 int rank = c->Rank(c->input(0));
660 std::vector<int64> dimensions_to_reduce;
661 TF_RETURN_IF_ERROR(
662 c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
663 std::set<int64> dims_set(dimensions_to_reduce.begin(),
664 dimensions_to_reduce.end());
665 auto dim_in_range = [rank](int64_t dim) {
666 return dim >= 0 && dim < rank;
667 };
668 const int dimensions_to_reduce_size = dimensions_to_reduce.size();
669 if (rank < dimensions_to_reduce_size ||
670 dims_set.size() != dimensions_to_reduce.size() ||
671 !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
672 return errors::InvalidArgument(
673 "Invalid dimensions_to_reduce argument to XlaVariadicReduce");
674 }
675 for (int i = 0; i < n; i++) {
676 c->set_output(
677 i, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
678 }
679 } else {
680 for (int i = 0; i < n; i++) {
681 c->set_output(i, c->input(i));
682 }
683 }
684 return Status::OK();
685 })
686 .Doc(R"doc(
687 Wraps the variadic XLA Reduce operator.
688
689 Semantics are documented at
690 https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce.
691
692 This version is limited to operands of the same dtype.
693 XlaVariadicReduceV2 is a version that supports heterogeneous operands.
694
695 input: the input tensor(s)
696 init_value: scalar initial value(s) for the reduction
697 reducer: a reducer function to apply
698 dimensions_to_reduce: dimension numbers over which to reduce
699 )doc");
700
701 REGISTER_OP("XlaVariadicReduceV2")
702 .Input("inputs: T")
703 .Input("init_values: T")
704 .Attr("T: list(type) >= 1")
705 .Attr("dimensions_to_reduce: list(int)")
706 .Attr("reducer: func")
707 .Output("outputs: T")
__anon592c2f360902(shape_inference::InferenceContext* c) 708 .SetShapeFn([](shape_inference::InferenceContext* c) {
709 std::vector<shape_inference::ShapeHandle> input_shapes;
710 TF_RETURN_IF_ERROR(c->input("inputs", &input_shapes));
711 std::vector<shape_inference::ShapeHandle> init_values_shapes;
712 TF_RETURN_IF_ERROR(c->input("init_values", &init_values_shapes));
713 const int nr_inputs = input_shapes.size();
714 if (nr_inputs != init_values_shapes.size()) {
715 return errors::InvalidArgument(
716 "Must specify the same number of inputs and init_values. ", "Got ",
717 nr_inputs, " and ", init_values_shapes.size());
718 }
719 if (nr_inputs == 0) {
720 return errors::InvalidArgument("Must specify at least one input");
721 }
722
723 shape_inference::ShapeHandle input_shape = input_shapes[0];
724 for (int i = 1; i < nr_inputs; ++i) {
725 shape_inference::ShapeHandle merged;
726 TF_RETURN_WITH_CONTEXT_IF_ERROR(
727 c->Merge(input_shape, input_shapes[i], &merged),
728 "All inputs must have the same shape. Input ", i,
729 " (zero-based) has shape ", c->DebugString(input_shapes[i]),
730 " incompatible with the shape ", "inferred from previous inputs ",
731 c->DebugString(input_shape));
732 input_shape = merged;
733 }
734 // All outputs have the same shape
735 shape_inference::ShapeHandle output_shape = c->UnknownShape();
736
737 if (c->RankKnown(input_shape)) {
738 int rank = c->Rank(input_shape);
739
740 std::vector<int64> dimensions_to_reduce;
741 TF_RETURN_IF_ERROR(
742 c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
743 std::set<int64> dims_set(dimensions_to_reduce.begin(),
744 dimensions_to_reduce.end());
745
746 auto dim_in_range = [rank](int64_t dim) {
747 return dim >= 0 && dim < rank;
748 };
749 const int dimensions_to_reduce_size = dimensions_to_reduce.size();
750 if (rank < dimensions_to_reduce_size ||
751 dims_set.size() != dimensions_to_reduce.size() ||
752 !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
753 return errors::InvalidArgument(
754 "Invalid dimensions_to_reduce argument to XlaVariadicReduceV2");
755 }
756
757 std::vector<shape_inference::DimensionHandle> output_dims;
758 for (int64_t i = 0; i < rank; ++i) {
759 if (dims_set.find(i) == dims_set.end()) {
760 output_dims.emplace_back(c->Dim(input_shape, i));
761 }
762 }
763 output_shape = c->MakeShape(output_dims);
764 }
765 for (int i = 0; i < nr_inputs; ++i) {
766 c->set_output(i, output_shape);
767 }
768 return Status::OK();
769 })
770 .Doc(R"doc(
771 Wraps the variadic XLA Reduce operator.
772
773 Semantics are documented at
774 https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce.
775
776 This is an expanded version of XlaVariadicReduce, with support for
777 operands of different dtypes, and improved shape inference.
778
779 inputs: the input tensor(s)
780 init_values: scalar initial value(s) for the reduction
781 reducer: a reducer function to apply
782 dimensions_to_reduce: dimension numbers over which to reduce
783 )doc");
784
785 REGISTER_OP("XlaReduceWindow")
786 .Input("input: T")
787 .Input("init_value: T")
788 .Input("window_dimensions: Tindices")
789 .Input("window_strides: Tindices")
790 .Input("base_dilations: Tindices")
791 .Input("window_dilations: Tindices")
792 .Input("padding: Tindices")
793 .Attr("T: {numbertype, bool}")
794 .Attr("Tindices: {int32, int64}")
795 .Attr("computation: func")
796 .Output("output: T")
797 .SetShapeFn(UnchangedRank)
798 .Doc(R"doc(
799 Wraps the XLA ReduceWindow operator, documented at
800 https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
801
802 input: the input tensor
803 init_value: a scalar representing the initial value for the reduction
804 computation: a reducer function to apply
805 window_dimensions: the shape of the window
806 window_strides: the inter-window strides
807 padding: the padding to apply at the start and end of each input dimensions
808 )doc");
809
810 REGISTER_OP("XlaRngBitGenerator")
811 .Input("algorithm: int32")
812 .Input("initial_state: uint64")
813 .Input("shape: Tshape")
814 .Output("output_key: uint64")
815 .Output("output: dtype")
816 .Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64")
817 .Attr("Tshape: {int32, int64} = DT_INT32")
__anon592c2f360b02(shape_inference::InferenceContext* c) 818 .SetShapeFn([](shape_inference::InferenceContext* c) {
819 shape_inference::ShapeHandle algorithm;
820 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &algorithm));
821 shape_inference::ShapeHandle initial_state;
822 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &initial_state));
823
824 c->set_output(0, initial_state);
825 shape_inference::ShapeHandle output;
826 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output));
827 c->set_output(1, output);
828 return Status::OK();
829 })
830 .Doc(R"doc(
831 Stateless PRNG bit generator.
832 Wraps the XLA RngBitGenerator operator, documented at
833 https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator.
834
835 algorithm: The PRNG algorithm to use, one of
836 tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}.
837 initial_state: Initial state for the PRNG algorithm. For THREEFRY, it should be
838 a u64[2] and for PHILOX a u64[3].
839 shape: The output shape of the generated data.
840 dtype: The type of the tensor.
841 )doc");
842
843 REGISTER_OP("XlaSelectAndScatter")
844 .Input("operand: T")
845 .Input("window_dimensions: Tindices")
846 .Input("window_strides: Tindices")
847 .Input("padding: Tindices")
848 .Input("source: T")
849 .Input("init_value: T")
850 .Attr("T: numbertype")
851 .Attr("Tindices: {int32, int64}")
852 .Attr("select: func")
853 .Attr("scatter: func")
854 .Output("output: T")
855 .SetShapeFn(UnchangedRank)
856 .Doc(R"doc(
857 Wraps the XLA SelectAndScatter operator, documented at
858 https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter
859 .
860
861 operand: the input tensor
862 window_dimensions: the shape of the window
863 window_strides: the inter-window strides
864 padding: the padding to apply at the start and end of each input dimensions
865 source: a tensor of values to scatter
866 init_value: a scalar representing the initial value for the output tensor
867 select: a selection function to apply
868 scatter: a scatter function to apply
869 )doc");
870
871 REGISTER_OP("XlaSend")
872 .Input("tensor: T")
873 .Attr("T: type")
874 .Attr("tensor_name: string")
875 .SetIsStateful()
876 .SetShapeFn(shape_inference::UnknownShape)
877 .Doc(R"doc(
878 Sends the named tensor to another XLA computation. Wraps the XLA Send operator
879 documented at
880 https://www.tensorflow.org/performance/xla/operation_semantics#send .
881
882 tensor: The tensor to send.
883 tensor_name: A string key that identifies the channel.
884 )doc");
885
886 REGISTER_OP("XlaSort")
887 .Input("input: T")
888 .Output("output: T")
889 .Attr("T: type")
890 .SetShapeFn(shape_inference::UnchangedShape)
891 .Doc(R"doc(
892 Wraps the XLA Sort operator, documented at
893 https://www.tensorflow.org/performance/xla/operation_semantics#sort
894 .
895
896 Sorts a tensor. Currently only sorts in ascending order are supported.
897
898 input: A `Tensor` of type T.
899 output: A `Tensor` of type T.
900 )doc");
901
902 REGISTER_OP("XlaKeyValueSort")
903 .Input("keys: K")
904 .Input("values: V")
905 .Output("sorted_keys: K")
906 .Output("sorted_values: V")
907 .Attr("K: realnumbertype")
908 .Attr("V: type")
__anon592c2f360c02(shape_inference::InferenceContext* c) 909 .SetShapeFn([](shape_inference::InferenceContext* c) {
910 c->set_output(0, c->input(0));
911 c->set_output(1, c->input(1));
912 return Status::OK();
913 })
914 .Doc(R"doc(
915 Wraps the XLA Sort operator, documented at
916 https://www.tensorflow.org/performance/xla/operation_semantics#sort
917 .
918
919 Sorts a tensor. Currently only sorts in ascending order are supported.
920
921 keys: A `Tensor` of type K.
922 values: A `Tensor` of type V.
923 sorted_keys: A `Tensor` of type K.
924 sorted_values: A `Tensor` of type V.
925 )doc");
926
927 REGISTER_OP("XlaVariadicSort")
928 .Input("inputs: T")
929 .Input("dimension: int32")
930 .Output("outputs: T")
931 .Attr("T: list(type) >= 1")
932 .Attr("comparator: func")
933 .Attr("is_stable: bool")
__anon592c2f360d02(shape_inference::InferenceContext* c) 934 .SetShapeFn([](shape_inference::InferenceContext* c) {
935 std::vector<shape_inference::ShapeHandle> input_shapes;
936 TF_RETURN_IF_ERROR(c->input("inputs", &input_shapes));
937 TF_RETURN_IF_ERROR(c->set_output("outputs", input_shapes));
938 return Status::OK();
939 })
940 .Doc(R"doc(
941 Wraps the XLA Sort operator, documented at
942 https://www.tensorflow.org/performance/xla/operation_semantics#sort
943 .
944
945 Sorts one or more tensors, with support for custom comparator, dimension, and
946 is_stable attributes.
947
948 inputs: A list of `Tensor` of identical shape but possibly different types.
949 dimension: The dimension along which to sort. Must be a compile-time constant.
950 is_stable: Whether to use stable sort.
951 comparator: A comparator function to apply to 2*N scalars and returning a
952 boolean. N is the number of sort inputs. If you want to sort in ascending
953 order then the comparator should perform a less-than comparison.
954 outputs: A list of `Tensor` of same shape and types as the `input`.
955 )doc");
956
957 // TODO(b/37549631) setting the While Op to always be stateful is too
958 // conservative.
959 REGISTER_OP("XlaWhile")
960 .Input("input: T")
961 .Output("output: T")
962 .Attr("T: list(type) >= 0")
963 .Attr("cond: func")
964 .Attr("body: func")
965 .SetIsStateful()
966 .SetShapeFn(shape_inference::UnknownShape)
967 .Doc(R"doc(
968 output = input; While (Cond(output)) { output = Body(output) }
969
970 input: A list of input tensors whose types are T.
971 output: A list of output tensors whose types are T.
972 cond: A function takes 'input' and returns a tensor. If the tensor is
973 a scalar of non-boolean, the scalar is converted to a boolean
974 according to the following rule: if the scalar is a numerical
975 value, non-zero means True and zero means False; if the scalar is
976 a string, non-empty means True and empty means False. If the
977 tensor is not a scalar, non-emptiness means True and False
978 otherwise.
979 body: A function that takes a list of tensors and returns another
980 list of tensors. Both lists have the same types as specified by T.
981 )doc");
982
983 REGISTER_OP("XlaDequantize")
984 .Input("input: uint32")
985 .Output("output: bfloat16")
986 .Attr("min_range: float")
987 .Attr("max_range: float")
988 .Attr("mode: string")
989 .Attr("transpose_output: bool")
990 .SetIsStateful()
991 .SetShapeFn(shape_inference::UnknownShape)
992 .Doc(R"doc(
993 Takes the packed uint32 input and unpacks the input to uint8 to do
994 Dequantization on device.
995
996 input: Input tensors whose types is uint32, shape is [d0, ..., dn].
997 output: Output tensors whose types is bloat16. If transpose_output is true,
998 output shape is [dn * 4, dn-1, ..., d1, d0]. If transpose_output
999 is false, output shape is [d0,..., dn * 4].
1000 min_range: The minimum scalar value possibly produced for the input.
1001 max_range: The maximum scalar value possibly produced for the input.
1002 mode: String to determine the dequantize mode in {"MIN_COMBINED", "MIN_FIRST", "SCALED"}.
1003 transpose_output: Boolean to determine if output is transposed. transpose_output
1004 is faster when input is large and rank of input is higher than 1.
1005 )doc");
1006
1007 REGISTER_OP("XlaEinsum")
1008 .Input("a: T")
1009 .Input("b: T")
1010 .Output("product: T")
1011 .Attr("equation: string")
1012 .Attr("T: {complex64, bfloat16, float}")
__anon592c2f360e02(shape_inference::InferenceContext* context) 1013 .SetShapeFn([](shape_inference::InferenceContext* context) {
1014 string equation;
1015 TF_RETURN_IF_ERROR(context->GetAttr("equation", &equation));
1016 // XlaEinsum supports only two-input einsum equations.
1017 if (!absl::StrContains(equation, ",")) {
1018 return errors::InvalidArgument("Expected one \",\" in equation. Got: ",
1019 equation);
1020 }
1021 // Use EinsumShape for the rest of the inference now that we know we must
1022 // have a two-input einsum.
1023 return shape_inference::EinsumShape(context);
1024 })
1025 .Doc(R"doc(
1026 An op which supports basic einsum op with 2 inputs and 1 output.
1027
1028 This op has better TPU performance since it doesn't have explicitly reshape and
1029 transpose operations as tf.einsum does.
1030 )doc");
1031
1032 REGISTER_OP("XlaSpmdFullToShardShape")
1033 .Input("input: T")
1034 .Output("output: T")
1035 .Attr("T: type")
1036 .Attr("manual_sharding: string")
__anon592c2f360f02(shape_inference::InferenceContext* c) 1037 .SetShapeFn([](shape_inference::InferenceContext* c) {
1038 auto input_handle = c->input(0);
1039 if (!c->RankKnown(input_handle)) {
1040 return shape_inference::UnknownShape(c);
1041 }
1042 string sharding_attr;
1043 TF_RETURN_IF_ERROR(c->GetAttr("manual_sharding", &sharding_attr));
1044 xla::OpSharding sharding;
1045 sharding.ParseFromString(sharding_attr);
1046 if (sharding.type() != xla::OpSharding::OTHER) {
1047 return shape_inference::UnchangedShape(c);
1048 }
1049 std::vector<shape_inference::DimensionHandle> dims;
1050 for (int64_t i = 0; i < c->Rank(input_handle); ++i) {
1051 auto dim = c->Value(c->Dim(input_handle, i));
1052 int64_t partitions_i = sharding.tile_assignment_dimensions(i);
1053 if (dim != shape_inference::InferenceContext::kUnknownDim &&
1054 partitions_i != 1) {
1055 dim = (dim + partitions_i - 1) / partitions_i;
1056 }
1057 dims.push_back(c->MakeDim(dim));
1058 }
1059 c->set_output(0, c->MakeShape(dims));
1060 return Status::OK();
1061 })
1062 .Doc(R"doc(
1063 An op used by XLA SPMD partitioner to switch from automatic partitioning to
1064 manual partitioning. It annotates the input (full-shape, to be automatically
1065 partitioned) with the same sharding used by manual partitioning, and outputs a
1066 shard-shaped tensor to be consumed by later manually-partitioned ops. If the
1067 shape is not evenly partitionable, the padding region will be masked with 0s.
1068 )doc");
1069
1070 REGISTER_OP("XlaSpmdShardToFullShape")
1071 .Input("input: T")
1072 .Output("output: T")
1073 .Attr("T: type")
1074 .Attr("manual_sharding: string")
1075 .Attr("full_shape: shape")
__anon592c2f361002(shape_inference::InferenceContext* c) 1076 .SetShapeFn([](shape_inference::InferenceContext* c) {
1077 TensorShape shape_attr;
1078 TF_RETURN_IF_ERROR(c->GetAttr("full_shape", &shape_attr));
1079 shape_inference::ShapeHandle s;
1080 TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
1081 c->set_output(0, s);
1082 return Status::OK();
1083 })
1084 .Doc(R"doc(
1085 An op used by XLA SPMD partitioner to switch from manual partitioning to
1086 automatic partitioning. It converts the shard-shaped, manually partitioned input
1087 into full-shaped tensor to be partitioned automatically with the same sharding
1088 used by manual partitioning.
1089 )doc");
1090
1091 REGISTER_OP("XlaSharding")
1092 .Input("input: T")
1093 .Output("output: T")
1094 .Attr("T: type")
1095 .Attr("sharding: string = ''")
1096 .SetShapeFn(shape_inference::UnchangedShape)
1097 .Doc(R"doc(
1098 An op which shards the input based on the given sharding attribute.
1099 )doc");
1100
1101 REGISTER_OP("XlaReplicaId")
1102 .Output("id: int32")
__anon592c2f361102(shape_inference::InferenceContext* context) 1103 .SetShapeFn([](shape_inference::InferenceContext* context) {
1104 context->set_output(0, context->MakeShape({}));
1105 return Status::OK();
1106 })
1107 .Doc("Replica ID.");
1108
1109 REGISTER_OP("XlaGather")
1110 .Input("operand: T")
1111 .Input("start_indices: Tindices")
1112 .Input("slice_sizes: Tindices")
1113 .Attr("dimension_numbers: string")
1114 .Attr("indices_are_sorted: bool")
1115 .Attr("T: {numbertype, bool}")
1116 .Attr("Tindices: {int32, int64}")
1117 .Output("output: T")
1118 .SetShapeFn(shape_inference::UnknownShape)
1119 .Doc(R"doc(
1120 Wraps the XLA Gather operator documented at
1121 https://www.tensorflow.org/xla/operation_semantics#gather
1122 operand: The array we're gathering from.
1123 start_indices: Array containing the starting indices of the slices we gather.
1124 dimension_numbers: A serialized xla::GatherDimensionNumbers proto.
1125 slice_sizes: slice_sizes[i] is the bounds for the slice on dimension i.
1126 indices_are_sorted: Boolean indicating if the indices are sorted.
1127 )doc");
1128
1129 REGISTER_OP("XlaScatter")
1130 .Input("operand: T")
1131 .Input("scatter_indices: Tindices")
1132 .Input("updates: T")
1133 .Attr("update_computation: func")
1134 .Attr("dimension_numbers: string")
1135 .Attr("indices_are_sorted: bool")
1136 .Attr("T: {numbertype, bool}")
1137 .Attr("Tindices: {int32, int64}")
1138 .Output("output: T")
1139 .SetShapeFn(shape_inference::UnchangedShape)
1140 .Doc(R"doc(
1141 Wraps the XLA Scatter operator documented at
1142 https://www.tensorflow.org/xla/operation_semantics#scatter.
1143
1144 operand: Array to be scattered into.
1145 scatter_indices: Array containing the starting indices of the slices that must
1146 be scattered to.
1147 updates: Array containing the values that must be used for scattering.
1148 update_computation: Computation to be used for combining the existing values in
1149 the input array and the updates during scatter.
1150 dimension_numbers: A serialized xla::ScatterDimensionNumbers proto.
1151 indices_are_sorted: Boolean indicating if the indices are sorted.
1152 )doc");
1153
1154 } // namespace
1155 } // namespace tensorflow
1156