• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/algorithm/container.h"
17 #include "absl/strings/match.h"
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/str_split.h"
20 #include "tensorflow/compiler/xla/xla_data.pb.h"
21 #include "tensorflow/core/framework/common_shape_fns.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/shape_inference.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 
26 namespace tensorflow {
27 namespace {
28 
29 // Helper shape function for operators that return an output with the same rank
30 // as their first input.
UnchangedRank(shape_inference::InferenceContext * c)31 Status UnchangedRank(shape_inference::InferenceContext* c) {
32   if (c->RankKnown(c->input(0))) {
33     c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
34   } else {
35     c->set_output(0, c->input(0));
36   }
37   return Status::OK();
38 }
39 
40 REGISTER_OP("XlaBroadcastHelper")
41     .Input("lhs: T")
42     .Input("rhs: T")
43     .Input("broadcast_dims: Tindices")
44     .Attr("T: numbertype")
45     .Attr("Tindices: {int32, int64}")
46     .Output("lhs_output: T")
47     .Output("rhs_output: T")
48     .SetShapeFn(shape_inference::UnknownShape)
49     .Doc(R"doc(
50 Helper operator for performing XLA-style broadcasts
51 
52 Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to
53 whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules
54 for binary operators.
55 
56 lhs: the LHS input tensor
57 rhs: the RHS input tensor
58 broadcast_dims: an XLA-style broadcast dimension specification
59 lhs_output: the broadcasted LHS tensor
60 rhs_output: the broadcasted RHS tensor
61 )doc");
62 
63 REGISTER_OP("XlaSelfAdjointEig")
64     .Input("a: T")
65     .Attr("lower: bool")
66     .Attr("max_iter: int")
67     .Attr("epsilon: float")
68     .Output("w: T")
69     .Output("v: T")
70     .SetShapeFn(shape_inference::UnknownShape)
71     .Attr("T: numbertype")
72     .Doc(R"doc(
73 Computes the eigen decomposition of a batch of self-adjoint matrices
74 (Note: Only real inputs are supported).
75 
76 Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in
77 tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for
78 i=0...N-1.
79 
80 a: the input tensor.
81 
82 lower: a boolean specifies whether the calculation is done with the lower
83   triangular part or the upper triangular part.
84 
85 max_iter: maximum number of sweep update, i.e., the whole lower triangular
86   part or upper triangular part based on parameter lower. Heuristically, it has
87   been argued that approximately logN sweeps are needed in practice (Ref: Golub &
88   van Loan "Matrix Computation").
89 
90 epsilon: the tolerance ratio.
91 
92 w: The eigenvalues in ascending order, each repeated according to its
93   multiplicity.
94 v: The column v[..., :, i] is the normalized eigenvector corresponding to the
95   eigenvalue w[..., i].
96 )doc");
97 
98 REGISTER_OP("XlaSvd")
99     .Input("a: T")
100     .Attr("max_iter: int")
101     .Attr("epsilon: float")
102     .Attr("precision_config: string")
103     .Output("s: T")
104     .Output("u: T")
105     .Output("v: T")
106     .SetShapeFn(shape_inference::UnknownShape)
107     .Attr("T: numbertype")
108     .Doc(R"doc(
109 Computes the eigen decomposition of a batch of self-adjoint matrices
110 (Note: Only real inputs are supported).
111 
112 Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in
113 tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]).
114 
115 a: the input tensor.
116 
117 max_iter: maximum number of sweep update, i.e., the whole lower triangular
118   part or upper triangular part based on parameter lower. Heuristically, it has
119   been argued that approximately log(min (M, N)) sweeps are needed in practice
120   (Ref: Golub & van Loan "Matrix Computation").
121 
122 epsilon: the tolerance ratio.
123 
124 precision_config: a serialized xla::PrecisionConfig proto.
125 
126 s: Singular values. The values are sorted in reverse order of magnitude, so
127   s[..., 0] is the largest value, s[..., 1] is the second largest, etc.
128 u: Left singular vectors.
129 v: Right singular vectors.
130 )doc");
131 
132 REGISTER_OP("XlaConv")
133     .Input("lhs: T")
134     .Input("rhs: T")
135     .Input("window_strides: Tindices")
136     .Input("padding: Tindices")
137     .Input("lhs_dilation: Tindices")
138     .Input("rhs_dilation: Tindices")
139     .Input("feature_group_count: Tindices")
140     .Attr("T: numbertype")
141     .Attr("Tindices: {int32, int64}")
142     .Attr("dimension_numbers: string")
143     .Attr("precision_config: string")
144     .Output("output: T")
145     .SetShapeFn(UnchangedRank)
146     .Doc(R"doc(
147 Wraps the XLA ConvGeneralDilated operator, documented at
148  https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
149 .
150 
151 lhs: the input tensor
152 rhs: the kernel tensor
153 window_strides: the inter-window strides
154 padding: the padding to apply at the start and end of each input dimensions
155 lhs_dilation: dilation to apply between input elements
156 rhs_dilation: dilation to apply between kernel elements
157 feature_group_count: number of feature groups for grouped convolution.
158 dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
159 precision_config: a serialized xla::PrecisionConfig proto.
160 )doc");
161 
162 REGISTER_OP("XlaDot")
163     .Input("lhs: T")
164     .Input("rhs: T")
165     .Attr("T: numbertype")
166     .Attr("dimension_numbers: string")
167     .Attr("precision_config: string")
168     .Output("output: T")
__anon00ae58f50202(shape_inference::InferenceContext* c) 169     .SetShapeFn([](shape_inference::InferenceContext* c) {
170       shape_inference::ShapeHandle lhs_shape_handle = c->input(0);
171       shape_inference::ShapeHandle rhs_shape_handle = c->input(1);
172       if (!c->FullyDefined(lhs_shape_handle) ||
173           !c->FullyDefined(rhs_shape_handle)) {
174         return shape_inference::UnknownShape(c);
175       }
176 
177       string dimension_numbers_string;
178       TF_RETURN_IF_ERROR(
179           c->GetAttr("dimension_numbers", &dimension_numbers_string));
180 
181       xla::DotDimensionNumbers dimension_numbers;
182       dimension_numbers.ParseFromString(dimension_numbers_string);
183 
184       // Check that number of contracting dimensions match.
185       if (dimension_numbers.lhs_contracting_dimensions_size() !=
186           dimension_numbers.rhs_contracting_dimensions_size())
187         return errors::InvalidArgument(
188             "Must specify the same number of contracting dimensions for lhs "
189             "and rhs. Got: ",
190             dimension_numbers.lhs_contracting_dimensions_size(), " and ",
191             dimension_numbers.rhs_contracting_dimensions_size());
192 
193       // Check that contracting dimension sizes match.
194       for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
195            ++i) {
196         const int64 lhs_contracting_dimension =
197             dimension_numbers.lhs_contracting_dimensions(i);
198         const int64 rhs_contracting_dimension =
199             dimension_numbers.rhs_contracting_dimensions(i);
200         shape_inference::DimensionOrConstant
201             lhs_contracting_dimension_or_constant(
202                 c->DimKnownRank(lhs_shape_handle, lhs_contracting_dimension));
203         shape_inference::DimensionOrConstant
204             rhs_contracting_dimension_or_constant(
205                 c->DimKnownRank(rhs_shape_handle, rhs_contracting_dimension));
206         const int64 lhs_contracting_dimension_size =
207             c->Value(lhs_contracting_dimension_or_constant);
208         const int64 rhs_contracting_dimension_size =
209             c->Value(rhs_contracting_dimension_or_constant);
210         if (lhs_contracting_dimension_size != rhs_contracting_dimension_size) {
211           return errors::InvalidArgument(
212               "Contracting dimension sizes do not match. Got: ",
213               lhs_contracting_dimension_size, " and ",
214               rhs_contracting_dimension_size);
215         }
216       }
217 
218       // Check that number of batch dimensions match.
219       if (dimension_numbers.lhs_batch_dimensions_size() !=
220           dimension_numbers.rhs_batch_dimensions_size())
221         return errors::InvalidArgument(
222             "Must specify the same number of batch dimensions for lhs "
223             "and rhs. Got: ",
224             dimension_numbers.lhs_batch_dimensions_size(), " and ",
225             dimension_numbers.rhs_batch_dimensions_size());
226 
227       // Check that batch dimension sizes match.
228       for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size();
229            ++i) {
230         const int64 lhs_batch_dimension =
231             dimension_numbers.lhs_batch_dimensions(i);
232         const int64 rhs_batch_dimension =
233             dimension_numbers.rhs_batch_dimensions(i);
234         shape_inference::DimensionOrConstant lhs_batch_dimension_or_constant(
235             c->DimKnownRank(lhs_shape_handle, lhs_batch_dimension));
236         shape_inference::DimensionOrConstant rhs_batch_dimension_or_constant(
237             c->DimKnownRank(rhs_shape_handle, rhs_batch_dimension));
238         const int64 lhs_batch_dimension_size =
239             c->Value(lhs_batch_dimension_or_constant);
240         const int64 rhs_batch_dimension_size =
241             c->Value(rhs_batch_dimension_or_constant);
242         if (lhs_batch_dimension_size != rhs_batch_dimension_size) {
243           return errors::InvalidArgument(
244               "Batch dimension sizes do not match. Got: ",
245               lhs_batch_dimension_size, " and ", rhs_batch_dimension_size);
246         }
247       }
248 
249       // The ranks of lhs and rhs are decremented by 1 respectively due to the
250       // contraction, and added for the rank of the result. When an input tensor
251       // is a scalar, its contribution to the rank of the result is 0. Generate
252       // the result dimensions in order, rhs dimensions followed by lhs
253       // dimensions except the contracted and batch dimensions.
254       std::vector<shape_inference::DimensionHandle> output_dims;
255       for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
256         output_dims.emplace_back(c->Dim(lhs_shape_handle, lhs_dim));
257       }
258       const int32 lhs_rank = c->Rank(lhs_shape_handle);
259       for (int64 i = 0; i < lhs_rank; ++i) {
260         if (absl::c_linear_search(
261                 dimension_numbers.lhs_contracting_dimensions(), i) ||
262             absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(),
263                                   i)) {
264           continue;
265         }
266         output_dims.emplace_back(c->Dim(lhs_shape_handle, i));
267       }
268 
269       const int32 rhs_rank = c->Rank(rhs_shape_handle);
270       for (int64 i = 0; i < rhs_rank; ++i) {
271         if (absl::c_linear_search(
272                 dimension_numbers.rhs_contracting_dimensions(), i) ||
273             absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(),
274                                   i)) {
275           continue;
276         }
277         output_dims.emplace_back(c->Dim(rhs_shape_handle, i));
278       }
279 
280       c->set_output(0, c->MakeShape(output_dims));
281       return Status::OK();
282     })
283     .Doc(R"doc(
284 Wraps the XLA DotGeneral operator, documented at
285  https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
286 .
287 
288 lhs: the LHS tensor
289 rhs: the RHS tensor
290 dimension_numbers: a serialized xla::DotDimensionNumbers proto.
291 precision_config: a serialized xla::PrecisionConfig proto.
292 )doc");
293 
294 REGISTER_OP("XlaDynamicSlice")
295     .Input("input: T")
296     .Input("start_indices: Tindices")
297     .Input("size_indices: Tindices")
298     .Output("output: T")
299     .Attr("T: type")
300     .Attr("Tindices: {int32, int64}")
301     .SetShapeFn(shape_inference::UnknownShape)
302     .Doc(R"doc(
303 Wraps the XLA DynamicSlice operator, documented at
304  https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice
305 .
306 
307 DynamicSlice extracts a sub-array from the input array at dynamic
308 start_indices. The size of the slice in each dimension is passed in
309 size_indices, which specify the end point of exclusive slice intervals in each
310 dimension -- [start, start + size). The shape of start_indices must have rank 1,
311 with dimension size equal to the rank of operand.
312 
313 input: A `Tensor` of type T.
314 
315 start_indices: Rank 1 tensor of N integers containing the starting indices of
316   the slice for each dimension. Value must be greater than or equal to zero.
317 
318 start_indices: List of N integers containing the slice size for each
319   dimension. Each value must be strictly greater than zero, and start + size
320   must be less than or equal to the size of the dimension to avoid
321   implementation defined behavior.
322 )doc");
323 
324 REGISTER_OP("XlaDynamicUpdateSlice")
325     .Input("input: T")
326     .Input("update: T")
327     .Input("indices: Tindices")
328     .Output("output: T")
329     .Attr("T: type")
330     .Attr("Tindices: {int32, int64}")
331     .SetShapeFn(shape_inference::UnchangedShape)
332     .Doc(R"doc(
333 Wraps the XLA DynamicUpdateSlice operator, documented at
334  https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice
335 .
336 
337 XlaDynamicUpdateSlice generates a result which is the value of the `input`
338 operand, with a slice update overwritten at `indices`. The shape of `update`
339 determines the shape of the sub-array of the result which is updated. The shape
340 of indices must be rank == 1, with dimension size equal to the rank of `input`.
341 
342 Handling of out-of-bounds slice indices is implementation-defined.
343 
344 input: A `Tensor` of type T.
345 indices: A vector of indices into `input`. Must have length equal to the rank of
346   `input`.
347 update: A `Tensor` of type T. Same rank as `input`.
348 output: A `Tensor` of type T.
349 )doc");
350 
351 // TODO(b/37549631) setting the If Op to always be stateful is too
352 // conservative.
353 REGISTER_OP("XlaIf")
354     .Input("cond: Tcond")
355     .Input("inputs: Tin")
356     .Output("output: Tout")
357     .Attr("Tcond: type")
358     .Attr("then_branch: func")
359     .Attr("else_branch: func")
360     .Attr("Tin: list(type) >= 0")
361     .Attr("Tout: list(type) >= 0")
362     .SetIsStateful()
363     .SetShapeFn(shape_inference::UnknownShape)
364     .Doc(R"doc(
365 output = cond ? then_branch(inputs) : else_branch(inputs).
366 
367 cond: A boolean scalar.
368 inputs: A list of input tensors.
369 output: A list of tensors returned by either then_branch(inputs) or
370         else_branch(inputs). The input shapes of the then_branch and
371         else_branch must match.
372 then_branch: A function takes 'inputs' and returns a list of tensors,
373              whose types are the same as what else_branch returns.
374 else_branch: A function takes 'inputs' and returns a list of tensors.
375              whose types are the same as what then_branch returns.
376 )doc");
377 
378 REGISTER_OP("XlaPad")
379     .Input("input: T")
380     .Input("padding_value: T")
381     .Input("padding_low: Tindices")
382     .Input("padding_high: Tindices")
383     .Input("padding_interior: Tindices")
384     .Output("output: T")
385     .Attr("T: type")
386     .Attr("Tindices: {int32, int64}")
387     .SetShapeFn(UnchangedRank)
388     .Doc(R"doc(
389 Wraps the XLA Pad operator, documented at
390  https://www.tensorflow.org/performance/xla/operation_semantics#pad
391 .
392 
393 input: A `Tensor` of type T.
394 padding_value: A scalar `Tensor` of type T.
395 padding_low: the padding to apply at the start of each input dimensions
396 padding_high: the padding to apply at the end of each input dimension.
397 padding_interior: the padding to apply between each input element.
398 output: A `Tensor` of type T.
399 )doc");
400 
401 REGISTER_OP("XlaRecv")
402     .Output("tensor: dtype")
403     .Attr("dtype: type")
404     .Attr("tensor_name: string")
405     .Attr("shape: shape")
406     .SetIsStateful()
__anon00ae58f50302(shape_inference::InferenceContext* c) 407     .SetShapeFn([](shape_inference::InferenceContext* c) {
408       TensorShape shape_attr;
409       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
410       shape_inference::ShapeHandle s;
411       TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
412       c->set_output(0, s);
413       return Status::OK();
414     })
415     .Doc(R"doc(
416 Receives the named tensor from another XLA computation. Wraps the XLA Recv
417 operator documented at
418  https://www.tensorflow.org/performance/xla/operation_semantics#recv .
419 
420 tensor: The tensor to receive.
421 dtype: The type of the tensor.
422 tensor_name: A string key that identifies the channel.
423 shape: The shape of the tensor.
424 )doc");
425 
426 REGISTER_OP("XlaReduce")
427     .Input("input: T")
428     .Input("init_value: T")
429     .Attr("T: numbertype")
430     .Attr("dimensions_to_reduce: list(int)")
431     .Attr("reducer: func")
432     .Output("output: T")
__anon00ae58f50402(shape_inference::InferenceContext* c) 433     .SetShapeFn([](shape_inference::InferenceContext* c) {
434       if (c->RankKnown(c->input(0))) {
435         int rank = c->Rank(c->input(0));
436         std::vector<int64> dimensions_to_reduce;
437         TF_RETURN_IF_ERROR(
438             c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
439         std::set<int64> dims_set(dimensions_to_reduce.begin(),
440                                  dimensions_to_reduce.end());
441         auto dim_in_range = [rank](int64 dim) {
442           return dim >= 0 && dim < rank;
443         };
444         if (rank < dimensions_to_reduce.size() ||
445             dims_set.size() != dimensions_to_reduce.size() ||
446             !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
447           return errors::InvalidArgument(
448               "Invalid dimensions_to_reduce argument to XlaReduce");
449         }
450         c->set_output(
451             0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
452       } else {
453         c->set_output(0, c->input(0));
454       }
455       return Status::OK();
456     })
457     .Doc(R"doc(
458 Wraps the XLA Reduce operator, documented at
459  https://www.tensorflow.org/performance/xla/operation_semantics#reduce .
460 
461 input: the input tensor
462 init_value: a scalar representing the initial value for the reduction
463 reducer: a reducer function to apply
464 dimensions_to_reduce: dimension numbers over which to reduce
465 )doc");
466 
467 REGISTER_OP("XlaReduceWindow")
468     .Input("input: T")
469     .Input("init_value: T")
470     .Input("window_dimensions: Tindices")
471     .Input("window_strides: Tindices")
472     .Input("base_dilations: Tindices")
473     .Input("window_dilations: Tindices")
474     .Input("padding: Tindices")
475     .Attr("T: numbertype")
476     .Attr("Tindices: {int32, int64}")
477     .Attr("computation: func")
478     .Output("output: T")
479     .SetShapeFn(UnchangedRank)
480     .Doc(R"doc(
481 Wraps the XLA ReduceWindow operator, documented at
482  https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
483 
484 input: the input tensor
485 init_value: a scalar representing the initial value for the reduction
486 computation: a reducer function to apply
487 window_dimensions: the shape of the window
488 window_strides: the inter-window strides
489 padding: the padding to apply at the start and end of each input dimensions
490 )doc");
491 
492 REGISTER_OP("XlaSelectAndScatter")
493     .Input("operand: T")
494     .Input("window_dimensions: Tindices")
495     .Input("window_strides: Tindices")
496     .Input("padding: Tindices")
497     .Input("source: T")
498     .Input("init_value: T")
499     .Attr("T: numbertype")
500     .Attr("Tindices: {int32, int64}")
501     .Attr("select: func")
502     .Attr("scatter: func")
503     .Output("output: T")
504     .SetShapeFn(UnchangedRank)
505     .Doc(R"doc(
506 Wraps the XLA SelectAndScatter operator, documented at
507  https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter
508 .
509 
510 operand: the input tensor
511 window_dimensions: the shape of the window
512 window_strides: the inter-window strides
513 padding: the padding to apply at the start and end of each input dimensions
514 source: a tensor of values to scatter
515 init_value: a scalar representing the initial value for the output tensor
516 select: a selection function to apply
517 scatter: a scatter function to apply
518 )doc");
519 
520 REGISTER_OP("XlaSend")
521     .Input("tensor: T")
522     .Attr("T: type")
523     .Attr("tensor_name: string")
524     .SetIsStateful()
525     .SetShapeFn(shape_inference::UnknownShape)
526     .Doc(R"doc(
527 Sends the named tensor to another XLA computation. Wraps the XLA Send operator
528 documented at
529  https://www.tensorflow.org/performance/xla/operation_semantics#send .
530 
531 tensor: The tensor to send.
532 tensor_name: A string key that identifies the channel.
533 )doc");
534 
535 REGISTER_OP("XlaSort")
536     .Input("input: T")
537     .Output("output: T")
538     .Attr("T: type")
539     .SetShapeFn(shape_inference::UnchangedShape)
540     .Doc(R"doc(
541 Wraps the XLA Sort operator, documented at
542  https://www.tensorflow.org/performance/xla/operation_semantics#sort
543 .
544 
545 Sorts a tensor. Currently only sorts in ascending order are supported.
546 
547 input: A `Tensor` of type T.
548 output: A `Tensor` of type T.
549 )doc");
550 
551 REGISTER_OP("XlaKeyValueSort")
552     .Input("keys: K")
553     .Input("values: V")
554     .Output("sorted_keys: K")
555     .Output("sorted_values: V")
556     .Attr("K: realnumbertype")
557     .Attr("V: type")
__anon00ae58f50602(shape_inference::InferenceContext* c) 558     .SetShapeFn([](shape_inference::InferenceContext* c) {
559       c->set_output(0, c->input(0));
560       c->set_output(1, c->input(1));
561       return Status::OK();
562     })
563     .Doc(R"doc(
564 Wraps the XLA Sort operator, documented at
565  https://www.tensorflow.org/performance/xla/operation_semantics#sort
566 .
567 
568 Sorts a tensor. Currently only sorts in ascending order are supported.
569 
570 keys: A `Tensor` of type K.
571 values: A `Tensor` of type V.
572 sorted_keys: A `Tensor` of type K.
573 sorted_values: A `Tensor` of type V.
574 )doc");
575 
576 // TODO(b/37549631) setting the While Op to always be stateful is too
577 // conservative.
578 REGISTER_OP("XlaWhile")
579     .Input("input: T")
580     .Output("output: T")
581     .Attr("T: list(type) >= 0")
582     .Attr("cond: func")
583     .Attr("body: func")
584     .SetIsStateful()
585     .SetShapeFn(shape_inference::UnknownShape)
586     .Doc(R"doc(
587 output = input; While (Cond(output)) { output = Body(output) }
588 
589 input: A list of input tensors whose types are T.
590 output: A list of output tensors whose types are T.
591 cond: A function takes 'input' and returns a tensor.  If the tensor is
592       a scalar of non-boolean, the scalar is converted to a boolean
593       according to the following rule: if the scalar is a numerical
594       value, non-zero means True and zero means False; if the scalar is
595       a string, non-empty means True and empty means False. If the
596       tensor is not a scalar, non-emptiness means True and False
597       otherwise.
598 body: A function that takes a list of tensors and returns another
599       list of tensors. Both lists have the same types as specified by T.
600 )doc");
601 
602 REGISTER_OP("XlaDequantize")
603     .Input("input: uint32")
604     .Output("output: bfloat16")
605     .Attr("min_range: float")
606     .Attr("max_range: float")
607     .Attr("mode: string")
608     .Attr("transpose_output: bool")
609     .SetIsStateful()
610     .SetShapeFn(shape_inference::UnknownShape)
611     .Doc(R"doc(
612 Takes the packed uint32 input and unpacks the input to uint8 to do
613 Dequantization on device.
614 
615 input: Input tensors whose types is uint32, shape is [d0, ..., dn].
616 output: Output tensors whose types is bloat16. If transpose_output is true,
617      output shape is [dn * 4, dn-1, ..., d1, d0]. If transpose_output
618      is false, output shape is [d0,..., dn * 4].
619 min_range: The minimum scalar value possibly produced for the input.
620 max_range: The maximum scalar value possibly produced for the input.
621 mode: String to determine the dequantize mode in {"MIN_COMBINED", "MIN_FIRST", "SCALED"}.
622 transpose_output: Boolean to determine if output is transposed. transpose_output
623      is faster when input is large and rank of input is higher than 1.
624 )doc");
625 
626 REGISTER_OP("XlaEinsum")
627     .Input("a: T")
628     .Input("b: T")
629     .Output("product: T")
630     .Attr("equation: string")
631     .Attr("T: {complex64, bfloat16, float}")
__anon00ae58f50702(shape_inference::InferenceContext* context) 632     .SetShapeFn([](shape_inference::InferenceContext* context) {
633       string equation;
634       TF_RETURN_IF_ERROR(context->GetAttr("equation", &equation));
635       // XlaEinsum supports only two-input einsum equations.
636       if (!absl::StrContains(equation, ",")) {
637         return errors::InvalidArgument("Expected one \",\" in equation. Got: ",
638                                        equation);
639       }
640       // Use EinsumShape for the rest of the inference now that we know we must
641       // have a two-input einsum.
642       return shape_inference::EinsumShape(context);
643     })
644     .Doc(R"doc(
645 An op which supports basic einsum op with 2 inputs and 1 output.
646 
647 This op has better TPU performance since it doesn't have explicitly reshape and
648 transpose operations as tf.einsum does.
649 )doc");
650 
651 REGISTER_OP("XlaSharding")
652     .Input("input: T")
653     .Output("output: T")
654     .Attr("T: type")
655     .SetShapeFn(shape_inference::UnchangedShape)
656     .Doc(R"doc(
657 An op which shards the input based on the given sharding attribute.
658 )doc");
659 
660 REGISTER_OP("XlaReplicaId")
661     .Output("id: int32")
__anon00ae58f50802(shape_inference::InferenceContext* context) 662     .SetShapeFn([](shape_inference::InferenceContext* context) {
663       context->set_output(0, context->MakeShape({}));
664       return Status::OK();
665     })
666     .Doc("Replica ID.");
667 
668 REGISTER_OP("XlaGather")
669     .Input("operand: T")
670     .Input("start_indices: Tindices")
671     .Input("slice_sizes: Tindices")
672     .Attr("dimension_numbers: string")
673     .Attr("indices_are_sorted: bool")
674     .Attr("T: numbertype")
675     .Attr("Tindices: {int32, int64}")
676     .Output("output: T")
677     .SetShapeFn(UnchangedRank)
678     .Doc(R"doc(
679 Wraps the XLA Gather operator documented at
680   https://www.tensorflow.org/xla/operation_semantics#gather
681 operand: The array we're gathering from.
682 start_indices: Array containing the starting indices of the slices we gather.
683 dimension_numbers: A serialized xla::GatherDimensionNumbers proto.
684 slice_sizes: slice_sizes[i] is the bounds for the slice on dimension i.
685 indices_are_sorted: Boolean indicating if the indices are sorted.
686 )doc");
687 
688 REGISTER_OP("XlaScatter")
689     .Input("operand: T")
690     .Input("scatter_indices: Tindices")
691     .Input("updates: T")
692     .Attr("update_computation: func")
693     .Attr("dimension_numbers: string")
694     .Attr("indices_are_sorted: bool")
695     .Attr("T: numbertype")
696     .Attr("Tindices: {int32, int64}")
697     .Output("output: T")
698     .SetShapeFn(UnchangedRank)
699     .Doc(R"doc(
700 Wraps the XLA Scatter operator documented at
701   https://www.tensorflow.org/xla/operation_semantics#scatter.
702 
703 operand: Array to be scattered into.
704 scatter_indices: Array containing the starting indices of the slices that must
705   be scattered to.
706 updates: Array containing the values that must be used for scattering.
707 update_computation: Computation to be used for combining the existing values in
708   the input array and the updates during scatter.
709 dimension_numbers: A serialized xla::ScatterDimensionNumbers proto.
710 indices_are_sorted: Boolean indicating if the indices are sorted.
711 )doc");
712 
713 }  // namespace
714 }  // namespace tensorflow
715