• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This is the operation definition file for TensorFlow.
17//
18// This file contains TensorFlow ops whose definitions are amended to fix
19// issues or provide more information. In this file you have full control
20// of the op definition; all changes will be retained with subsequent
21// refreshes.
22//
23// This file includes another file, `tf_generated_ops.td`, which contains
24// all ops whose definitions are generated from TensorFlow codebase.
25// Changes made there are not respected.
26
27#ifndef TF_OPS
28#define TF_OPS
29
30include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td"
31include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
32include "mlir/Interfaces/CallInterfaces.td"
33include "mlir/Interfaces/ControlFlowInterfaces.td"
34include "mlir/Interfaces/InferTypeOpInterface.td"
35include "mlir/Interfaces/LoopLikeInterface.td"
36include "mlir/Interfaces/SideEffectInterfaces.td"
37include "mlir/IR/OpAsmInterface.td"
38include "mlir/IR/OpBase.td"
39include "mlir/IR/SymbolInterfaces.td"
40
41class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
42  let results = (outs
43    TF_VariantTensor:$handle
44  );
45
46  TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<0>;
47
48  let verifier = [{
49    // This is required to populate derived attributes during export in a
50    // meaningful way. Else during export to GraphDef element_type() query
51    // will result in out of bounds access/assert.
52    if (handle_dtype().getSubtypes().size() != 1) {
53      return emitOpError(
54          "must have exactly one subtype in the result variant type");
55    }
56
57    return Verify(*this);
58  }];
59
60  DerivedTypeAttr element_dtype = DerivedTypeAttr<
61      "return getElementTypeOrSelf(element_type());">;
62
63  let extraClassDeclaration = [{
64    // Returns type of the TensorList element produced by this op.
65    TensorType element_type() { return handle_dtype().getSubtypes()[0]; }
66
67    // Returns data type of the result handle. Returned type contains type of
68    // the TensorList element as a subtype.
69    VariantType handle_dtype() {
70      return getElementTypeOrSelf(handle().getType()).cast<TF::VariantType>();
71    }
72  }];
73}
74
75def TF_CaseOp : TF_Op<"Case", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
76  let summary = [{
77An n-way switch statement which calls a single branch function.
78  }];
79
80  let description = [{
81An n-way switch statement, implementing the following:
82    ```
83    switch (branch_index) {
84      case 0:
85        output = branches[0](input);
86        break;
87      case 1:
88        output = branches[1](input);
89        break;
90      ...
91      case [[nbranches-1]]:
92      default:
93        output = branches[nbranches-1](input);
94        break;
95    }
96    ```
97  }];
98
99  let arguments = (ins
100    I32Tensor:$branch_index,
101    Variadic<TF_Tensor>:$input,
102
103    Confined<SymbolRefArrayAttr, [ArrayMinCount<1>]>:$branches,
104
105    // Used to map StatelessCase and Case op defined in TensorFlow to a common
106    // op.
107    BoolAttr:$is_stateless
108  );
109
110  let results = (outs
111    Variadic<TF_Tensor>:$output
112  );
113
114  TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
115  TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
116  TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
117
118  let hasCanonicalizer = 1;
119
120  let verifier = [{
121    return Verify(*this);
122  }];
123
124
125 let extraClassDeclaration = [{
126    int num_branches() { return branches().size(); }
127
128    // Gets function corresponding branch # `index`.
129    FuncOp branch_function(int index) {
130      auto flat_sym_ref = branches()[index].cast<FlatSymbolRefAttr>();
131      return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, flat_sym_ref);
132    }
133
134    // Gets all branch functions.
135    void get_branch_functions(SmallVectorImpl<FuncOp> &functions) {
136      functions.reserve(num_branches());
137      for (int idx : llvm::seq<int>(0, num_branches()))
138        functions.push_back(branch_function(idx));
139    }
140  }];
141}
142
143def TF_CaseRegionOp : TF_Op<"CaseRegion",
144      [SingleBlockImplicitTerminator<"YieldOp">, NoRegionArguments]> {
145  let summary = [{
146An n-way switch statement which calls a single branch function.
147  }];
148
149  let description = [{
150An n-way switch statement, implementing the following:
151    ```
152    switch (branch_index) {
153      case 0:
154        output = branches[0](input);
155        break;
156      case 1:
157        output = branches[1](input);
158        break;
159      ...
160      case [[nbranches-1]]:
161      default:
162        output = branches[nbranches-1](input);
163        break;
164    }
165    ```
166  }];
167
168  let arguments = (ins
169    I32Tensor:$branch_index,
170
171    // Used to map StatelessCase and Case op defined in TensorFlow to a common
172    // op.
173    BoolAttr:$is_stateless
174  );
175
176  let results = (outs
177    Variadic<TF_Tensor>:$output
178  );
179
180  let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
181
182  let verifier = [{
183    return Verify(*this);
184  }];
185
186  let hasCanonicalizer = 1;
187
188}
189
190// In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with
191// its type encoding the tensor's shape and data type.
192def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect,
193    DeclareOpInterfaceMethods<InferTypeOpInterface>,
194    DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
195  let summary = "Constant tensor op";
196
197  let arguments = (ins
198    ElementsAttr:$value
199  );
200
201  let results = (outs
202    TF_Tensor:$output
203  );
204
205  TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
206
207  let builders = [
208    OpBuilder<(ins "Attribute":$value)>,
209    OpBuilder<(ins "Type":$type, "Attribute":$value)>,
210  ];
211
212  let hasFolder = 1;
213
214  let extraClassDeclaration = [{
215    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
216      return BroadcastCompatible(l, r);
217    }
218  }];
219}
220
221def TF_XlaAllReduceOp : TF_Op<"XlaAllReduce", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>]> {
222  let summary = "An Op to reduce inputs across replicated TPU instances.";
223
224  let arguments = (ins
225    TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Int32, TF_Uint32]>:$input,
226    TF_Int32Tensor:$group_assignment,
227    TF_AnyStrAttrOf<["Min", "Max", "Mul", "Add", "Mean"]>:$reduce_op
228  );
229
230  let results = (outs
231    TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Int32, TF_Uint32]>:$output
232  );
233
234  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
235}
236
237def TF_EmptyTensorListOp : TF_TensorListInitOp<"EmptyTensorList"> {
238  let summary = "Creates and returns an empty tensor list.";
239
240  let description = [{
241All list elements must be tensors of dtype element_dtype and shape compatible
242with element_shape.
243
244handle: an empty tensor list.
245element_dtype: the type of elements in the list.
246element_shape: a shape compatible with that of elements in the list.
247  }];
248
249  let arguments = (ins
250    TF_I32OrI64Tensor:$element_shape,
251    TF_Int32Tensor:$max_num_elements
252  );
253}
254
255def TF_IfOp : TF_Op<"If", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
256  let summary = "output = cond ? then_branch(input) : else_branch(input)";
257
258  let description = [{
259output = cond ? then_branch(input) : else_branch(input)
260
261cond: A Tensor. If the tensor is a scalar of non-boolean type, the
262    scalar is converted to a boolean according to the
263    following rule: if the scalar is a numerical value, non-zero means
264    True and zero means False; if the scalar is a string, non-empty
265    means True and empty means False. If the tensor is not a scalar,
266    being empty means False and being non-empty means True.
267input: A list of input tensors.
268then_branch: A function that takes 'inputs' and returns a list of
269    tensors, whose types are the same as what else_branch returns.
270else_branch: A function that takes 'inputs' and returns a list of
271    tensors.  whose types are the same as what then_branch returns.
272  }];
273
274  let arguments = (ins
275    TF_Tensor:$cond,
276    Variadic<TF_Tensor>:$input,
277
278    FlatSymbolRefAttr:$then_branch,
279    FlatSymbolRefAttr:$else_branch,
280
281    // Used to map StatelessIf and If op defined in TensorFlow to a common op.
282    BoolAttr:$is_stateless
283  );
284
285  let results = (outs
286    Variadic<TF_Tensor>:$output
287  );
288
289  TF_DerivedOperandTypeAttr Tcond = TF_DerivedOperandTypeAttr<0>;
290  TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
291  TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
292  TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
293
294  let hasCanonicalizer = 1;
295
296  let extraClassDeclaration = [{
297    // Get the then branch function.
298    FuncOp then_function() {
299     return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, then_branch());
300    }
301
302    // Get the else branch function.
303    FuncOp else_function() {
304     return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, else_branch());
305    }
306  }];
307}
308
309def TF_YieldOp : TF_Op<"Yield",
310      [NoSideEffect, ReturnLike, Terminator,
311       ParentOneOf<["CaseRegionOp", "IfRegionOp", "WhileRegionOp"]>]> {
312  let summary = "Yield operation";
313
314  let description = [{
315    The "yield" operation represents a return operation within the conditional
316    and body of structured control flow (e.g., if and while). The operation
317    takes a variable number of operands and produces no results. The number and
318    types of inputs must match the signature of the operation that contains the
319    region.
320  }];
321
322  let arguments = (ins Variadic<AnyType>:$operands);
323}
324
325def TF_IfRegionOp : TF_Op<"IfRegion",
326      [SingleBlockImplicitTerminator<"YieldOp">, NoRegionArguments]> {
327  let summary = "output = cond ? then_branch output : else_branch output";
328
329  let description = [{
330"output = cond ? then_branch output : else_branch output"
331
332cond: A Tensor. If the tensor is a scalar of non-boolean type, the
333    scalar is converted to a boolean according to the
334    following rule: if the scalar is a numerical value, non-zero means
335    True and zero means False; if the scalar is a string, non-empty
336    means True and empty means False. If the tensor is not a scalar,
337    being empty means False and being non-empty means True.
338then_branch: A region that computes the outputs of the op if cond = true.
339    It returns a list of tensors using tf.yield (as the terminator). The
340    types of these returned tensors is same as that of the else_branch
341else_branch: A region that computes the outputs of the op if cond = false.
342    It returns a list of tensors using tf.yield (as the terminator). The
343    types of these returned tensors is same as that of the then_branch
344  }];
345
346  let arguments = (ins
347    0DTensorOf<[I1]>:$cond,
348
349    // Used to map StatelessIf and If op defined in TensorFlow to a common op.
350    BoolAttr:$is_stateless,
351    // Used to maintain function name when round-tripping
352    // between functional and regional control flow.  This can be removed if
353    // the runtime does not require globally unique then/else branch function names.
354    OptionalAttr<StrAttr>:$_then_func_name,
355    OptionalAttr<StrAttr>:$_else_func_name
356  );
357
358  let results = (outs
359    Variadic<TF_Tensor>:$output
360  );
361
362  let regions = (region SizedRegion<1>:$then_branch, SizedRegion<1>:$else_branch);
363
364  let verifier = [{
365    return Verify(*this);
366  }];
367
368  let builders = [
369    OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$operands,
370      "llvm::ArrayRef<::mlir::NamedAttribute>":$attributes,
371      "unsigned":$numRegions),
372    [{
373      assert(numRegions == 2u && "mismatched number of regions");
374      build($_builder, $_state, resultTypes, operands, attributes);
375    }]>];
376
377  let hasCanonicalizer = 1;
378}
379
380def TF_LegacyCallOp : TF_Op<"LegacyCall",
381                            [CallOpInterface, NoSideEffect]> {
382  let summary =
383    "returns `f(inputs)`, where `f` is a function.";
384
385  let description = [{
386    The LegacyCall operation represents a direct call to a function that is
387    within the same symbol scope as the call and is mapped to a GraphDef node
388    with the function name as the op name. Unlike a PartitionedCall which
389    represents asynchronously executing a function across multiple devices, a
390    LegacyCall ignores specification for ops in the attached function and
391    instead executes it on the device assigned to this op.
392  }];
393
394  let arguments = (ins
395    Variadic<TF_Tensor>:$args,
396
397    FlatSymbolRefAttr:$f,
398    DefaultValuedAttr<BoolAttr, "false">:$_disable_call_shape_inference
399  );
400
401  let results = (outs
402    Variadic<TF_Tensor>:$output
403  );
404
405  let extraClassDeclaration = [{
406    // Gets the argument operands to the called function.
407    operand_range getArgOperands() { return args(); }
408
409    // Returns the callee of this operation.
410    CallInterfaceCallable getCallableForCallee() { return fAttr(); }
411
412    // returns the callee of this operation.
413    FuncOp func() {
414      return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, f());
415    }
416  }];
417}
418
419def TF_ParseExampleOp : TF_Op<"ParseExample",
420                               [NoSideEffect,
421                                AttrSizedResultSegments,
422                                AttrSizedOperandSegments]> {
423
424  let summary =
425    "Transforms a vector of tf.Example protos (as strings) into typed tensors.";
426
427  let arguments = (ins
428    TF_StrTensor:$serialized,
429    TF_StrTensor:$names,
430    Variadic<TF_StrTensor>:$sparse_keys,
431    Variadic<TF_StrTensor>:$dense_keys,
432    Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_defaults,
433
434    TF_ShapeAttrArray:$dense_shapes,
435    I32ElementsAttr:$result_segment_sizes,
436    I32ElementsAttr:$operand_segment_sizes
437  );
438
439  let results = (outs
440    Variadic<TF_Int64Tensor>:$sparse_indices,                           // len(sparse_types)
441    Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$sparse_values,  // len(sparse_types)
442    Variadic<TF_Int64Tensor>:$sparse_shapes,                            // len(sparse_types)
443    Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_values    // len(Tdense)
444  );
445
446  TF_DerivedOperandSizeAttr Nsparse = TF_DerivedOperandSizeAttr<2>;
447  TF_DerivedOperandSizeAttr Ndense = TF_DerivedOperandSizeAttr<3>;
448  TF_DerivedOperandTypeListAttr Tdense = TF_DerivedOperandTypeListAttr<4>;
449  TF_DerivedResultTypeListAttr sparse_types = TF_DerivedResultTypeListAttr<1>;
450
451  let verifier = ?;
452}
453
454def TF_ParseExampleV2Op : TF_Op<"ParseExampleV2",
455                                [NoSideEffect,
456                                 AttrSizedResultSegments]> {
457
458  let summary =
459    "Transforms a vector of tf.Example protos (as strings) into typed tensors.";
460
461  let arguments = (ins
462    TF_StrTensor:$serialized,
463    TF_StrTensor:$names,
464    TF_StrTensor:$sparse_keys,
465    TF_StrTensor:$dense_keys,
466    TF_StrTensor:$ragged_keys,
467    Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_defaults,
468
469    Confined<I64Attr, [IntMinValue<0>]>:$num_sparse,
470    TF_ShapeAttrArray:$dense_shapes,
471    I32ElementsAttr:$result_segment_sizes
472  );
473
474  let results = (outs
475    Variadic<TF_Int64Tensor>:$sparse_indices,                           // len(sparse_types)
476    Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$sparse_values,  // len(sparse_types)
477    Variadic<TF_Int64Tensor>:$sparse_shapes,                            // len(sparse_types)
478    Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_values,   // len(Tdense)
479    Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$ragged_values,  // len(ragged_value_types)
480                                                            //     = len(ragged_split_types)
481    Variadic<TensorOf<[TF_Int32, TF_Int64]>>:$ragged_row_splits         // len(ragged_split_types)
482                                                            //     = len(ragged_value_types)
483  );
484
485  // The Verify(ParseExampleV2Op) function validates that the lengths and types
486  // of these attrs are compatible.
487  TF_DerivedOperandTypeListAttr Tdense = TF_DerivedOperandTypeListAttr<5>;
488  TF_DerivedResultTypeListAttr sparse_types = TF_DerivedResultTypeListAttr<1>;
489  TF_DerivedResultTypeListAttr ragged_value_types =
490    TF_DerivedResultTypeListAttr<4>;
491  TF_DerivedResultTypeListAttr ragged_split_types =
492    TF_DerivedResultTypeListAttr<5>;
493
494  let verifier = [{
495    return Verify(*this);
496  }];
497}
498
499def TF_PlaceholderOp : TF_Op<"Placeholder", [NoSideEffect]> {
500  let summary = "Placeholder op";
501
502  let description = [{
503Inserts a placeholder for a tensor that will be always fed.
504  }];
505
506  let arguments = (ins
507  );
508
509  let results = (outs
510    TF_Tensor:$output
511  );
512
513  TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
514}
515
516def TF_PlaceholderWithDefaultOp : TF_Op<"PlaceholderWithDefault", [NoSideEffect]> {
517  let summary = "Placeholder op";
518
519  let description = [{
520    A placeholder op that passes through input when its output is not fed.
521  }];
522
523  let arguments = (ins
524    TF_Tensor:$input
525  );
526
527  let results = (outs
528    TF_Tensor:$output
529  );
530
531  TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
532  DerivedAttr shape = TF_DerivedResultShapeAttr;
533}
534
535def TF_StatefulPartitionedCallOp : TF_Op<"StatefulPartitionedCall",
536                                         [CallOpInterface]> {
537  let summary =
538    "returns `f(inputs)`, where `f`'s body is placed and partitioned.";
539
540  let description = [{
541Asynchronously executes a function, potentially across multiple devices but
542within a single process. The kernel places and partitions a given function's
543underlying graph, and executes each of the partitioned subgraphs as a function.
544  }];
545
546  let arguments = (ins
547    Variadic<TF_Tensor>:$args,
548
549    FlatSymbolRefAttr:$f,
550    StrAttr:$config,
551    StrAttr:$config_proto,
552    StrAttr:$executor_type
553  );
554
555  let results = (outs
556    Variadic<TF_Tensor>:$output
557  );
558
559  TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
560  TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
561
562  let extraClassDeclaration = [{
563    // Gets the argument operands to the called function.
564    operand_range getArgOperands() { return args(); }
565
566    // Returns the callee of this operation.
567    CallInterfaceCallable getCallableForCallee() { return fAttr(); }
568
569    // returns the callee of this operation.
570    FuncOp func() {
571      return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, f());
572    }
573  }];
574
575  let verifier = [{ return VerifyPartitionedCall(*this); }];
576}
577
578def TF_WhileOp : TF_Op<"While", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
579  let summary = [{
580output = input; While (Cond(output)) { output = Body(output) }
581  }];
582
583  let description = [{
584output = input; While (Cond(output)) { output = Body(output) }
585
586input: A list of input tensors whose types are T.
587output: A list of output tensors whose types are T.
588cond: A function that takes 'input' and returns a tensor.  If the tensor is
589    a scalar of non-boolean, the scalar is converted to a boolean
590    according to the following rule: if the scalar is a numerical
591    value, non-zero means True and zero means False; if the scalar is
592    a string, non-empty means True and empty means False. If the
593    tensor is not a scalar, non-emptiness means True and False
594    otherwise.
595body: A function that takes a list of tensors and returns another
596      list of tensors. Both lists have the same types as specified
597      by T.
598  }];
599
600  let arguments = (ins
601    Variadic<TF_Tensor>:$input,
602
603    FlatSymbolRefAttr:$cond,
604    FlatSymbolRefAttr:$body,
605    DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
606
607    // Used to map StatelessWhile and While op defined in TensorFlow to a common
608    // op.
609    BoolAttr:$is_stateless,
610
611    // In TensorFlow, While has a special behavior where if `output_shapes`
612    // attribute is not empty, those shapes are used in its shape function
613    // as result shapes instead of propagating operand shapes as result shapes.
614    // This allows for different result shapes from operand shapes. While these
615    // shapes are imported and set as a part of the result type, there is no
616    // indicator differentiating between having no output shapes compared to
617    // having all unranked shapes. Thus this attribute is set to determine
618    // which shape function behavior to use for this op, specifically
619    // propagating operand shapes as result shapes when this attribute is not
620    // set, or preserving result shapes as is when this attribute is set.
621    UnitAttr:$shape_invariant
622  );
623
624  let results = (outs
625    Variadic<TF_Tensor>:$output
626  );
627
628  TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
629  TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
630
631  let extraClassDeclaration = [{
632    // Get the condition function.
633    FuncOp cond_function() {
634      return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, cond());
635    }
636
637    // Get the body function.
638    FuncOp body_function() {
639      return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, body());
640    }
641  }];
642}
643
644def TF_WhileRegionOp : TF_Op<"WhileRegion",
645      [DeclareOpInterfaceMethods<LoopLikeOpInterface>,
646       SingleBlockImplicitTerminator<"YieldOp">]> {
647  let summary = "while operation";
648  let description = [{
649  The tf.WhileRegion op represents a while loop using 2 regions and a set of
650  iteration variables. The iteration variables maintained by this Op have the
651  same types as the inputs. The Op executes a while loop described by the
652  following pseudo code:
653
654  ```
655     func WhileRegionOp(inputs) {
656       iteration_vars = inputs;
657       while (cond(iteration_vars)) {
658           iteration_vars = body(iteration_vars);
659       }
660       return iteration_vars;
661     }
662  ```
663
664  `cond` is the condition region and `body` is the body region. Both these
665  regions accept the current value of the iteration variables as inputs. The
666  condition region returns a tensor<i1> which, if false, will exit the loop.
667  The body region computes new values of the iteration variables. The iteration
668  variables are initialized to the Op input, and the results of the
669  tf.WhileRegion op are the final values of the iteration variables.
670
671  This implies that the operand and result types for tf.WhileRegion should be
672  the same. Note that the condition and body regions can implicitly capture
673  loop invariant values directly. In canonical form, iteration variables that
674  pass through the loop body unmodified are converted to implicitly captured
675  references to their values outside the loop.
676  }];
677
678  let arguments = (ins
679    Variadic<AnyTensor>:$input,
680
681    DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
682
683    // Used to map StatelessWhile and While op defined in TensorFlow to a common
684    // op.
685    BoolAttr:$is_stateless,
686
687    // In TensorFlow, While has a special behavior where if `output_shapes`
688    // attribute is not empty, those shapes are used in its shape function
689    // as result shapes instead of propagating operand shapes as result shapes.
690    // This allows for different result shapes from operand shapes. While these
691    // shapes are imported and set as a part of the result type, there is no
692    // indicator differentiating between having no output shapes compared to
693    // having all unranked shapes. Thus this attribute is set to determine
694    // which shape function behavior to use for this op, specifically
695    // propagating operand shapes as result shapes when this attribute is not
696    // set, or preserving result shapes as is when this attribute is set.
697    UnitAttr:$shape_invariant
698  );
699  let results = (outs Variadic<AnyTensor>:$output);
700
701  let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
702
703  let verifier = [{ return Verify(*this); }];
704
705  let hasCanonicalizer = 1;
706}
707
708def TF_TensorListReserveOp : TF_TensorListInitOp<"TensorListReserve"> {
709  let summary = "List of the given size with empty elements.";
710
711  let description = [{
712element_shape: the shape of the future elements of the list
713num_elements: the number of elements to reserve
714handle: the output list
715element_dtype: the desired type of elements in the list.
716  }];
717
718  let arguments = (ins
719    TF_I32OrI64Tensor:$element_shape,
720    TF_Int32Tensor:$num_elements
721  );
722}
723
724def TF_VarHandleOp : TF_Op<"VarHandleOp", [DeclareOpInterfaceMethods<TF_ResourceHandleAllocatorInterface>]> {
725  let summary = "Creates a handle to a Variable resource from its name.";
726
727  let description = [{
728container: the container this variable is placed in.
729shared_name: the name by which this variable is referred to.
730dtype and shape: attributes representing the data type and shape held in the
731  variable.
732
733Example:
734    resource_variable_ops.var_handle_op(
735          dtype=dtypes.int32, shape=[8, 16], container="foo", shared_name="bar")
736  returns a handle for a variable with name "bar" in container "foo", and the
737  variable holds a tensor of shape [8, 16] and dtype int32.
738  }];
739
740  let arguments = (ins
741    DefaultValuedAttr<StrAttr, "">:$container,
742    DefaultValuedAttr<StrAttr, "">:$shared_name
743  );
744
745  let results = (outs
746    Res<TF_ResourceTensor, "", [TF_VariableAlloc]>:$resource
747  );
748
749  let verifier = [{
750    // VarHandleOp requires the resource handle supply a single subtype from
751    // which to derive the dtype and shape attributes.
752    if (resource_type().getSubtypes().size() != 1) {
753      return emitOpError(
754          "must have exactly one subtype in the result resource type");
755    }
756
757    return success();
758  }];
759
760  DerivedTypeAttr dtype = DerivedTypeAttr<
761      "return getElementTypeOrSelf(resource_subtype());">;
762  DerivedAttr shape = DerivedAttr<
763      "ShapedType",
764      "return resource_subtype().cast<ShapedType>();",
765      [{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>;
766
767  let extraClassDeclaration = [{
768    TensorType resource_subtype() { return resource_type().getSubtypes()[0]; }
769
770    ResourceType resource_type() {
771      return getElementTypeOrSelf(resource()).cast<TF::ResourceType>();
772    }
773  }];
774}
775
776def TF_XlaShardingOp : TF_Op<"XlaSharding", [NoSideEffect, TF_NoConstantFold]> {
777  let summary = [{
778An op which shards the input based on the given sharding attribute.
779  }];
780
781  let arguments = (ins
782    TF_Tensor:$input,
783
784    DefaultValuedAttr<StrAttr, "">:$sharding,
785    OptionalAttr<StrAttr>:$_XlaSharding
786  );
787
788  let results = (outs
789    TF_Tensor:$output
790  );
791
792  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
793}
794
795def TF_InfeedDequeueTupleOp : TF_Op<"InfeedDequeueTuple", []> {
796  let summary = "Fetches multiple values from infeed as an XLA tuple.";
797
798  let arguments = (ins
799    OptionalAttr<StrAttr>:$_XlaSharding
800  );
801
802  let results = (outs
803    Variadic<TF_Tensor>:$outputs
804  );
805
806  TF_DerivedResultShapeListAttr shapes = TF_DerivedResultShapeListAttr<0>;
807  TF_DerivedResultTypeListAttr dtypes = TF_DerivedResultTypeListAttr<0>;
808}
809
810// TODO(b/177675373): Make dtypes and shapes derived attributes,
811// use more general solution.
812def TF_InfeedEnqueueTupleOp : TF_Op<"InfeedEnqueueTuple", []> {
813  let summary = [{
814Feeds multiple Tensor values into the computation as an XLA tuple.
815  }];
816
817  let arguments = (ins
818    Arg<Variadic<TF_Tensor>, [{A list of tensors that will be provided using the infeed mechanism.}]>:$inputs,
819
820    Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$dtypes,
821    TF_ShapeAttrArray:$shapes,
822    DefaultValuedAttr<I64ArrayAttr, "{}">:$layouts,
823    DefaultValuedAttr<I64Attr, "-1">:$device_ordinal
824  );
825
826  let results = (outs);
827}
828
829// This op is manually defined because the attribute name `template` (which is
830// a keyword) is changed to `strtemplate`.
831def TF_StringFormatOp : TF_Op<"StringFormat", [NoSideEffect]> {
832  let summary = "Formats a string template using a list of tensors.";
833
834  let description = [{
835Formats a string template using a list of tensors, pretty-printing tensor summaries.
836  }];
837
838  let arguments = (ins
839    Variadic<TF_Tensor>:$inputs,
840
841    DefaultValuedAttr<StrAttr, "%s">:$strtemplate,
842    DefaultValuedAttr<StrAttr, "%s">:$placeholder,
843    DefaultValuedAttr<I64Attr, "3">:$summarize
844  );
845
846  let results = (outs
847    TF_StrTensor:$output
848  );
849
850  TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
851}
852
853//===----------------------------------------------------------------------===//
854// tf.data ops
855//===----------------------------------------------------------------------===//
856
857def TF_ReduceDatasetOp : TF_Op<"ReduceDataset", [SameVariadicOperandSize]> {
858  let summary = [{
859    Reduces the input dataset to a singleton using a reduce function.
860  }];
861
862  let arguments = (ins
863    TF_VariantTensor:$input_dataset,
864    Variadic<TF_Tensor>:$initial_state,
865    Variadic<TF_Tensor>:$other_arguments,
866
867    SymbolRefAttr:$f,
868    Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$Tstate,
869    Confined<TypeArrayAttr, [ArrayMinCount<0>]>:$Targuments,
870    Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
871    Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
872    DefaultValuedAttr<BoolAttr, "true">:$use_inter_op_parallelism
873  );
874
875  let results = (outs
876    Variadic<TF_Tensor>:$components
877  );
878}
879
880// Manually defined to restrict result type to `I1Tensor`.
881def TF_ToBoolOp : TF_Op<"ToBool", [DeclareOpInterfaceMethods<InferTypeOpInterface>, NoSideEffect]> {
882  let summary = "Converts a tensor to a scalar predicate.";
883
884  let description = [{
885Converts a tensor to a scalar predicate with the following rules:
886
887- For 0D tensors, truthiness is determined by comparing against a "zero"
888  value. For numerical types it is the obvious zero. For strings it is the
889  empty string.
890
891- For >0D tensors, truthiness is determined by looking at the number of
892  elements. If has zero elements, then the result is false. Otherwise the
893  result is true.
894
895This matches the behavior of If and While for determining if a tensor counts
896as true/false for a branch condition.
897  }];
898
899  let arguments = (ins
900    TF_Tensor:$input
901  );
902
903  let results = (outs
904    I1Tensor:$output
905  );
906
907  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
908
909  let hasCanonicalizer = 1;
910
911  let extraClassDeclaration = [{
912    // InferTypeOpInterface:
913    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
914      return ArraysAreCastCompatible(l, r);
915    }
916  }];
917}
918
919def TF_BesselI0eOp : TF_Op<"BesselI0e", [NoSideEffect, SameOperandsAndResultType]> {
920  let summary = "Computes the Bessel i0e function of `x` element-wise.";
921
922  let description = [{
923Exponentially scaled modified Bessel function of order 0 defined as
924`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`.
925
926This function is faster and numerically stabler than `bessel_i0(x)`.
927  }];
928
929  let arguments = (ins
930    TF_FloatTensor:$x
931  );
932
933  let results = (outs
934    TF_FloatTensor:$y
935  );
936
937  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
938}
939
940def TF_BesselI1eOp : TF_Op<"BesselI1e", [NoSideEffect, SameOperandsAndResultType]> {
941  let summary = "Computes the Bessel i1e function of `x` element-wise.";
942
943  let description = [{
944Exponentially scaled modified Bessel function of order 0 defined as
945`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`.
946
947This function is faster and numerically stabler than `bessel_i1(x)`.
948  }];
949
950  let arguments = (ins
951    TF_FloatTensor:$x
952  );
953
954  let results = (outs
955    TF_FloatTensor:$y
956  );
957
958  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
959}
960
961def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> {
962  let summary = "Calls a function placed on a specified TPU device.";
963
964  let arguments = (ins
965    Variadic<TF_Tensor>:$args,
966    TF_Int32Tensor:$device_ordinal,
967
968    SymbolRefAttr:$f,
969    DefaultValuedAttr<I64Attr, "0">:$autotuner_thresh
970  );
971
972  let results = (outs
973    Variadic<TF_Tensor>:$output
974  );
975
976  TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
977  TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
978
979  let extraClassDeclaration = [{
980    // Gets the argument operands to the called function.
981    operand_range getArgOperands() { return args(); }
982
983    // Returns the callee of this operation.
984    CallInterfaceCallable getCallableForCallee() { return fAttr(); }
985
986    // returns the callee of this operation.
987    FuncOp func() {
988      return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, f());
989    }
990  }];
991
992  let verifier = [{ return VerifyPartitionedCall(*this); }];
993}
994
995def TF_StatefulUniformFullIntOp : TF_Op<"StatefulUniformFullInt", []> {
996  let summary = "Outputs random integers from a uniform distribution.";
997
998  let description = [{
999The generated values are uniform integers covering the whole range of `dtype`.
1000  }];
1001
1002  let arguments = (ins
1003    Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource,
1004    TF_Int64Tensor:$algorithm,
1005    TF_I32OrI64Tensor:$shape
1006  );
1007
1008  let results = (outs
1009    TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output
1010  );
1011
1012  TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>;
1013  TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
1014}
1015
1016// TODO(lyandy): Investigate supported dtypes (`minval`, `maxval`, `output`) for
1017// `tf.StatefulUniformInt`. tf2xla kernels support i32, i64, ui32, and ui64
1018// while TensorFlow CPU/GPU kernels only support i32 and i64.
1019def TF_StatefulUniformIntOp : TF_Op<"StatefulUniformInt", []> {
1020  let summary = "Outputs random integers from a uniform distribution.";
1021
1022  let description = [{
1023The generated values are uniform integers in the range `[minval, maxval)`.
1024The lower bound `minval` is included in the range, while the upper bound
1025`maxval` is excluded.
1026
1027The random integers are slightly biased unless `maxval - minval` is an exact
1028power of two.  The bias is small for values of `maxval - minval` significantly
1029smaller than the range of the output (either `2^32` or `2^64`).
1030  }];
1031
1032  let arguments = (ins
1033    Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource,
1034    TF_Int64Tensor:$algorithm,
1035    TF_I32OrI64Tensor:$shape,
1036    TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$minval,
1037    TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$maxval
1038  );
1039
1040  let results = (outs
1041    TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output
1042  );
1043
1044  TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>;
1045  TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<3>;
1046}
1047
1048def TF_CloseSummaryWriterOp : TF_Op<"CloseSummaryWriter", []> {
1049  let summary = "Flushes and closes the summary writer.";
1050
1051  let description = [{
1052Also removes it from the resource manager. To reopen, use another
1053CreateSummaryFileWriter op.
1054
1055writer: A handle to the summary writer resource.
1056  }];
1057
1058  let arguments = (ins
1059    Arg<TF_ResourceTensor, "", [TF_SummaryFree]>:$writer
1060  );
1061
1062  let results = (outs);
1063}
1064
1065// TODO(b/168035831): Model db_uri read/write.
1066def TF_CreateSummaryDbWriterOp : TF_Op<"CreateSummaryDbWriter", []> {
1067  let summary = "Creates summary database writer accessible by given resource handle.";
1068
1069  let description = [{
1070This can be used to write tensors from the execution graph directly
1071to a database. Only SQLite is supported right now. This function
1072will create the schema if it doesn't exist. Entries in the Users,
1073Experiments, and Runs tables will be created automatically if they
1074don't already exist.
1075
1076writer: Handle to SummaryWriter resource to overwrite.
1077db_uri: For example "file:/tmp/foo.sqlite".
1078experiment_name: Can't contain ASCII control characters or <>. Case
1079  sensitive. If empty, then the Run will not be associated with any
1080  Experiment.
1081run_name: Can't contain ASCII control characters or <>. Case sensitive.
1082  If empty, then each Tag will not be associated with any Run.
1083user_name: Must be valid as both a DNS label and Linux username. If
1084  empty, then the Experiment will not be associated with any User.
1085  }];
1086
1087  let arguments = (ins
1088    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1089    TF_StrTensor:$db_uri,
1090    TF_StrTensor:$experiment_name,
1091    TF_StrTensor:$run_name,
1092    TF_StrTensor:$user_name
1093  );
1094
1095  let results = (outs);
1096}
1097
1098// TODO(b/168035831): Model logdir read/write.
1099def TF_CreateSummaryFileWriterOp : TF_Op<"CreateSummaryFileWriter", []> {
1100  let summary = "Creates a summary file writer accessible by the given resource handle.";
1101
1102  let description = [{
1103writer: A handle to the summary writer resource
1104logdir: Directory where the event file will be written.
1105max_queue: Size of the queue of pending events and summaries.
1106flush_millis: How often, in milliseconds, to flush the pending events and
1107  summaries to disk.
1108filename_suffix: Every event file's name is suffixed with this suffix.
1109  }];
1110
1111  let arguments = (ins
1112    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1113    TF_StrTensor:$logdir,
1114    TF_Int32Tensor:$max_queue,
1115    TF_Int32Tensor:$flush_millis,
1116    TF_StrTensor:$filename_suffix
1117  );
1118
1119  let results = (outs);
1120}
1121
1122def TF_FlushSummaryWriterOp : TF_Op<"FlushSummaryWriter", []> {
1123  let summary = "Flushes the writer's unwritten events.";
1124
1125  let description = [{
1126writer: A handle to the summary writer resource.
1127  }];
1128
1129  let arguments = (ins
1130    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer
1131  );
1132
1133  let results = (outs);
1134}
1135
1136def TF_ImportEventOp : TF_Op<"ImportEvent", []> {
1137  let summary = "Outputs a `tf.Event` protocol buffer.";
1138
1139  let description = [{
1140When CreateSummaryDbWriter is being used, this op can be useful for
1141importing data from event logs.
1142
1143writer: A handle to a summary writer.
1144event: A string containing a binary-encoded tf.Event proto.
1145  }];
1146
1147  let arguments = (ins
1148    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1149    TF_StrTensor:$event
1150  );
1151
1152  let results = (outs);
1153}
1154
1155def TF_SummaryWriterOp : TF_Op<"SummaryWriter", [DeclareOpInterfaceMethods<TF_ResourceHandleAllocatorInterface>]> {
1156  let summary = "Returns a handle to be used to access a summary writer.";
1157
1158  let description = [{
1159The summary writer is an in-graph resource which can be used by ops to write
1160summaries to event files.
1161
1162writer: the summary writer resource. Scalar handle.
1163  }];
1164
1165  let arguments = (ins
1166    StrAttr:$shared_name,
1167    StrAttr:$container
1168  );
1169
1170  let results = (outs
1171    Res<TF_ResourceTensor, "", [TF_SummaryAlloc]>:$writer
1172  );
1173}
1174
1175def TF_WriteAudioSummaryOp : TF_Op<"WriteAudioSummary", []> {
1176  let summary = "Writes a `Summary` protocol buffer with audio.";
1177
1178  let description = [{
1179The summary has up to `max_outputs` summary values containing audio. The
1180audio is built from `tensor` which must be 3-D with shape `[batch_size,
1181frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are
1182assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`.
1183
1184The `tag` argument is a scalar `Tensor` of type `string`.  It is used to
1185build the `tag` of the summary values:
1186
1187*  If `max_outputs` is 1, the summary value tag is '*tag*/audio'.
1188*  If `max_outputs` is greater than 1, the summary value tags are
1189   generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc.
1190
1191writer: A handle to a summary writer.
1192step: The step to write the summary for.
1193tag: Scalar. Used to build the `tag` attribute of the summary values.
1194tensor: 2-D of shape `[batch_size, frames]`.
1195sample_rate: The sample rate of the signal in hertz.
1196max_outputs: Max number of batch elements to generate audio for.
1197  }];
1198
1199  let arguments = (ins
1200    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1201    TF_Int64Tensor:$step,
1202    TF_StrTensor:$tag,
1203    TF_Float32Tensor:$tensor,
1204    TF_Float32Tensor:$sample_rate,
1205
1206    Confined<DefaultValuedAttr<I64Attr, "3">, [IntMinValue<1>]>:$max_outputs
1207  );
1208
1209  let results = (outs);
1210}
1211
1212def TF_WriteGraphSummaryOp : TF_Op<"WriteGraphSummary", []> {
1213  let summary = "Writes a `GraphDef` protocol buffer to a `SummaryWriter`.";
1214
1215  let description = [{
1216writer: Handle of `SummaryWriter`.
1217step: The step to write the summary for.
1218tensor: A scalar string of the serialized tf.GraphDef proto.
1219  }];
1220
1221  let arguments = (ins
1222    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1223    TF_Int64Tensor:$step,
1224    TF_StrTensor:$tensor
1225  );
1226
1227  let results = (outs);
1228}
1229
1230def TF_WriteHistogramSummaryOp : TF_Op<"WriteHistogramSummary", []> {
1231  let summary = "Writes a histogram summary.";
1232
1233  let description = [{
1234The generated
1235[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
1236has one summary value containing a histogram for `values`.
1237
1238This op reports an `InvalidArgument` error if any value is not finite.
1239
1240writer: A handle to a summary writer.
1241step: The step to write the summary for.
1242tag: Scalar.  Tag to use for the `Summary.Value`.
1243values: Any shape. Values to use to build the histogram.
1244  }];
1245
1246  let arguments = (ins
1247    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1248    TF_Int64Tensor:$step,
1249    TF_StrTensor:$tag,
1250    TF_IntOrFpTensor:$values
1251  );
1252
1253  let results = (outs);
1254
1255  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>;
1256}
1257
1258def TF_WriteImageSummaryOp : TF_Op<"WriteImageSummary", []> {
1259  let summary = "Writes a `Summary` protocol buffer with images.";
1260
1261  let description = [{
1262The summary has up to `max_images` summary values containing images. The
1263images are built from `tensor` which must be 4-D with shape `[batch_size,
1264height, width, channels]` and where `channels` can be:
1265
1266*  1: `tensor` is interpreted as Grayscale.
1267*  3: `tensor` is interpreted as RGB.
1268*  4: `tensor` is interpreted as RGBA.
1269
1270The images have the same number of channels as the input tensor. For float
1271input, the values are normalized one image at a time to fit in the range
1272`[0, 255]`.  `uint8` values are unchanged.  The op uses two different
1273normalization algorithms:
1274
1275*  If the input values are all positive, they are rescaled so the largest one
1276   is 255.
1277
1278*  If any input value is negative, the values are shifted so input value 0.0
1279   is at 127.  They are then rescaled so that either the smallest value is 0,
1280   or the largest one is 255.
1281
1282The `tag` argument is a scalar `Tensor` of type `string`.  It is used to
1283build the `tag` of the summary values:
1284
1285*  If `max_images` is 1, the summary value tag is '*tag*/image'.
1286*  If `max_images` is greater than 1, the summary value tags are
1287   generated sequentially as '*tag*/image/0', '*tag*/image/1', etc.
1288
1289The `bad_color` argument is the color to use in the generated images for
1290non-finite input values.  It is a `unit8` 1-D tensor of length `channels`.
1291Each element must be in the range `[0, 255]` (It represents the value of a
1292pixel in the output image).  Non-finite values in the input tensor are
1293replaced by this tensor in the output image.  The default value is the color
1294red.
1295
1296writer: A handle to a summary writer.
1297step: The step to write the summary for.
1298tag: Scalar. Used to build the `tag` attribute of the summary values.
1299tensor: 4-D of shape `[batch_size, height, width, channels]` where
1300  `channels` is 1, 3, or 4.
1301max_images: Max number of batch elements to generate images for.
1302bad_color: Color to use for pixels with non-finite values.
1303  }];
1304
1305  let arguments = (ins
1306    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1307    TF_Int64Tensor:$step,
1308    TF_StrTensor:$tag,
1309    TensorOf<[TF_Float16, TF_Float32, TF_Uint8]>:$tensor,
1310    TF_Uint8Tensor:$bad_color,
1311
1312    Confined<DefaultValuedAttr<I64Attr, "3">, [IntMinValue<1>]>:$max_images
1313  );
1314
1315  let results = (outs);
1316
1317  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>;
1318}
1319
1320def TF_WriteRawProtoSummaryOp : TF_Op<"WriteRawProtoSummary", []> {
1321  let summary = "Writes a `Summary` protocol buffer with serialized string `Summary` protocol buffers.";
1322
1323  let description = [{
1324writer: A handle to a summary writer.
1325step: The step to write the summary for.
1326tensor: A tensor holding one or more serialized `Summary` protobufs to write.
1327  }];
1328
1329  let arguments = (ins
1330    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1331    TF_Int64Tensor:$step,
1332    TF_StrTensor:$tensor
1333  );
1334
1335  let results = (outs);
1336}
1337
1338def TF_WriteScalarSummaryOp : TF_Op<"WriteScalarSummary", []> {
1339  let summary = "Writes a `Summary` protocol buffer with scalar values.";
1340
1341  let description = [{
1342The input `tag` and `value` must have the scalars.
1343
1344writer: A handle to a summary writer.
1345step: The step to write the summary for.
1346tag: Tag for the summary.
1347value: Value for the summary.
1348  }];
1349
1350  let arguments = (ins
1351    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1352    TF_Int64Tensor:$step,
1353    TF_StrTensor:$tag,
1354    TF_IntOrFpTensor:$value
1355  );
1356
1357  let results = (outs);
1358
1359  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>;
1360}
1361
1362def TF_WriteSummaryOp : TF_Op<"WriteSummary", []> {
1363  let summary = "Outputs a `Summary` protocol buffer with a tensor.";
1364
1365  let description = [{
1366writer: A handle to a summary writer.
1367step: The step to write the summary for.
1368tensor: A tensor to serialize.
1369tag: The summary's tag.
1370summary_metadata: Serialized SummaryMetadata protocol buffer containing
1371 plugin-related metadata for this summary.
1372  }];
1373
1374  let arguments = (ins
1375    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1376    TF_Int64Tensor:$step,
1377    TF_Tensor:$tensor,
1378    TF_StrTensor:$tag,
1379    TF_StrTensor:$summary_metadata
1380  );
1381
1382  let results = (outs);
1383
1384  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>;
1385}
1386
1387def TF__TPUDeviceOrdinalPlaceholderOp : TF_Op<"_TPUDeviceOrdinalPlaceholder", []> {
1388  let summary = [{
1389Placeholder device ordinal that represents device ordinal of a replicated op.
1390  }];
1391
1392  let description = [{
1393This op can be used when certain rewrite passes materialize ops that require a
1394device ordinal of a replicated op but replication logic has been abstracted away
1395using tf_device.replicate op. Subsequent rewrite passes must replace this op with
1396a constant output that represents the correct device ordinal of the replicated
1397operations inside a TPU host.
1398  }];
1399
1400  let arguments = (ins);
1401
1402  let results = (outs
1403    TF_Int64Tensor:$device_ordinal
1404  );
1405}
1406
1407def TF_TPUPartitionedInputOp : TF_Op<"TPUPartitionedInput", [NoSideEffect]> {
1408  let summary = [{
1409An op that groups a list of partitioned inputs together. This op
1410  }];
1411
1412  let arguments = (ins
1413    Variadic<TF_Tensor>:$inputs,
1414
1415    DefaultValuedAttr<I64Attr, "0">:$partition_dim,
1416    OptionalAttr<StrAttr>:$_XlaSharding
1417  );
1418
1419  let results = (outs
1420    TF_Tensor:$output
1421  );
1422
1423  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
1424  TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
1425}
1426
1427def TF_TPUPartitionedOutputOp : TF_Op<"TPUPartitionedOutput", [NoSideEffect]> {
1428  let summary = [{
1429An op that demultiplexes a tensor to be sharded by XLA to a list of partitioned
1430  }];
1431
1432  let description = [{
1433outputs outside the XLA computation.
1434  }];
1435
1436  let arguments = (ins
1437    TF_Tensor:$inputs,
1438
1439    DefaultValuedAttr<I64Attr, "0">:$partition_dim,
1440    OptionalAttr<StrAttr>:$_XlaSharding
1441  );
1442
1443  let results = (outs
1444    Variadic<TF_Tensor>:$output
1445  );
1446
1447  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
1448  TF_DerivedResultSizeAttr num_splits = TF_DerivedResultSizeAttr<0>;
1449}
1450
1451// Declares symbol reference attribute `shape_inference_graph` to be optional
1452// unlike the TensorFlow definition. This is required to support ops that use
1453// empty string value for the attribute to signify missing.
1454def TF_XlaHostComputeOp : TF_Op<"XlaHostCompute", []> {
1455  let summary = [{
1456A pseudo-op to represent host-side computation in an XLA program.
1457  }];
1458
1459  let arguments = (ins
1460    Arg<Variadic<TF_Tensor>, [{A list of tensors that will be sent to the host.}]>:$inputs,
1461
1462    StrArrayAttr:$ancestors,
1463    TF_ShapeAttrArray:$shapes,
1464    OptionalAttr<SymbolRefAttr>:$shape_inference_graph,
1465    StrAttr:$key,
1466    DefaultValuedAttr<StrAttr, "">:$send_key,
1467    DefaultValuedAttr<StrAttr, "">:$recv_key,
1468    DefaultValuedAttr<I64Attr, "1000000">:$cost_estimate_ns,
1469    DefaultValuedAttr<I64Attr, "0">:$tpu_core
1470  );
1471
1472  let results = (outs
1473    Res<Variadic<TF_Tensor>, [{A list of tensors that will be returned to the device.}]>:$outputs
1474  );
1475
1476  TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
1477  TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
1478}
1479
1480def TF_ConfigureAndInitializeGlobalTPUOp : TF_Op<"ConfigureAndInitializeGlobalTPU", []> {
1481  let summary = [{
1482An op that initialize the TPU system in a multi-client set up.
1483  }];
1484
1485  let description = [{
1486Initializes global TPU system for mutli-client execution.
1487
1488This op does the work of both ConfigureDistributedTpuOp and
1489InitializeHostForDistributedTpuOp, and outputs the latter's result.
1490  }];
1491
1492  let arguments = (ins);
1493
1494  let results = (outs
1495    Res<TF_Int32Tensor, [{A vector containing the global TPU id of each TPU on the host.}]>:$output
1496  );
1497}
1498
1499def TF_ShutdownTPUSystemOp : TF_Op<"ShutdownTPUSystem", []> {
1500  let summary = [{
1501An op that shuts down the TPU system.
1502  }];
1503
1504  let arguments = (ins);
1505  let results = (outs
1506    TF_BoolTensor:$success
1507  );
1508}
1509
1510// Internal op for testing value-based side-effects for non-resource values.
1511// TODO(mgester) We should have an extension of TF dialect only for testing so
1512// TF dialect is not polluted with test ops.
1513def TF__InternalTestNonResourceValueSideEffects_ : TF_Op<"_InternalTestNonResourceValueSideEffects_", []> {
1514  let summary = "Internal op for testing only";
1515
1516  let arguments = (ins
1517    Arg<TF_StrTensor,"", [TF_DatasetIteratorRead, TF_DatasetIteratorWrite]>:$key
1518  );
1519  let results = (outs);
1520}
1521
1522#endif // TF_OPS
1523