• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This is the base operation definition file for TensorFlow.
17//
18// This file includes the definition for the TensorFlow dialect, base TensorFlow
19// op, and various commonly used TensorFlow traits, types, attributes, and
20// builders.
21
22#ifndef TF_OP_BASE
23#define TF_OP_BASE
24
25include "mlir/IR/OpBase.td"
26include "mlir/Interfaces/SideEffectInterfaces.td"
27include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td"
28
29//===----------------------------------------------------------------------===//
30// TensorFlow dialect definitions
31//===----------------------------------------------------------------------===//
32
33def TF_Dialect : Dialect {
34  let name = "tf";
35
36  let description = [{
37The TensorFlow dialect.
38
39This dialect maps to TensorFlow operations.
40
41Invariants:
42
43* All values are of Tensor type (in particular, scalars are
44  represented using zero-dimensional tensors);
45
46TODO: Make invariants more structured so that we can reference them in ops.
47  }];
48
49  let cppNamespace = "::mlir::TF";
50}
51
52//===----------------------------------------------------------------------===//
53// TensorFlow traits
54//===----------------------------------------------------------------------===//
55
56// Specify this trait if the op requires all outputs to have the same type and
57// the inputs either have the same type as result or a ref type corresponding to
58// the result type.
59def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait<
60  "TF::OperandsSameAsResultsTypeOrRef">;
61
62// Op has the same operand and result element types (or type itself, if scalar)
63// after resolving reference types (i.e., after converting reference types to
64// their corresponding TensorFlow or standard types).
65def TF_SameOperandsAndResultElementTypeResolveRef : NativeOpTrait<
66  "TF::SameOperandsAndResultElementTypeResolveRef">;
67
68// Op has the same operand and result types after resolving reference types
69// (i.e., after converting reference types to their corresponding TensorFlow or
70// standard types).
71def TF_SameOperandsAndResultTypeResolveRef : NativeOpTrait<
72  "TF::SameOperandsAndResultTypeResolveRef">;
73
74// Layout agnostic operations do not depend on the operands data layout (data
75// format), as an example all element wise operations are layout agnostic.
76def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">;
77
78// Trait to indicate operations that cannot be duplicated as they might carry
79// certain state around within their implementations.
80def TF_CannotDuplicate : NativeOpTrait<"TF::CannotDuplicate">;
81
82// Trait to indicate an operation cannot be constant folded.
83def TF_NoConstantFold : NativeOpTrait<"TF::NoConstantFold">;
84
85// Coefficient wise binary operation with implicit broadcasting support, for
86// example tf.Sub operation.
87def TF_CwiseBinary : NativeOpTrait<"TF::CwiseBinary">;
88
89// Coefficient wise unary operation, for example tf.Sqrt operation.
90def TF_CwiseUnary : NativeOpTrait<"TF::CwiseUnary">;
91
92// Variant of broadcastable trait that considers TF's subtype behavior.
93class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[
94    TCOpResIsShapedTypePred<opId, resId>,
95    CPred<"mlir::tf_type::BroadcastCompatible("
96              "$_op.getOperand(" # opId # ").getType(), "
97              "$_op.getResult(" # resId # ").getType())">]>;
98
99
100class TF_AllTypesMatchPred<list<string> values> :
101    CPred<"tf_type::AreCastCompatible(llvm::makeArrayRef({" #
102      !interleave(values, ", ") # "}))">;
103
104class TF_AllTypesMatch<list<string> names> :
105    PredOpTrait<
106        "all of {" # !interleave(names, ", ") #
107          "} have dynamically equal types ",
108        TF_AllTypesMatchPred<
109            !foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;
110
111// This trait indicates that all returned resources are unique for a
112// resource-allocating op (i.e. op with `MemAlloc` side-effect).
113//
114// Note that if the trait is used where this invariant is not true, then this
115// might lead to incorrect execution order, while if not used where it should
116// be, it can only lead to reduced performance due to conservative ordering.
117// Example op where the invariant is not true: `TF_VarHandleOp`.
118def TF_UniqueResourceAllocation: OpTraitList<[
119    TF_ResourceHandleAllocatorInterface,
120    NativeOpTrait<"TF::UniqueResourceAllocation">
121]>;
122
123//===----------------------------------------------------------------------===//
124// Rank/Shape helpers.
125//===----------------------------------------------------------------------===//
126
127class TF_OperandIsUnrankedPred<int n> :
128  CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
129
130class TF_ResultIsUnrankedPred<int n> :
131  CPred<"$_op.getResult(" # n # ").getType().isa<UnrankedTensorType>()">;
132
133// Returns true if the n-th operand has unknown rank or has rank m.
134class TF_OperandHasRank<int n, int m> :
135  PredOpTrait<"operand " # n # " is " # m # "-D",
136    Or<[TF_OperandIsUnrankedPred<n>,
137      CPred<"$_op.getOperand(" # n #
138      ").getType().cast<ShapedType>().getRank() == " # m>]>>;
139
140// Returns true if the n-th result has unknown rank or has rank m.
141class TF_ResultHasRank<int n, int m> :
142  PredOpTrait<"result " # n # " is " # m # "-D",
143    Or<[TF_ResultIsUnrankedPred<n>,
144      CPred<"$_op.getResult(" # n #
145      ").getType().cast<ShapedType>().getRank() == " # m>]>>;
146
147//===----------------------------------------------------------------------===//
148// TensorFlow op side effects
149//===----------------------------------------------------------------------===//
150
151class TF_ResourceBase<string resourceKind> :
152  Resource<!strconcat("::mlir::TF::ResourceEffects::", resourceKind)> {
153}
154
155// Value-based side-effects
156def TF_VariableResource : TF_ResourceBase<"Variable">;
157def TF_StackResource : TF_ResourceBase<"Stack">;
158def TF_TensorArrayResource : TF_ResourceBase<"TensorArray">;
159def TF_SummaryResource : TF_ResourceBase<"Summary">;
160def TF_LookupTableResource : TF_ResourceBase<"LookupTable">;
161def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">;
162def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">;
163def TF_DatasetIteratorResource : TF_ResourceBase<"DatasetIterator">;
164def TF_TPUEmbeddingResource : TF_ResourceBase<"TPUEmbedding">;
165def TF_GeneratorOpResource : TF_ResourceBase<"GeneratorOp">;
166
167def TF_VariableRead : MemRead<TF_VariableResource>;
168def TF_StackRead : MemRead<TF_StackResource>;
169def TF_TensorArrayRead : MemRead<TF_TensorArrayResource>;
170def TF_LookupTableRead : MemRead<TF_LookupTableResource>;
171def TF_DatasetSeedGeneratorRead : MemRead<TF_DatasetSeedGeneratorResource>;
172def TF_DatasetMemoryCacheRead : MemRead<TF_DatasetMemoryCacheResource>;
173def TF_DatasetIteratorRead : MemRead<TF_DatasetIteratorResource>;
174
175def TF_VariableWrite : MemWrite<TF_VariableResource>;
176def TF_StackWrite : MemWrite<TF_StackResource>;
177def TF_TensorArrayWrite : MemWrite<TF_TensorArrayResource>;
178def TF_SummaryWrite : MemWrite<TF_SummaryResource>;
179def TF_LookupTableWrite : MemWrite<TF_LookupTableResource>;
180def TF_DatasetSeedGeneratorWrite : MemWrite<TF_DatasetSeedGeneratorResource>;
181def TF_DatasetMemoryCacheWrite : MemWrite<TF_DatasetMemoryCacheResource>;
182def TF_DatasetIteratorWrite : MemWrite<TF_DatasetIteratorResource>;
183
184def TF_VariableAlloc : MemAlloc<TF_VariableResource>;
185def TF_StackAlloc : MemAlloc<TF_StackResource>;
186def TF_TensorArrayAlloc : MemAlloc<TF_TensorArrayResource>;
187def TF_SummaryAlloc : MemAlloc<TF_SummaryResource>;
188def TF_LookupTableAlloc : MemAlloc<TF_LookupTableResource>;
189def TF_DatasetSeedGeneratorAlloc : MemAlloc<TF_DatasetSeedGeneratorResource>;
190def TF_DatasetMemoryCacheAlloc : MemAlloc<TF_DatasetMemoryCacheResource>;
191def TF_DatasetIteratorAlloc : MemAlloc<TF_DatasetIteratorResource>;
192
193def TF_StackFree : MemFree<TF_StackResource>;
194def TF_TensorArrayFree : MemFree<TF_TensorArrayResource>;
195def TF_SummaryFree : MemFree<TF_SummaryResource>;
196def TF_DatasetSeedGeneratorFree : MemFree<TF_DatasetSeedGeneratorResource>;
197def TF_DatasetMemoryCacheFree : MemFree<TF_DatasetMemoryCacheResource>;
198def TF_DatasetIteratorFree : MemFree<TF_DatasetIteratorResource>;
199
200// Op-based side-effects
201def TF_TPUEmbeddingSideEffect : MemoryEffects<[MemWrite<TF_TPUEmbeddingResource>]>;
202def TF_GeneratorOpSideEffect : MemoryEffects<[MemWrite<TF_GeneratorOpResource>]>;
203
204//===----------------------------------------------------------------------===//
205// TensorFlow op definitions
206//===----------------------------------------------------------------------===//
207
208class TF_Op<string mnemonic, list<OpTrait> traits = []> :
209    Op<TF_Dialect, mnemonic, traits>;
210
211//===----------------------------------------------------------------------===//
212// TensorFlow attribute definitions
213//===----------------------------------------------------------------------===//
214
215class TF_TensorFlowAttr <string name, string description> :
216    Attr<CPred<"$_self.isa<mlir::TF::" # name # "Attr>()">,
217         "TensorFlow " # description # " attribute">;
218
219def TF_ShapeAttr : TF_TensorFlowAttr<"Shape", "shape"> {
220  let returnType = "llvm::Optional<llvm::ArrayRef<int64_t>>";
221  let convertFromStorage = "$_self.cast<mlir::TF::ShapeAttr>().getValue()";
222
223  // Create a ranked shape attr by default.
224  let constBuilderCall = "mlir::TF::ShapeAttr::get($_builder.getContext(), $0)";
225}
226
227def TF_ShapeAttrArray :
228    TypedArrayAttrBase<TF_ShapeAttr, "tensorflow shape attribute array">;
229
230//===----------------------------------------------------------------------===//
231// TensorFlow type definitions
232//===----------------------------------------------------------------------===//
233
234// Any tensor element type defined in the TensorFlow dialect
235def TF_TFDialectType :
236    Type<CPred<"$_self.isa<mlir::TF::TensorFlowType>()">, "TensorFlow type">;
237
238// Class for any TensorFlow dialect specific type
239class TF_TensorFlowType <string name, string description> :
240    Type<CPred<"$_self.isa<mlir::TF::" # name # "Type>()">,
241         "TensorFlow " # description # " type">,
242    BuildableType<"getType<mlir::TF::" # name # "Type>()">;
243
244//===----------------------------------------------------------------------===//
245// Reference types
246
247// Float reference types
248def TF_Float16Ref : TF_TensorFlowType<"HalfRef", "f16ref">;
249def TF_Float32Ref : TF_TensorFlowType<"FloatRef", "f32ref">;
250def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">;
251def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">;
252
253// Complex reference types
254def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">;
255def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">;
256
257// Integer reference types
258def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">;
259def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">;
260def TF_Int32Ref : TF_TensorFlowType<"Int32Ref", "i32ref">;
261def TF_Int64Ref : TF_TensorFlowType<"Int64Ref", "i64ref">;
262
263def TF_Uint8Ref : TF_TensorFlowType<"Uint8Ref", "ui8ref">;
264def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">;
265def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">;
266def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">;
267
268// Quantized reference types
269def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">;
270def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">;
271def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">;
272def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">;
273def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">;
274
275// Other reference types
276def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">;
277def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">;
278def TF_StrRef : TF_TensorFlowType<"StringRef", "stringref">;
279def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">;
280
281//===----------------------------------------------------------------------===//
282// Integer types (including corresponding reference types)
283
284def TF_Bool : AnyTypeOf<[I<1>, TF_BoolRef], "bool">;
285
286def TF_Int8 : AnyTypeOf<[I8, TF_Int8Ref], "8-bit integer">;
287def TF_Int16 : AnyTypeOf<[I16, TF_Int16Ref], "16-bit integer">;
288def TF_Int32 : AnyTypeOf<[I32, TF_Int32Ref], "32-bit integer">;
289def TF_Int64 : AnyTypeOf<[I64, TF_Int64Ref], "64-bit integer">;
290def TF_I32OrI64 : AnyTypeOf<[I32, I64, TF_Int32Ref, TF_Int64Ref],
291                           "32/64-bit signed integer">;
292
293def TF_Uint8 : AnyTypeOf<[UI<8>, TF_Uint8Ref], "8-bit unsigned integer">;
294def TF_Uint16 : AnyTypeOf<[UI<16>, TF_Uint16Ref], "16-bit unsigned integer">;
295def TF_Uint32 : AnyTypeOf<[UI<32>, TF_Uint32Ref], "32-bit unsigned integer">;
296def TF_Uint64 : AnyTypeOf<[UI<64>, TF_Uint64Ref], "64-bit unsigned integer">;
297
298// Any unsigned integer type
299def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64],
300                        "unsigned integer">;
301
302// Any signed integer type
303def TF_SInt : AnyTypeOf<[TF_Int8, TF_Int16, TF_Int32, TF_Int64],
304                        "signed integer">;
305
306// Any integer type
307def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">;
308
309// Tensor types
310def TF_BoolTensor : TensorOf<[TF_Bool]>;
311
312def TF_IntTensor : TensorOf<[TF_Int]>;
313def TF_Int8Tensor : TensorOf<[TF_Int8]>;
314def TF_Int16Tensor : TensorOf<[TF_Int16]>;
315def TF_Int32Tensor : TensorOf<[TF_Int32]>;
316def TF_Int64Tensor : TensorOf<[TF_Int64]>;
317def TF_I32OrI64Tensor : TensorOf<[TF_I32OrI64]>;
318
319def TF_Uint8Tensor : TensorOf<[TF_Uint8]>;
320def TF_Uint16Tensor : TensorOf<[TF_Uint16]>;
321def TF_Uint32Tensor : TensorOf<[TF_Uint32]>;
322def TF_Uint64Tensor : TensorOf<[TF_Uint64]>;
323
324//===----------------------------------------------------------------------===//
325// Quantized types (including corresponding reference types)
326
327def TF_Qint8   : AnyTypeOf<
328  [TF_TensorFlowType<"Qint8", "qint8">, TF_Qint8Ref],
329  "8-bit quantized integer">;
330def TF_Qint16  : AnyTypeOf<
331  [TF_TensorFlowType<"Qint16", "qint16">, TF_Qint16Ref],
332  "16-bit quantized integer">;
333def TF_Qint32  : AnyTypeOf<
334  [TF_TensorFlowType<"Qint32", "qint32">, TF_Qint32Ref],
335  "32-bit quantized integer">;
336def TF_Quint8  : AnyTypeOf<
337  [TF_TensorFlowType<"Quint8", "quint8">, TF_Quint8Ref],
338  "8-bit quantized unsigned integer">;
339def TF_Quint16 : AnyTypeOf<
340  [TF_TensorFlowType<"Quint16", "quint16">, TF_Quint16Ref],
341  "16-bit quantized unsigned integer">;
342
343// Any quantized type
344def TF_Quantized : AnyTypeOf<
345  [TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, TF_Quint16], "quantized">;
346
347//===----------------------------------------------------------------------===//
348// Floating-point types (including corresponding reference types)
349
350def TF_Float16 : AnyTypeOf<[F16, TF_Float16Ref], "16-bit float">;
351def TF_Float32 : AnyTypeOf<[F32, TF_Float32Ref], "32-bit float">;
352def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">;
353def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">;
354
355def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">;
356
357def TF_Float : AnyTypeOf<
358  [TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16],
359  "floating-point">;
360
361// Tensor types
362def TF_FloatTensor : TensorOf<[TF_Float]>;
363def TF_F32OrF64Tensor : TensorOf<[TF_F32OrF64]>;
364def TF_Float16Tensor : TensorOf<[TF_Float16]>;
365def TF_Float32Tensor : TensorOf<[TF_Float32]>;
366def TF_Float64Tensor : TensorOf<[TF_Float64]>;
367def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>;
368
369//===----------------------------------------------------------------------===//
370// Complex types (including corresponding reference types)
371
372// TODO(suderman): Remove TF_Complex64 and use a standard ops declaration, along
373// with the associated cleanup.
374def TF_Complex64 : AnyTypeOf<[Complex<F<32>>, TF_Complex64Ref],
375  "64-bit complex">;
376def TF_Complex128 : AnyTypeOf<[Complex<F<64>>, TF_Complex128Ref],
377  "128-bit complex">;
378def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">;
379
380// Tensor types
381def TF_ComplexTensor : TensorOf<[TF_Complex]>;
382def TF_Complex64Tensor : TensorOf<[TF_Complex64]>;
383def TF_Complex128Tensor : TensorOf<[TF_Complex128]>;
384
385//===----------------------------------------------------------------------===//
386// String/variant/resource types (including corresponding reference types)
387
388def TF_Str : AnyTypeOf<
389  [TF_TensorFlowType<"String", "str">, TF_StrRef], "string">;
390def TF_StrTensor : TensorOf<[TF_Str]>;
391
392def TF_Variant : AnyTypeOf<
393  [TF_TensorFlowType<"Variant", "var">, TF_VariantRef], "variant">;
394def TF_VariantTensor : TensorOf<[TF_Variant]>;
395
396def TF_Resource : AnyTypeOf<
397  [TF_TensorFlowType<"Resource", "res">, TF_ResourceRef], "resource">;
398def TF_ResourceTensor : TensorOf<[TF_Resource]>;
399
400//===----------------------------------------------------------------------===//
401// Multi-category type constraints
402
403def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32OrF64]>;
404def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32OrI64]>;
405def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>;
406def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>;
407def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>;
408
409def TF_Number : AnyTypeOf<
410  [TF_Int, TF_Float, TF_Quantized, TF_Complex], "number">;
411def TF_NumberTensor : TensorOf<[TF_Number]>;
412
413def TF_NumberNotQuantizedOrStr :
414  AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Str]>;
415def TF_NumberNotQuantizedOrStrTensor : TensorOf<[TF_NumberNotQuantizedOrStr]>;
416
417//===----------------------------------------------------------------------===//
418// Tensor and tensor element types
419
420// Any tensor element type allowed in TensorFlow ops
421// (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType)
422def TF_ElementType : Type<Or<[TF_Float.predicate,
423                              TF_Complex.predicate,
424                              TF_Int.predicate,
425                              TF_Bool.predicate,
426                              TF_TFDialectType.predicate]>,
427                          "tf.dtype">;
428
429// Any TensorFlow tensor type
430def TF_Tensor : TensorOf<[TF_ElementType]>;
431
432//===----------------------------------------------------------------------===//
433// TensorFlow attribute definitions
434//===----------------------------------------------------------------------===//
435
436//===----------------------------------------------------------------------===//
437// Tensorflow devices metadata
438
439// Tensorflow GPU device metadata.
440def TF_GpuDeviceMetadata : StructAttr<"GpuDeviceMetadata", TF_Dialect, [
441    // GPU device compute capability: major:minor.
442    StructFieldAttr<"cc_major", I32Attr>,
443    StructFieldAttr<"cc_minor", I32Attr>
444]>;
445
446//===----------------------------------------------------------------------===//
447// String attribute constraints
448
449// A string attribute whose value are one of the values in `cases`.
450class TF_AnyStrAttrOf<list<string> cases> : StringBasedAttr<
451  CPred<!foldl(
452      "$_self.cast<StringAttr>().getValue() == \"" # !head(cases) # "\"",
453      !foreach(case, !tail(cases),
454               "$_self.cast<StringAttr>().getValue() == \"" # case # "\""),
455      prev, cur, prev # " || " # cur)>,
456  "string attribute whose value is " #
457    !foldl(/*init*/!head(cases), /*list*/!tail(cases),
458           prev, cur, prev # ", or " # cur)>;
459
460// TODO: Use EnumAttr to define the common attribute cases
461
462def TF_ConvnetDataFormatAttr : StringBasedAttr<
463    CPred<"$_self.cast<StringAttr>().getValue() == \"NHWC\" || " #
464          "$_self.cast<StringAttr>().getValue() == \"NCHW\"">,
465    "'NHWC' or 'NCHW' convnet data format">;
466
467//===----------------------------------------------------------------------===//
468// Type attributes
469
470// A derived attribute that returns the size of `idx`-th ODS-declared variadic
471// operand.
472class TF_DerivedOperandSizeAttr<int idx> : DerivedAttr<
473  "size_t",
474  "auto range = getODSOperands(" # idx # ");\n"
475  "return std::distance(range.begin(), range.end());",
476  [{ $_builder.getI64IntegerAttr($_self) }]>;
477
478// A derived attribute that returns the element type of `idx`-th ODS-declared
479// operand. If the `idx`-th operand is a variadic operand, then this attribute
480// just returns the element type of its first tensor, which is only meaningful
481// when the variadic operand has at least one tensor and the tensors all have
482// the same element type.
483class TF_DerivedOperandTypeAttr<int idx> : DerivedTypeAttr<
484  "return mlir::getElementTypeOrSelf(*getODSOperands(" # idx # ").begin());">;
485
486// A derived attribute that returns the element types of the tensors in the
487// actual value pack that corresponds to the `idx`-th ODS-declared variadic
488// operand. This returns a list of element types so it is used for variadic
489// operands that can have different element types.
490class TF_DerivedOperandTypeListAttr<int idx> : DerivedAttr<
491  "mlir::OperandElementTypeRange",
492  "auto values = getODSOperands(" # idx # ");\n"
493  "return {mlir::OperandElementTypeIterator(values.begin()), "
494          "mlir::OperandElementTypeIterator(values.end())};",
495  [{
496    ArrayAttr::get($_ctx,
497    [&]() {
498      llvm::SmallVector<Attribute, 4> ret;
499      for (auto t : $_self)
500        ret.push_back(TypeAttr::get(t));
501      return ret;
502    }())
503  }]
504>;
505
506// A derived attribute that returns the shapes of the tensors in the actual
507// value pack that corresponds to the `idx`-th ODS-declared variadic operand.
508// This returns a list of shapes so it is used for variadic operands that
509// can have different shapes.
510class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr<
511  "::mlir::TF::OperandShapeRange",
512  "auto values = getODSOperands(" # idx # ");\n"
513  "return {mlir::TF::OperandShapeIterator(values.begin()), "
514          "mlir::TF::OperandShapeIterator(values.end())};",
515  [{
516    ArrayAttr::get($_ctx,
517      [&](){
518        llvm::SmallVector<Attribute, 4> ret;
519        for (auto shape : $_self)
520          ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape));
521        return ret;
522      }())
523  }]
524>;
525
526// A derived attribute that returns the size of `idx`-th ODS-declared variadic
527// result.
528class TF_DerivedResultSizeAttr<int idx> : DerivedAttr<
529  "size_t",
530  "auto range = getODSResults(" # idx # ");\n"
531  "return std::distance(range.begin(), range.end());",
532  [{ $_builder.getI64IntegerAttr($_self) }]>;
533
534// A derived attribute that returns the element type of `idx`-th ODS-declared
535// result. If the `idx`-th result is a variadic result, then this attribute
536// just returns the element type of its first tensor, which is only meaningful
537// when the variadic result has at least one tensor and the tensors all have
538// the same element type.
539class TF_DerivedResultTypeAttr<int idx> : DerivedTypeAttr<
540  "return mlir::getElementTypeOrSelf(*getODSResults(" # idx # ").begin());">;
541
542// A derived attribute that returns the element types of the tensors in the
543// actual value pack that corresponds to the `idx`-th ODS-declared variadic
544// result. This returns a list of element types so it is used for variadic
545// results that can have different element types.
546class TF_DerivedResultTypeListAttr<int idx> : DerivedAttr<
547  "mlir::ResultElementTypeRange",
548  "auto values = getODSResults(" # idx # ");\n"
549  "return {mlir::ResultElementTypeIterator(values.begin()), "
550          "mlir::ResultElementTypeIterator(values.end())};",
551  [{
552    ArrayAttr::get($_ctx,
553    [&]() {
554      llvm::SmallVector<Attribute, 4> ret;
555      for (auto t : $_self)
556        ret.push_back(TypeAttr::get(t));
557      return ret;
558    }())
559  }]
560>;
561
562// A derived attribute that returns the shapes of the tensors in the actual
563// value pack that corresponds to the `idx`-th ODS-declared variadic result.
564// This returns a list of shapes so it is used for variadic results that
565// can have different shapes.
566class TF_DerivedResultShapeListAttr<int idx> : DerivedAttr<
567  "mlir::TF::ResultShapeRange",
568  "auto values = getODSResults(" # idx # ");\n"
569  "return {mlir::TF::ResultShapeIterator(values.begin()), "
570          "mlir::TF::ResultShapeIterator(values.end())};",
571  [{
572    ArrayAttr::get($_ctx,
573      [&](){
574        llvm::SmallVector<Attribute, 4> ret;
575        for (auto shape : $_self)
576          ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape));
577        return ret;
578      }())
579  }]
580>;
581
582// A derived attribute that returns the shape of the first result type.
583def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType",
584  "return (*getOperation()->result_type_begin()).cast<ShapedType>();",
585  [{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>;
586
587def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> {
588  let returnType = "Type";
589}
590
591//===----------------------------------------------------------------------===//
592// TensorFlow common builders
593//===----------------------------------------------------------------------===//
594
595// Mixin class defining a builder for binary ops supporting broadcast
596// behavior. The result type has the same element type as both operands.
597class WithBroadcastableBinOpBuilder {
598  list<OpBuilder> builders = [
599    OpBuilder<(ins "Value":$x, "Value":$y),
600    [{
601  auto resultType =
602      OpTrait::util::getBroadcastedType(x.getType(), y.getType());
603  if (!resultType)
604    mlir::emitError($_state.location, "non-broadcastable operands");
605  return build($_builder, $_state, resultType, x, y);
606}]>];
607}
608
609// Mixin class defining a builder for comparison ops supporting broadcast
610// behavior. The result type has bool element type.
611class WithBroadcastableCmpOpBuilder {
612  list<OpBuilder> builders = [
613    OpBuilder<(ins "Value":$x, "Value":$y),
614    [{
615  Type resultType;
616  if (x.getType().isa<UnrankedTensorType>() ||
617      y.getType().isa<UnrankedTensorType>()) {
618    resultType = UnrankedTensorType::get($_builder.getI1Type());
619  } else {
620    SmallVector<int64_t, 4> resultShape;
621    if (!OpTrait::util::getBroadcastedShape(
622            x.getType().cast<ShapedType>().getShape(),
623            y.getType().cast<ShapedType>().getShape(), resultShape)) {
624      mlir::emitError($_state.location,
625                      "operands have no broadcastable shapes");
626    }
627
628    resultType = RankedTensorType::get(resultShape, $_builder.getI1Type());
629  }
630  return build($_builder, $_state, resultType, x, y);
631}]>];
632}
633
634#endif // TF_OP_BASE
635