1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstddef>
17 #include <vector>
18
19 #include "absl/algorithm/container.h"
20 #include "absl/strings/match.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_split.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/framework/common_shape_fns.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/shape_inference.h"
27 #include "tensorflow/core/lib/core/errors.h"
28
29 // Note: Most of the operators defined in this module are used by the jax2tf
30 // converter (see go/jax2tf for details) and are used in SavedModel produced
31 // by jax2tf. Hence, we need to maintain backwards compatibility for these
32 // operators. Please reach out to the JAX team if you want to make changes.
33
34 namespace tensorflow {
35 namespace {
36
37 // Helper shape function for operators that return an output with the same rank
38 // as their first input.
UnchangedRank(shape_inference::InferenceContext * c)39 Status UnchangedRank(shape_inference::InferenceContext* c) {
40 if (c->RankKnown(c->input(0))) {
41 c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
42 } else {
43 c->set_output(0, c->input(0));
44 }
45 return OkStatus();
46 }
47
48 REGISTER_OP("XlaBroadcastHelper")
49 .Input("lhs: T")
50 .Input("rhs: T")
51 .Input("broadcast_dims: Tindices")
52 .Attr("T: numbertype")
53 .Attr("Tindices: {int32, int64}")
54 .Output("lhs_output: T")
55 .Output("rhs_output: T")
56 .SetShapeFn(shape_inference::UnknownShape)
57 .Doc(R"doc(
58 Helper operator for performing XLA-style broadcasts
59
60 Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to
61 whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules
62 for binary operators.
63
64 lhs: the LHS input tensor
65 rhs: the RHS input tensor
66 broadcast_dims: an XLA-style broadcast dimension specification
67 lhs_output: the broadcasted LHS tensor
68 rhs_output: the broadcasted RHS tensor
69 )doc");
70
71 REGISTER_OP("XlaSelfAdjointEig")
72 .Input("a: T")
73 .Attr("lower: bool")
74 .Attr("max_iter: int")
75 .Attr("epsilon: float")
76 .Output("w: T")
77 .Output("v: T")
78 .SetShapeFn(shape_inference::UnknownShape)
79 .Attr("T: numbertype")
80 .Doc(R"doc(
81 Computes the eigen decomposition of a batch of self-adjoint matrices
82 (Note: Only real inputs are supported).
83
84 Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in
85 tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for
86 i=0...N-1.
87
88 a: the input tensor.
89
90 lower: a boolean specifies whether the calculation is done with the lower
91 triangular part or the upper triangular part.
92
93 max_iter: maximum number of sweep update, i.e., the whole lower triangular
94 part or upper triangular part based on parameter lower. Heuristically, it has
95 been argued that approximately logN sweeps are needed in practice (Ref: Golub &
96 van Loan "Matrix Computation").
97
98 epsilon: the tolerance ratio.
99
100 w: The eigenvalues in ascending order, each repeated according to its
101 multiplicity.
102 v: The column v[..., :, i] is the normalized eigenvector corresponding to the
103 eigenvalue w[..., i].
104 )doc");
105
106 REGISTER_OP("XlaSvd")
107 .Input("a: T")
108 .Attr("max_iter: int")
109 .Attr("epsilon: float")
110 .Attr("precision_config: string")
111 .Output("s: T")
112 .Output("u: T")
113 .Output("v: T")
114 .SetShapeFn(shape_inference::UnknownShape)
115 .Attr("T: numbertype")
116 .Doc(R"doc(
117 Computes the eigen decomposition of a batch of self-adjoint matrices
118 (Note: Only real inputs are supported).
119
120 Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in
121 tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]).
122
123 a: the input tensor.
124
125 max_iter: maximum number of sweep update, i.e., the whole lower triangular
126 part or upper triangular part based on parameter lower. Heuristically, it has
127 been argued that approximately log(min (M, N)) sweeps are needed in practice
128 (Ref: Golub & van Loan "Matrix Computation").
129
130 epsilon: the tolerance ratio.
131
132 precision_config: a serialized xla::PrecisionConfig proto.
133
134 s: Singular values. The values are sorted in reverse order of magnitude, so
135 s[..., 0] is the largest value, s[..., 1] is the second largest, etc.
136 u: Left singular vectors.
137 v: Right singular vectors.
138 )doc");
139
140 REGISTER_OP("XlaConv")
141 .Input("lhs: T")
142 .Input("rhs: T")
143 .Input("window_strides: Tindices")
144 .Input("padding: Tindices")
145 .Input("lhs_dilation: Tindices")
146 .Input("rhs_dilation: Tindices")
147 .Input("feature_group_count: Tindices")
148 .Attr("T: numbertype")
149 .Attr("Tindices: {int32, int64}")
150 .Attr("dimension_numbers: string")
151 .Attr("precision_config: string")
152 .Output("output: T")
153 .SetShapeFn(UnchangedRank)
154 .Doc(R"doc(
155 Wraps the XLA ConvGeneralDilated operator, documented at
156 https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
157 .
158
159 lhs: the input tensor
160 rhs: the kernel tensor
161 window_strides: the inter-window strides
162 padding: the padding to apply at the start and end of each input dimensions
163 lhs_dilation: dilation to apply between input elements
164 rhs_dilation: dilation to apply between kernel elements
165 feature_group_count: number of feature groups for grouped convolution.
166 dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
167 precision_config: a serialized xla::PrecisionConfig proto.
168 )doc");
169
170 REGISTER_OP("XlaConvV2")
171 .Input("lhs: LhsT")
172 .Input("rhs: RhsT")
173 .Input("window_strides: Tindices")
174 .Input("padding: Tindices")
175 .Input("lhs_dilation: Tindices")
176 .Input("rhs_dilation: Tindices")
177 .Input("feature_group_count: Tindices")
178 .Attr("LhsT: numbertype")
179 .Attr("RhsT: numbertype")
180 .Attr("Tindices: {int32, int64}")
181 .Attr("dimension_numbers: string")
182 .Attr("precision_config: string")
183 .Attr("preferred_element_type: numbertype")
184 .Attr("batch_group_count: int = 1")
185 .Output("output: preferred_element_type")
186 .SetShapeFn(UnchangedRank)
187 .Doc(R"doc(
188 Wraps the XLA ConvGeneralDilated operator, documented at
189 https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
190 .
191
192 lhs: input tensor
193 rhs: kernel tensor
194 window_strides: inter-window strides
195 padding: padding to apply at the start and end of each input dimensions
196 lhs_dilation: dilation to apply between input elements
197 rhs_dilation: dilation to apply between kernel elements
198 feature_group_count: number of feature groups for grouped convolution.
199 dimension_numbers: serialized xla::ConvolutionDimensionNumbers proto.
200 precision_config: serialized xla::PrecisionConfig proto.
201 preferred_element_type: type of the tensor.
202 batch_group_count: number of batch groups or grouped filters.
203 )doc");
204
XlaDotShapeFunction(shape_inference::InferenceContext * c)205 static Status XlaDotShapeFunction(shape_inference::InferenceContext* c) {
206 shape_inference::ShapeHandle lhs_shape_handle = c->input(0);
207 shape_inference::ShapeHandle rhs_shape_handle = c->input(1);
208 if (!c->RankKnown(lhs_shape_handle) || !c->RankKnown(rhs_shape_handle)) {
209 return shape_inference::UnknownShape(c);
210 }
211
212 string dimension_numbers_string;
213 TF_RETURN_IF_ERROR(
214 c->GetAttr("dimension_numbers", &dimension_numbers_string));
215
216 xla::DotDimensionNumbers dimension_numbers;
217 dimension_numbers.ParseFromString(dimension_numbers_string);
218
219 // Check that number of contracting dimensions match.
220 if (dimension_numbers.lhs_contracting_dimensions_size() !=
221 dimension_numbers.rhs_contracting_dimensions_size())
222 return errors::InvalidArgument(
223 "Must specify the same number of contracting dimensions for lhs "
224 "and rhs. Got: ",
225 dimension_numbers.lhs_contracting_dimensions_size(), " and ",
226 dimension_numbers.rhs_contracting_dimensions_size());
227
228 // Check that contracting dimension sizes match.
229 for (int64_t i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
230 ++i) {
231 const int64_t lhs_contracting_dimension =
232 dimension_numbers.lhs_contracting_dimensions(i);
233 const int64_t rhs_contracting_dimension =
234 dimension_numbers.rhs_contracting_dimensions(i);
235 shape_inference::DimensionHandle unused;
236 TF_RETURN_WITH_CONTEXT_IF_ERROR(
237 c->Merge(c->DimKnownRank(lhs_shape_handle, lhs_contracting_dimension),
238 c->DimKnownRank(rhs_shape_handle, rhs_contracting_dimension),
239 &unused),
240 "For contracting dimension ", i, " which is lhs dimension ",
241 lhs_contracting_dimension, " and rhs dimension ",
242 rhs_contracting_dimension);
243 }
244
245 // Check that number of batch dimensions match.
246 if (dimension_numbers.lhs_batch_dimensions_size() !=
247 dimension_numbers.rhs_batch_dimensions_size())
248 return errors::InvalidArgument(
249 "Must specify the same number of batch dimensions for lhs "
250 "and rhs. Got: ",
251 dimension_numbers.lhs_batch_dimensions_size(), " and ",
252 dimension_numbers.rhs_batch_dimensions_size());
253
254 // The ranks of lhs and rhs are decremented by the number of contractions,
255 // and added for the rank of the result. When an input tensor
256 // is a scalar, its contribution to the rank of the result is 0. Generate
257 // the result dimensions in order, batch dimensions, then the
258 // non-contracted and non-batch lhs and rhs dimensions.
259 std::vector<shape_inference::DimensionHandle> output_dims;
260
261 // Check that batch dimension sizes match, and add them to output_dims.
262 for (int64_t i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
263 const int64_t lhs_batch_dimension =
264 dimension_numbers.lhs_batch_dimensions(i);
265 const int64_t rhs_batch_dimension =
266 dimension_numbers.rhs_batch_dimensions(i);
267 shape_inference::DimensionHandle out;
268 TF_RETURN_WITH_CONTEXT_IF_ERROR(
269 c->Merge(c->DimKnownRank(lhs_shape_handle, lhs_batch_dimension),
270 c->DimKnownRank(rhs_shape_handle, rhs_batch_dimension), &out),
271 "For batch dimension ", i, " which is lhs dimension ",
272 lhs_batch_dimension, " and rhs dimension ", rhs_batch_dimension);
273 output_dims.emplace_back(out);
274 }
275
276 const int32_t lhs_rank = c->Rank(lhs_shape_handle);
277 for (int64_t i = 0; i < lhs_rank; ++i) {
278 if (absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
279 i) ||
280 absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
281 continue;
282 }
283 output_dims.emplace_back(c->Dim(lhs_shape_handle, i));
284 }
285
286 const int32_t rhs_rank = c->Rank(rhs_shape_handle);
287 for (int64_t i = 0; i < rhs_rank; ++i) {
288 if (absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
289 i) ||
290 absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
291 continue;
292 }
293 output_dims.emplace_back(c->Dim(rhs_shape_handle, i));
294 }
295
296 c->set_output(0, c->MakeShape(output_dims));
297 return OkStatus();
298 }
299
300 REGISTER_OP("XlaDot")
301 .Input("lhs: T")
302 .Input("rhs: T")
303 .Attr("T: numbertype")
304 .Attr("dimension_numbers: string")
305 .Attr("precision_config: string")
306 .Output("output: T")
307 .SetShapeFn(XlaDotShapeFunction)
308 .Doc(R"doc(
309 Wraps the XLA DotGeneral operator, documented at
310 https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
311 .
312
313 lhs: the LHS tensor
314 rhs: the RHS tensor
315 dimension_numbers: a serialized xla::DotDimensionNumbers proto.
316 precision_config: a serialized xla::PrecisionConfig proto.
317 )doc");
318
319 REGISTER_OP("XlaDotV2")
320 .Input("lhs: LhsT")
321 .Input("rhs: RhsT")
322 .Attr("LhsT: numbertype")
323 .Attr("RhsT: numbertype")
324 .Attr("dimension_numbers: string")
325 .Attr("precision_config: string")
326 .Attr("preferred_element_type: numbertype")
327 .Output("output: preferred_element_type")
328 .SetShapeFn(XlaDotShapeFunction)
329 .Doc(R"doc(
330 Wraps the XLA DotGeneral operator, documented at
331 https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
332 .
333
334 lhs: the LHS tensor
335 rhs: the RHS tensor
336 dimension_numbers: a serialized xla::DotDimensionNumbers proto.
337 precision_config: a serialized xla::PrecisionConfig proto.
338 preferred_element_type: The type of the tensor.
339 )doc");
340
341 REGISTER_OP("XlaSetBound")
342 .Input("input: int32")
343 .Input("bound: int32")
344 .Output("output: int32")
345 .SetShapeFn(shape_inference::UnknownShape)
346 .Doc(
347 R"doc(Set a bound for the given input value as a hint to Xla compiler,
348 returns the same value.
349 )doc");
350
351 REGISTER_OP("XlaSetDynamicDimensionSize")
352 .Input("input: T")
353 .Input("dim_index: int32")
354 .Input("size: int32")
355 .Output("output: T")
356 .Attr("T: type")
357 // Use unknown shape to prevent constant folding.
358 .SetShapeFn(shape_inference::UnknownShape)
359 .Doc(
360 R"doc(Make a static dimension into a xla bounded dynamic dimension.
361 The current static dimension size will become the bound and the second
362 operand becomes the dynamic size of the dimension.)doc");
363
364 REGISTER_OP("XlaRemoveDynamicDimensionSize")
365 .Input("input: T")
366 .Input("dim_index: int32")
367 .Output("output: T")
368 .Attr("T: type")
369 // Use unknown shape to prevent constant folding.
370 .SetShapeFn(shape_inference::UnknownShape)
371 .Doc(R"doc(
372 Inverse of XlaSetDynamicDimensionSize.
373
374 Make an xla bounded dynamic dimension into a static dimension. The bound of the
375 size of dimension `dim_index` becomes the static dimension size.
376 )doc");
377
378 REGISTER_OP("XlaDynamicSlice")
379 .Input("input: T")
380 .Input("start_indices: Tindices")
381 .Input("size_indices: Tindices")
382 .Output("output: T")
383 .Attr("T: type")
384 .Attr("Tindices: {int32, int64}")
__anonb720852f0202(shape_inference::InferenceContext* c) 385 .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
386 shape_inference::ShapeHandle size_indices_shape = c->input(2);
387 if (!c->RankKnown(size_indices_shape)) {
388 return UnchangedRank(c);
389 }
390 if (c->Rank(size_indices_shape) != 1) {
391 return errors::InvalidArgument("size_indices must be a 1D tensor");
392 }
393 shape_inference::ShapeHandle size_indices_value;
394 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &size_indices_value));
395 if (!c->RankKnown(size_indices_value)) {
396 // If we cannot tell the rank of the output from the value of
397 // size_indices, perhaps we can find it from the rank of first operand.
398 return UnchangedRank(c);
399 }
400 c->set_output(0, size_indices_value);
401 return OkStatus();
402 })
403 .Doc(R"doc(
404 Wraps the XLA DynamicSlice operator, documented at
405 https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice
406 .
407
408 DynamicSlice extracts a sub-array from the input array at dynamic
409 start_indices. The size of the slice in each dimension is passed in
410 size_indices, which specify the end point of exclusive slice intervals in each
411 dimension -- [start, start + size). The shape of start_indices must have rank 1,
412 with dimension size equal to the rank of operand.
413
414 input: A `Tensor` of type T.
415
416 start_indices: Rank 1 tensor of N integers containing the starting indices of
417 the slice for each dimension. Value must be greater than or equal to zero.
418
419 start_indices: List of N integers containing the slice size for each
420 dimension. Each value must be strictly greater than zero, and start + size
421 must be less than or equal to the size of the dimension to avoid
422 implementation defined behavior.
423 )doc");
424
425 REGISTER_OP("XlaDynamicUpdateSlice")
426 .Input("input: T")
427 .Input("update: T")
428 .Input("indices: Tindices")
429 .Output("output: T")
430 .Attr("T: type")
431 .Attr("Tindices: {int32, int64}")
432 .SetShapeFn(shape_inference::UnchangedShape)
433 .Doc(R"doc(
434 Wraps the XLA DynamicUpdateSlice operator, documented at
435 https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice
436 .
437
438 XlaDynamicUpdateSlice generates a result which is the value of the `input`
439 operand, with a slice update overwritten at `indices`. The shape of `update`
440 determines the shape of the sub-array of the result which is updated. The shape
441 of indices must be rank == 1, with dimension size equal to the rank of `input`.
442
443 Handling of out-of-bounds slice indices is implementation-defined.
444
445 input: A `Tensor` of type T.
446 indices: A vector of indices into `input`. Must have length equal to the rank of
447 `input`.
448 update: A `Tensor` of type T. Same rank as `input`.
449 output: A `Tensor` of type T.
450 )doc");
451
452 // TODO(b/37549631) setting the If Op to always be stateful is too
453 // conservative.
454 REGISTER_OP("XlaIf")
455 .Input("cond: Tcond")
456 .Input("inputs: Tin")
457 .Output("output: Tout")
458 .Attr("Tcond: type")
459 .Attr("then_branch: func")
460 .Attr("else_branch: func")
461 .Attr("Tin: list(type) >= 0")
462 .Attr("Tout: list(type) >= 0")
463 .SetIsStateful()
464 .SetShapeFn(shape_inference::UnknownShape)
465 .Doc(R"doc(
466 output = cond ? then_branch(inputs) : else_branch(inputs).
467
468 cond: A boolean scalar.
469 inputs: A list of input tensors.
470 output: A list of tensors returned by either then_branch(inputs) or
471 else_branch(inputs). The input shapes of the then_branch and
472 else_branch must match.
473 then_branch: A function takes 'inputs' and returns a list of tensors,
474 whose types are the same as what else_branch returns.
475 else_branch: A function takes 'inputs' and returns a list of tensors.
476 whose types are the same as what then_branch returns.
477 )doc");
478
479 REGISTER_OP("XlaPad")
480 .Input("input: T")
481 .Input("padding_value: T")
482 .Input("padding_low: Tindices")
483 .Input("padding_high: Tindices")
484 .Input("padding_interior: Tindices")
485 .Output("output: T")
486 .Attr("T: type")
487 .Attr("Tindices: {int32, int64}")
__anonb720852f0302(shape_inference::InferenceContext* c) 488 .SetShapeFn([](shape_inference::InferenceContext* c) {
489 shape_inference::ShapeHandle input_shape_handle = c->input(0);
490 if (!c->RankKnown(input_shape_handle)) {
491 return UnchangedRank(c);
492 }
493 const int32_t op_rank = c->Rank(input_shape_handle);
494
495 shape_inference::ShapeHandle padding_shape_handle = c->input(1);
496 if (c->RankKnown(padding_shape_handle) &&
497 c->Rank(padding_shape_handle) != 0) {
498 return errors::InvalidArgument(
499 "padding_value input must be scalar, found rank ",
500 c->Rank(padding_shape_handle));
501 }
502 const Tensor* padding_low_tensor = c->input_tensor(2);
503 const Tensor* padding_high_tensor = c->input_tensor(3);
504 const Tensor* padding_interior_tensor = c->input_tensor(4);
505 if (padding_low_tensor == nullptr || padding_high_tensor == nullptr ||
506 padding_interior_tensor == nullptr) {
507 return UnchangedRank(c);
508 }
509
510 if (padding_low_tensor->shape().dims() != 1 ||
511 padding_low_tensor->shape().dim_size(0) != op_rank) {
512 return errors::InvalidArgument(
513 "padding_low must be a 1D tensor of size ", op_rank);
514 }
515 if (padding_high_tensor->shape().dims() != 1 ||
516 padding_high_tensor->shape().dim_size(0) != op_rank) {
517 return errors::InvalidArgument(
518 "padding_high must be a 1D tensor of size ", op_rank);
519 }
520 if (padding_interior_tensor->shape().dims() != 1 ||
521 padding_interior_tensor->shape().dim_size(0) != op_rank) {
522 return errors::InvalidArgument(
523 "padding_interior must be a 1D tensor of size ", op_rank);
524 }
525 std::vector<shape_inference::DimensionHandle> output_dims;
526 output_dims.reserve(op_rank);
527 for (int64_t i = 0; i < op_rank; ++i) {
528 int64_t low, high, interior;
529 TF_RETURN_IF_ERROR(c->GetScalarFromTensor(padding_low_tensor, i, &low));
530 TF_RETURN_IF_ERROR(
531 c->GetScalarFromTensor(padding_high_tensor, i, &high));
532 TF_RETURN_IF_ERROR(
533 c->GetScalarFromTensor(padding_interior_tensor, i, &interior));
534 if (interior < 0) {
535 return errors::InvalidArgument(
536 "padding_interior must contain only non-negative values, found ",
537 interior);
538 }
539
540 shape_inference::DimensionHandle orig_size_handle =
541 c->Dim(input_shape_handle, i);
542 if (c->ValueKnown(orig_size_handle)) {
543 auto orig_dim = c->Value(orig_size_handle);
544 int64_t new_dim = orig_dim + low + high;
545 if (orig_dim > 0) {
546 new_dim += interior * (orig_dim - 1);
547 }
548 if (new_dim < 0) {
549 return errors::InvalidArgument(
550 "resulting padded dimension has negative size ", new_dim);
551 }
552 output_dims.emplace_back(c->MakeDim(new_dim));
553 } else {
554 output_dims.emplace_back(c->UnknownDim());
555 }
556 }
557
558 c->set_output(0, c->MakeShape(output_dims));
559 return OkStatus();
560 })
561 .Doc(R"doc(
562 Wraps the XLA Pad operator, documented at
563 https://www.tensorflow.org/performance/xla/operation_semantics#pad
564 .
565
566 input: A `Tensor` of type T.
567 padding_value: A scalar `Tensor` of type T.
568 padding_low: the padding to apply at the start of each input dimensions. Must
569 be a compile-time constant 1D tensor of length equal to rank of input.
570 padding_high: the padding to apply at the end of each input dimension. Must
571 be a compile-time constant 1D tensor of length equal to rank of input.
572 padding_interior: the padding to apply between each input element. Must
573 be a compile-time constant 1D tensor of length equal to rank of input,
574 containing only non-negative values.
575 output: A `Tensor` of type T.
576 )doc");
577
578 REGISTER_OP("XlaRecv")
579 .Output("tensor: dtype")
580 .Attr("dtype: type")
581 .Attr("tensor_name: string")
582 .Attr("shape: shape")
583 .SetIsStateful()
__anonb720852f0402(shape_inference::InferenceContext* c) 584 .SetShapeFn([](shape_inference::InferenceContext* c) {
585 TensorShape shape_attr;
586 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
587 shape_inference::ShapeHandle s;
588 TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
589 c->set_output(0, s);
590 return OkStatus();
591 })
592 .Doc(R"doc(
593 Receives the named tensor from another XLA computation. Wraps the XLA Recv
594 operator documented at
595 https://www.tensorflow.org/performance/xla/operation_semantics#recv .
596
597 tensor: The tensor to receive.
598 dtype: The type of the tensor.
599 tensor_name: A string key that identifies the channel.
600 shape: The shape of the tensor.
601 )doc");
602
603 REGISTER_OP("XlaReduce")
604 .Input("input: T")
605 .Input("init_value: T")
606 .Attr("T: {numbertype, bool}")
607 .Attr("dimensions_to_reduce: list(int)")
608 .Attr("reducer: func")
609 .Output("output: T")
__anonb720852f0502(shape_inference::InferenceContext* c) 610 .SetShapeFn([](shape_inference::InferenceContext* c) {
611 if (c->RankKnown(c->input(0))) {
612 int rank = c->Rank(c->input(0));
613 std::vector<int64_t> dimensions_to_reduce;
614 TF_RETURN_IF_ERROR(
615 c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
616 std::set<int64_t> dims_set(dimensions_to_reduce.begin(),
617 dimensions_to_reduce.end());
618 auto dim_in_range = [rank](int64_t dim) {
619 return dim >= 0 && dim < rank;
620 };
621 const int dimensions_to_reduce_size = dimensions_to_reduce.size();
622 if (rank < dimensions_to_reduce_size ||
623 dims_set.size() != dimensions_to_reduce.size() ||
624 !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
625 return errors::InvalidArgument(
626 "Invalid dimensions_to_reduce argument to XlaReduce");
627 }
628 c->set_output(
629 0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
630 } else {
631 c->set_output(0, c->input(0));
632 }
633 return OkStatus();
634 })
635 .Doc(R"doc(
636 Wraps the XLA Reduce operator, documented at
637 https://www.tensorflow.org/performance/xla/operation_semantics#reduce .
638
639 input: the input tensor
640 init_value: a scalar representing the initial value for the reduction
641 reducer: a reducer function to apply
642 dimensions_to_reduce: dimension numbers over which to reduce
643 )doc");
644
645 REGISTER_OP("XlaVariadicReduce")
646 .Input("input: N * T")
647 .Input("init_value: N * T")
648 .Attr("N: int >= 1")
649 .Attr("T: {numbertype, bool}")
650 .Attr("dimensions_to_reduce: list(int)")
651 .Attr("reducer: func")
652 .Output("output: N * T")
__anonb720852f0702(shape_inference::InferenceContext* c) 653 .SetShapeFn([](shape_inference::InferenceContext* c) {
654 int n;
655 TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
656 for (int i = 0; i < n; i++) {
657 for (int j = 0; j < n; j++) {
658 c->MergeInput(i, c->input(j));
659 }
660 }
661 if (c->RankKnown(c->input(0))) {
662 int rank = c->Rank(c->input(0));
663 std::vector<int64_t> dimensions_to_reduce;
664 TF_RETURN_IF_ERROR(
665 c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
666 std::set<int64_t> dims_set(dimensions_to_reduce.begin(),
667 dimensions_to_reduce.end());
668 auto dim_in_range = [rank](int64_t dim) {
669 return dim >= 0 && dim < rank;
670 };
671 const int dimensions_to_reduce_size = dimensions_to_reduce.size();
672 if (rank < dimensions_to_reduce_size ||
673 dims_set.size() != dimensions_to_reduce.size() ||
674 !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
675 return errors::InvalidArgument(
676 "Invalid dimensions_to_reduce argument to XlaVariadicReduce");
677 }
678 for (int i = 0; i < n; i++) {
679 c->set_output(
680 i, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
681 }
682 } else {
683 for (int i = 0; i < n; i++) {
684 c->set_output(i, c->input(i));
685 }
686 }
687 return OkStatus();
688 })
689 .Doc(R"doc(
690 Wraps the variadic XLA Reduce operator.
691
692 Semantics are documented at
693 https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce.
694
695 This version is limited to operands of the same dtype.
696 XlaVariadicReduceV2 is a version that supports heterogeneous operands.
697
698 input: the input tensor(s)
699 init_value: scalar initial value(s) for the reduction
700 reducer: a reducer function to apply
701 dimensions_to_reduce: dimension numbers over which to reduce
702 )doc");
703
704 REGISTER_OP("XlaVariadicReduceV2")
705 .Input("inputs: T")
706 .Input("init_values: T")
707 .Attr("T: list(type) >= 1")
708 .Attr("dimensions_to_reduce: list(int)")
709 .Attr("reducer: func")
710 .Output("outputs: T")
__anonb720852f0902(shape_inference::InferenceContext* c) 711 .SetShapeFn([](shape_inference::InferenceContext* c) {
712 std::vector<shape_inference::ShapeHandle> input_shapes;
713 TF_RETURN_IF_ERROR(c->input("inputs", &input_shapes));
714 std::vector<shape_inference::ShapeHandle> init_values_shapes;
715 TF_RETURN_IF_ERROR(c->input("init_values", &init_values_shapes));
716 const int nr_inputs = input_shapes.size();
717 if (nr_inputs != init_values_shapes.size()) {
718 return errors::InvalidArgument(
719 "Must specify the same number of inputs and init_values. ", "Got ",
720 nr_inputs, " and ", init_values_shapes.size());
721 }
722 if (nr_inputs == 0) {
723 return errors::InvalidArgument("Must specify at least one input");
724 }
725
726 shape_inference::ShapeHandle input_shape = input_shapes[0];
727 for (int i = 1; i < nr_inputs; ++i) {
728 shape_inference::ShapeHandle merged;
729 TF_RETURN_WITH_CONTEXT_IF_ERROR(
730 c->Merge(input_shape, input_shapes[i], &merged),
731 "All inputs must have the same shape. Input ", i,
732 " (zero-based) has shape ", c->DebugString(input_shapes[i]),
733 " incompatible with the shape ", "inferred from previous inputs ",
734 c->DebugString(input_shape));
735 input_shape = merged;
736 }
737 // All outputs have the same shape
738 shape_inference::ShapeHandle output_shape = c->UnknownShape();
739
740 if (c->RankKnown(input_shape)) {
741 int rank = c->Rank(input_shape);
742
743 std::vector<int64_t> dimensions_to_reduce;
744 TF_RETURN_IF_ERROR(
745 c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
746 std::set<int64_t> dims_set(dimensions_to_reduce.begin(),
747 dimensions_to_reduce.end());
748
749 auto dim_in_range = [rank](int64_t dim) {
750 return dim >= 0 && dim < rank;
751 };
752 const int dimensions_to_reduce_size = dimensions_to_reduce.size();
753 if (rank < dimensions_to_reduce_size ||
754 dims_set.size() != dimensions_to_reduce.size() ||
755 !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
756 return errors::InvalidArgument(
757 "Invalid dimensions_to_reduce argument to XlaVariadicReduceV2");
758 }
759
760 std::vector<shape_inference::DimensionHandle> output_dims;
761 for (int64_t i = 0; i < rank; ++i) {
762 if (dims_set.find(i) == dims_set.end()) {
763 output_dims.emplace_back(c->Dim(input_shape, i));
764 }
765 }
766 output_shape = c->MakeShape(output_dims);
767 }
768 for (int i = 0; i < nr_inputs; ++i) {
769 c->set_output(i, output_shape);
770 }
771 return OkStatus();
772 })
773 .Doc(R"doc(
774 Wraps the variadic XLA Reduce operator.
775
776 Semantics are documented at
777 https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce.
778
779 This is an expanded version of XlaVariadicReduce, with support for
780 operands of different dtypes, and improved shape inference.
781
782 inputs: the input tensor(s)
783 init_values: scalar initial value(s) for the reduction
784 reducer: a reducer function to apply
785 dimensions_to_reduce: dimension numbers over which to reduce
786 )doc");
787
788 REGISTER_OP("XlaReduceWindow")
789 .Input("input: T")
790 .Input("init_value: T")
791 .Input("window_dimensions: Tindices")
792 .Input("window_strides: Tindices")
793 .Input("base_dilations: Tindices")
794 .Input("window_dilations: Tindices")
795 .Input("padding: Tindices")
796 .Attr("T: {numbertype, bool}")
797 .Attr("Tindices: {int32, int64}")
798 .Attr("computation: func")
799 .Output("output: T")
800 .SetShapeFn(UnchangedRank)
801 .Doc(R"doc(
802 Wraps the XLA ReduceWindow operator, documented at
803 https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
804
805 input: the input tensor
806 init_value: a scalar representing the initial value for the reduction
807 computation: a reducer function to apply
808 window_dimensions: the shape of the window
809 window_strides: the inter-window strides
810 padding: the padding to apply at the start and end of each input dimensions
811 )doc");
812
813 REGISTER_OP("XlaRngBitGenerator")
814 .Input("algorithm: int32")
815 .Input("initial_state: uint64")
816 .Input("shape: Tshape")
817 .Output("output_key: uint64")
818 .Output("output: dtype")
819 .Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64")
820 .Attr("Tshape: {int32, int64} = DT_INT32")
__anonb720852f0b02(shape_inference::InferenceContext* c) 821 .SetShapeFn([](shape_inference::InferenceContext* c) {
822 shape_inference::ShapeHandle algorithm;
823 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &algorithm));
824 shape_inference::ShapeHandle initial_state;
825 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &initial_state));
826
827 c->set_output(0, initial_state);
828 shape_inference::ShapeHandle output;
829 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output));
830 c->set_output(1, output);
831 return OkStatus();
832 })
833 .Doc(R"doc(
834 Stateless PRNG bit generator.
835 Wraps the XLA RngBitGenerator operator, documented at
836 https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator.
837
838 algorithm: The PRNG algorithm to use, one of
839 tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}.
840 initial_state: Initial state for the PRNG algorithm. For THREEFRY, it should be
841 a u64[2] and for PHILOX a u64[3].
842 shape: The output shape of the generated data.
843 dtype: The type of the tensor.
844 )doc");
845
846 REGISTER_OP("XlaSelectAndScatter")
847 .Input("operand: T")
848 .Input("window_dimensions: Tindices")
849 .Input("window_strides: Tindices")
850 .Input("padding: Tindices")
851 .Input("source: T")
852 .Input("init_value: T")
853 .Attr("T: numbertype")
854 .Attr("Tindices: {int32, int64}")
855 .Attr("select: func")
856 .Attr("scatter: func")
857 .Output("output: T")
858 .SetShapeFn(UnchangedRank)
859 .Doc(R"doc(
860 Wraps the XLA SelectAndScatter operator, documented at
861 https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter
862 .
863
864 operand: the input tensor
865 window_dimensions: the shape of the window
866 window_strides: the inter-window strides
867 padding: the padding to apply at the start and end of each input dimensions
868 source: a tensor of values to scatter
869 init_value: a scalar representing the initial value for the output tensor
870 select: a selection function to apply
871 scatter: a scatter function to apply
872 )doc");
873
874 REGISTER_OP("XlaSend")
875 .Input("tensor: T")
876 .Attr("T: type")
877 .Attr("tensor_name: string")
878 .SetIsStateful()
879 .SetShapeFn(shape_inference::UnknownShape)
880 .Doc(R"doc(
881 Sends the named tensor to another XLA computation. Wraps the XLA Send operator
882 documented at
883 https://www.tensorflow.org/performance/xla/operation_semantics#send .
884
885 tensor: The tensor to send.
886 tensor_name: A string key that identifies the channel.
887 )doc");
888
889 REGISTER_OP("XlaSort")
890 .Input("input: T")
891 .Output("output: T")
892 .Attr("T: type")
893 .SetShapeFn(shape_inference::UnchangedShape)
894 .Doc(R"doc(
895 Wraps the XLA Sort operator, documented at
896 https://www.tensorflow.org/performance/xla/operation_semantics#sort
897 .
898
899 Sorts a tensor. Currently only sorts in ascending order are supported.
900
901 input: A `Tensor` of type T.
902 output: A `Tensor` of type T.
903 )doc");
904
905 REGISTER_OP("XlaKeyValueSort")
906 .Input("keys: K")
907 .Input("values: V")
908 .Output("sorted_keys: K")
909 .Output("sorted_values: V")
910 .Attr("K: realnumbertype")
911 .Attr("V: type")
__anonb720852f0c02(shape_inference::InferenceContext* c) 912 .SetShapeFn([](shape_inference::InferenceContext* c) {
913 c->set_output(0, c->input(0));
914 c->set_output(1, c->input(1));
915 return OkStatus();
916 })
917 .Doc(R"doc(
918 Wraps the XLA Sort operator, documented at
919 https://www.tensorflow.org/performance/xla/operation_semantics#sort
920 .
921
922 Sorts a tensor. Currently only sorts in ascending order are supported.
923
924 keys: A `Tensor` of type K.
925 values: A `Tensor` of type V.
926 sorted_keys: A `Tensor` of type K.
927 sorted_values: A `Tensor` of type V.
928 )doc");
929
930 REGISTER_OP("XlaVariadicSort")
931 .Input("inputs: T")
932 .Input("dimension: int32")
933 .Output("outputs: T")
934 .Attr("T: list(type) >= 1")
935 .Attr("comparator: func")
936 .Attr("is_stable: bool")
__anonb720852f0d02(shape_inference::InferenceContext* c) 937 .SetShapeFn([](shape_inference::InferenceContext* c) {
938 std::vector<shape_inference::ShapeHandle> input_shapes;
939 TF_RETURN_IF_ERROR(c->input("inputs", &input_shapes));
940 TF_RETURN_IF_ERROR(c->set_output("outputs", input_shapes));
941 return OkStatus();
942 })
943 .Doc(R"doc(
944 Wraps the XLA Sort operator, documented at
945 https://www.tensorflow.org/performance/xla/operation_semantics#sort
946 .
947
948 Sorts one or more tensors, with support for custom comparator, dimension, and
949 is_stable attributes.
950
951 inputs: A list of `Tensor` of identical shape but possibly different types.
952 dimension: The dimension along which to sort. Must be a compile-time constant.
953 is_stable: Whether to use stable sort.
954 comparator: A comparator function to apply to 2*N scalars and returning a
955 boolean. N is the number of sort inputs. If you want to sort in ascending
956 order then the comparator should perform a less-than comparison.
957 outputs: A list of `Tensor` of same shape and types as the `input`.
958 )doc");
959
960 // TODO(b/37549631) setting the While Op to always be stateful is too
961 // conservative.
962 REGISTER_OP("XlaWhile")
963 .Input("input: T")
964 .Output("output: T")
965 .Attr("T: list(type) >= 0")
966 .Attr("cond: func")
967 .Attr("body: func")
968 .SetIsStateful()
969 .SetShapeFn(shape_inference::UnknownShape)
970 .Doc(R"doc(
971 output = input; While (Cond(output)) { output = Body(output) }
972
973 input: A list of input tensors whose types are T.
974 output: A list of output tensors whose types are T.
975 cond: A function takes 'input' and returns a tensor. If the tensor is
976 a scalar of non-boolean, the scalar is converted to a boolean
977 according to the following rule: if the scalar is a numerical
978 value, non-zero means True and zero means False; if the scalar is
979 a string, non-empty means True and empty means False. If the
980 tensor is not a scalar, non-emptiness means True and False
981 otherwise.
982 body: A function that takes a list of tensors and returns another
983 list of tensors. Both lists have the same types as specified by T.
984 )doc");
985
986 REGISTER_OP("XlaDequantize")
987 .Input("input: uint32")
988 .Output("output: bfloat16")
989 .Attr("min_range: float")
990 .Attr("max_range: float")
991 .Attr("mode: string")
992 .Attr("transpose_output: bool")
993 .SetIsStateful()
994 .SetShapeFn(shape_inference::UnknownShape)
995 .Doc(R"doc(
996 Takes the packed uint32 input and unpacks the input to uint8 to do
997 Dequantization on device.
998
999 input: Input tensors whose types is uint32, shape is [d0, ..., dn].
1000 output: Output tensors whose types is bloat16. If transpose_output is true,
1001 output shape is [dn * 4, dn-1, ..., d1, d0]. If transpose_output
1002 is false, output shape is [d0,..., dn * 4].
1003 min_range: The minimum scalar value possibly produced for the input.
1004 max_range: The maximum scalar value possibly produced for the input.
1005 mode: String to determine the dequantize mode in {"MIN_COMBINED", "MIN_FIRST", "SCALED"}.
1006 transpose_output: Boolean to determine if output is transposed. transpose_output
1007 is faster when input is large and rank of input is higher than 1.
1008 )doc");
1009
1010 REGISTER_OP("XlaEinsum")
1011 .Input("a: T")
1012 .Input("b: T")
1013 .Output("product: T")
1014 .Attr("equation: string")
1015 .Attr("T: {complex64, bfloat16, float}")
__anonb720852f0e02(shape_inference::InferenceContext* context) 1016 .SetShapeFn([](shape_inference::InferenceContext* context) {
1017 string equation;
1018 TF_RETURN_IF_ERROR(context->GetAttr("equation", &equation));
1019 // XlaEinsum supports only two-input einsum equations.
1020 if (!absl::StrContains(equation, ",")) {
1021 return errors::InvalidArgument("Expected one \",\" in equation. Got: ",
1022 equation);
1023 }
1024 // Use EinsumShape for the rest of the inference now that we know we must
1025 // have a two-input einsum.
1026 return shape_inference::EinsumShape(context);
1027 })
1028 .Doc(R"doc(
1029 An op which supports basic einsum op with 2 inputs and 1 output.
1030
1031 This op has better TPU performance since it doesn't have explicitly reshape and
1032 transpose operations as tf.einsum does.
1033 )doc");
1034
1035 REGISTER_OP("XlaSpmdFullToShardShape")
1036 .Input("input: T")
1037 .Output("output: T")
1038 .Attr("T: type")
1039 .Attr("manual_sharding: string")
1040 .Attr("dim: int = -1")
1041 .Attr("unspecified_dims: list(int) = []")
__anonb720852f0f02(shape_inference::InferenceContext* c) 1042 .SetShapeFn([](shape_inference::InferenceContext* c) {
1043 auto input_handle = c->input(0);
1044 if (!c->RankKnown(input_handle)) {
1045 return shape_inference::UnknownShape(c);
1046 }
1047 string sharding_attr;
1048 TF_RETURN_IF_ERROR(c->GetAttr("manual_sharding", &sharding_attr));
1049 int32 single_dim;
1050 TF_RETURN_IF_ERROR(c->GetAttr("dim", &single_dim));
1051 xla::OpSharding sharding;
1052 sharding.ParseFromString(sharding_attr);
1053 if (sharding.type() != xla::OpSharding::OTHER) {
1054 return shape_inference::UnchangedShape(c);
1055 }
1056 std::vector<shape_inference::DimensionHandle> dims;
1057 for (int64_t i = 0; i < c->Rank(input_handle); ++i) {
1058 auto dim = c->Value(c->Dim(input_handle, i));
1059 if (single_dim < 0 || single_dim == i) {
1060 int64_t partitions_i = sharding.tile_assignment_dimensions(i);
1061 if (dim != shape_inference::InferenceContext::kUnknownDim &&
1062 partitions_i != 1) {
1063 dim = (dim + partitions_i - 1) / partitions_i;
1064 }
1065 }
1066 dims.push_back(c->MakeDim(dim));
1067 }
1068 c->set_output(0, c->MakeShape(dims));
1069 return OkStatus();
1070 })
1071 .Doc(R"doc(
1072 An op used by XLA SPMD partitioner to switch from automatic partitioning to
1073 manual partitioning. It annotates the input (full-shape, to be automatically
1074 partitioned) with the same sharding used by manual partitioning, and outputs a
1075 shard-shaped tensor to be consumed by later manually-partitioned ops. If the
1076 shape is not evenly partitionable, the padding region will be masked with 0s.
1077 The conversion can happen partially in subgroups, by specifying the dim
1078 attribute, where only that dim will be converted.
1079 )doc");
1080
1081 REGISTER_OP("XlaSpmdShardToFullShape")
1082 .Input("input: T")
1083 .Output("output: T")
1084 .Attr("T: type")
1085 .Attr("manual_sharding: string")
1086 .Attr("full_shape: shape")
1087 .Attr("dim: int = -1")
1088 .Attr("unspecified_dims: list(int) = []")
__anonb720852f1002(shape_inference::InferenceContext* c) 1089 .SetShapeFn([](shape_inference::InferenceContext* c) {
1090 TensorShape shape_attr;
1091 TF_RETURN_IF_ERROR(c->GetAttr("full_shape", &shape_attr));
1092 shape_inference::ShapeHandle s;
1093 TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
1094 c->set_output(0, s);
1095 return OkStatus();
1096 })
1097 .Doc(R"doc(
1098 An op used by XLA SPMD partitioner to switch from manual partitioning to
1099 automatic partitioning. It converts the shard-shaped, manually partitioned input
1100 into full-shaped tensor to be partitioned automatically with the same sharding
1101 used by manual partitioning. The conversion can happen partially in subgroups,
1102 by specifying the dim attribute, where only that dim will be converted.
1103 )doc");
1104
1105 REGISTER_OP("XlaSharding")
1106 .Input("input: T")
1107 .Output("output: T")
1108 .Attr("T: type")
1109 .Attr("sharding: string = ''")
1110 .Attr("unspecified_dims: list(int) = []")
1111 .SetShapeFn(shape_inference::UnchangedShape)
1112 .Doc(R"doc(
1113 An op which shards the input based on the given sharding attribute. It can
1114 selectively annotate a subset of tensor dimensions by skipping unspecified_dims,
1115 and the sharding annotation should be replicated in those dims.
1116 )doc");
1117
1118 REGISTER_OP("XlaReplicaId")
1119 .Output("id: int32")
__anonb720852f1102(shape_inference::InferenceContext* context) 1120 .SetShapeFn([](shape_inference::InferenceContext* context) {
1121 context->set_output(0, context->MakeShape({}));
1122 return OkStatus();
1123 })
1124 .Doc("Replica ID.");
1125
1126 REGISTER_OP("XlaGather")
1127 .Input("operand: T")
1128 .Input("start_indices: Tindices")
1129 .Input("slice_sizes: Tindices")
1130 .Attr("dimension_numbers: string")
1131 .Attr("indices_are_sorted: bool")
1132 .Attr("T: {numbertype, bool}")
1133 .Attr("Tindices: {int32, int64}")
1134 .Output("output: T")
1135 .SetShapeFn(shape_inference::UnknownShape)
1136 .Doc(R"doc(
1137 Wraps the XLA Gather operator documented at
1138 https://www.tensorflow.org/xla/operation_semantics#gather
1139 operand: The array we're gathering from.
1140 start_indices: Array containing the starting indices of the slices we gather.
1141 dimension_numbers: A serialized xla::GatherDimensionNumbers proto.
1142 slice_sizes: slice_sizes[i] is the bounds for the slice on dimension i.
1143 indices_are_sorted: Boolean indicating if the indices are sorted.
1144 )doc");
1145
1146 REGISTER_OP("XlaScatter")
1147 .Input("operand: T")
1148 .Input("scatter_indices: Tindices")
1149 .Input("updates: T")
1150 .Attr("update_computation: func")
1151 .Attr("dimension_numbers: string")
1152 .Attr("indices_are_sorted: bool")
1153 .Attr("T: {numbertype, bool}")
1154 .Attr("Tindices: {int32, int64}")
1155 .Output("output: T")
1156 .SetShapeFn(shape_inference::UnchangedShape)
1157 .Doc(R"doc(
1158 Wraps the XLA Scatter operator documented at
1159 https://www.tensorflow.org/xla/operation_semantics#scatter.
1160
1161 operand: Array to be scattered into.
1162 scatter_indices: Array containing the starting indices of the slices that must
1163 be scattered to.
1164 updates: Array containing the values that must be used for scattering.
1165 update_computation: Computation to be used for combining the existing values in
1166 the input array and the updates during scatter.
1167 dimension_numbers: A serialized xla::ScatterDimensionNumbers proto.
1168 indices_are_sorted: Boolean indicating if the indices are sorted.
1169 )doc");
1170
1171 REGISTER_OP("XlaAllReduce")
1172 .Input("input: T")
1173 .Input("group_assignment: int32")
1174 .Output("output: T")
1175 .Attr("T: {half, bfloat16, float, int32, uint32}")
1176 .Attr("reduce_op: {'Min', 'Max', 'Mul', 'Add', 'Mean'}")
1177 .Attr("mode: {'CrossReplica', 'CrossReplicaAndPartition'}")
1178 .SetShapeFn(shape_inference::UnchangedShape)
1179 .Doc(R"doc(
1180 Wraps the XLA AllReduce operator
1181 documented at https://www.tensorflow.org/xla/operation_semantics#allreduce.
1182
1183 input: Array or a non-empty tuple of arrays to reduce across replicas.
1184 group_assignment: Groups between which the reductions are performed.
1185 reduce_op: Reduction computation.
1186 mode: group mode.
1187 CrossReplica: group_assignment contains replica_id. Each group contains the
1188 replicas for the current partition.
1189 CrossReplicaAndPartition: group_assignment contains replica_id. Each group
1190 contains the replicas for all partitions.
1191 )doc");
1192
1193 REGISTER_OP("XlaReduceScatter")
1194 .Input("input: T")
1195 .Input("group_assignment: int32")
1196 .Input("scatter_dimension: int32")
1197 .Output("output: T")
1198 .Attr("T: {half, bfloat16, float, int32, uint32}")
1199 .Attr("reduce_op: {'Min', 'Max', 'Mul', 'Add', 'Mean'}")
1200 .SetShapeFn(shape_inference::ReduceScatterShape)
1201 .Doc(R"doc(
1202 Wraps the XLA ReduceScatter operator
1203 documented at https://www.tensorflow.org/xla/operation_semantics#reducescatter.
1204
1205 input: Array or a non-empty tuple of arrays to reduce across replicas.
1206 group_assignment: Groups between which the reductions are performed.
1207 scatter_dimension: Dimension to scatter.
1208 reduce_op: Reduction computation.
1209 )doc");
1210
OptimizationBarrierShape(shape_inference::InferenceContext * c)1211 Status OptimizationBarrierShape(shape_inference::InferenceContext* c) {
1212 for (int i = 0; i < c->num_inputs(); ++i) {
1213 c->set_output(i, c->input(i));
1214 }
1215 return OkStatus();
1216 }
1217
1218 REGISTER_OP("XlaOptimizationBarrier")
1219 .Input("input: T")
1220 .Output("output: T")
1221 .Attr("T: list(type) >= 0")
1222 .SetShapeFn(OptimizationBarrierShape)
1223 .Doc(R"doc(
1224 Wraps the XLA OptimizationBarrier operator.
1225
1226 Documented at https://www.tensorflow.org/xla/operation_semantics#optimizationbarrier.
1227
1228 input: A Tuple of Arrays of any type.
1229 )doc");
1230
1231 REGISTER_OP("XlaReducePrecision")
1232 .Input("operand: T")
1233 .Output("output: T")
1234 .Attr("T: {bfloat16, half, float, double}")
1235 .Attr("exponent_bits: int")
1236 .Attr("mantissa_bits: int")
1237 .SetShapeFn(shape_inference::UnchangedShape)
1238 .Doc(R"doc(
1239 Wraps the XLA ReducePrecision operator
1240 documented at https://www.tensorflow.org/xla/operation_semantics#reduceprecision.
1241
1242 operand: array of floating-point type.
1243 exponent_bits: number of exponent bits in lower-precision format
1244 mantissa_bits: number of mantissa bits in lower-precision format
1245 )doc");
1246
1247 REGISTER_OP("XlaCustomCall")
1248 .Input("args: T")
1249 .Output("output: dtype")
1250 .Attr("target_name: string")
1251 .Attr("backend_config: string")
1252 .Attr("T: list(type) >= 0")
1253 .Attr("dtype: type")
1254 .Attr("shape: shape")
__anonb720852f1202(shape_inference::InferenceContext* c) 1255 .SetShapeFn([](shape_inference::InferenceContext* c) {
1256 TensorShape shape_attr;
1257 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
1258 shape_inference::ShapeHandle s;
1259 TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
1260 c->set_output(0, s);
1261 return OkStatus();
1262 })
1263 .Doc(R"doc(
1264 Wraps the XLA CustomCall operator
1265 documented at https://www.tensorflow.org/xla/operation_semantics#customcall.
1266
1267 args: A list of `Tensor` with possibly different types.
1268 target_name: Name of the function. A call instruction will be emitted which
1269 targets this symbol name.
1270 backend_config: String, used to encode serialized metadata to the backend.
1271 dtype: Output tensor data type.
1272 shape: Output tensor shape.
1273 )doc");
1274
1275 REGISTER_OP("XlaCallModule")
1276 .Input("args: Tin")
1277 .Output("output: Tout")
1278 .Attr("module: string")
1279 .Attr("Sout: list(shape) >= 0")
1280 .Attr("Tout: list(type) >= 0")
1281 .Attr("Tin: list(type) >= 0")
1282 .Attr("dim_args_spec: list(string) >= 0")
__anonb720852f1302(shape_inference::InferenceContext* c) 1283 .SetShapeFn([](shape_inference::InferenceContext* c) {
1284 // For debugging
1285 VLOG(3) << "XlaCallModule.shape_inference";
1286 std::vector<shape_inference::ShapeHandle> args_shapes;
1287 TF_RETURN_IF_ERROR(c->input("args", &args_shapes));
1288 for (int i = 0; i < args_shapes.size(); ++i) {
1289 VLOG(3) << "XlaCallModule.shape_inference args[" << i
1290 << "] : " << c->DebugString(args_shapes[i]);
1291 }
1292 std::vector<PartialTensorShape> shapes_attr;
1293 TF_RETURN_IF_ERROR(c->GetAttr("Sout", &shapes_attr));
1294 for (int i = 0; i < shapes_attr.size(); ++i) {
1295 shape_inference::ShapeHandle s;
1296 TF_RETURN_IF_ERROR(
1297 c->MakeShapeFromPartialTensorShape(shapes_attr[i], &s));
1298 VLOG(3) << "XlaCallModule.shape_inference out[" << i
1299 << "] : " << c->DebugString(s);
1300 c->set_output(i, s);
1301 }
1302 return OkStatus();
1303 })
1304 .Doc(R"doc(
1305 Temporary op for experimenting with jax2tf.
1306
1307 DO NOT USE THIS OP. It has no backwards compatibility guarantees. It is also
1308 very likely to change. This op will be used only in jax2tf under an
1309 experimental flag.
1310
1311 This is an experimental op to allow a smooth evolution of jax2tf towards
1312 emitting and serializing MHLO directly from JAX. At the moment this op
1313 carries a serialized MHLO module, therefore there are no backward-compatibility
1314 guarantees, and should not be used for serialization.
1315 Eventually, the op will carry a MHLO object, which will have
1316 backwards-compatibility guarantees.
1317
1318 The serialized module must return a tuple if and only if the Sout is an empty
1319 list or a list with more than 1 elements. The length of Tout and Sout must
1320 match. This op always returns a tuple of results, even if the module returns
1321 a single result.
1322
1323 The handling of dynamic shapes is work-in-progress. At the moment, the
1324 JAX lowering for dynamic shapes will prepend one dimension parameter to the
1325 serialized module for each dimension whose value must be passed in.
1326 The "args" correspond to the non-dimension arguments. During compilation
1327 we compute the values of the dimension arguments based on the static shapes of
1328 the "args". In order to do this, we encode for each dimension argument a
1329 specification of how to compute its value, as a string, in the form
1330 "<arg_idx>.<axis_idx>".
1331 E.g., the specification "2.1" denotes the value args[2].shape[1].
1332
1333 args: A list of `Tensor` with possibly different types to be passed as arguments
1334 to the HLO module.
1335 module: A serialized computation, a text representation of mlir.Module.
1336 Tout: List of output tensor data types.
1337 Sout: List of output tensor shapes.
1338 dim_args_spec: the specification for the dimension arguments, one for each
1339 dimension argument. In absence of dynamic shapes this list is empty.
1340 )doc");
1341
1342 } // namespace
1343 } // namespace tensorflow
1344