• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This is the base operation definition file for TensorFlow.
17//
18// This file includes the definition for the TensorFlow dialect, base TensorFlow
19// op, and various commonly used TensorFlow traits, types, attributes, and
20// builders.
21
22#ifndef TF_OP_BASE
23#define TF_OP_BASE
24
25include "mlir/IR/OpBase.td"
26include "mlir/Interfaces/SideEffectInterfaces.td"
27include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td"
28
29//===----------------------------------------------------------------------===//
30// TensorFlow dialect definitions
31//===----------------------------------------------------------------------===//
32
33def TF_Dialect : Dialect {
34  let name = "tf";
35
36  let description = [{
37The TensorFlow dialect.
38
39This dialect maps to TensorFlow operations.
40
41Invariants:
42
43* All values are of Tensor type (in particular, scalars are
44  represented using zero-dimensional tensors);
45
46TODO: Make invariants more structured so that we can reference them in ops.
47  }];
48
49  let cppNamespace = "::mlir::TF";
50}
51
52//===----------------------------------------------------------------------===//
53// TensorFlow traits
54//===----------------------------------------------------------------------===//
55
56// Specify this trait if the op requires all outputs to have the same type and
57// the inputs either have the same type as result or a ref type corresponding to
58// the result type.
59def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait<
60  "TF::OperandsSameAsResultsTypeOrRef">;
61
62// Op has the same operand and result element types (or type itself, if scalar)
63// after resolving reference types (i.e., after converting reference types to
64// their corresponding TensorFlow or standard types).
65def TF_SameOperandsAndResultElementTypeResolveRef : NativeOpTrait<
66  "TF::SameOperandsAndResultElementTypeResolveRef">;
67
68// Op has the same operand and result types after resolving reference types
69// (i.e., after converting reference types to their corresponding TensorFlow or
70// standard types).
71def TF_SameOperandsAndResultTypeResolveRef : NativeOpTrait<
72  "TF::SameOperandsAndResultTypeResolveRef">;
73
74// Layout agnostic operations do not depend on the operands data layout (data
75// format), as an example all element wise operations are layout agnostic.
76def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">;
77
78// Trait to indicate operations that cannot be duplicated as they might carry
79// certain state around within their implementations.
80def TF_CannotDuplicate : NativeOpTrait<"TF::CannotDuplicate">;
81
82// Trait to indicate an operation cannot be constant folded.
83def TF_NoConstantFold : NativeOpTrait<"TF::NoConstantFold">;
84
85// Coefficient wise binary operation with implicit broadcasting support, for
86// example tf.Sub operation.
87def TF_CwiseBinary : NativeOpTrait<"TF::CwiseBinary">;
88
89// Coefficient wise unary operation, for example tf.Sqrt operation.
90def TF_CwiseUnary : NativeOpTrait<"TF::CwiseUnary">;
91
92// Variant of broadcastable trait that considers TF's subtype behavior.
93class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[
94    TCOpResIsShapedTypePred<opId, resId>,
95    CPred<"mlir::TF::BroadcastCompatible("
96              "$_op.getOperand(" # opId # ").getType(), "
97              "$_op.getResult(" # resId # ").getType())">]>;
98
99
100class TF_AllTypesMatchPred<list<string> values> :
101    CPred<"TF::AreCastCompatible(llvm::makeArrayRef({" #
102      !interleave(values, ", ") # "}))">;
103
104class TF_AllTypesMatch<list<string> names> :
105    PredOpTrait<
106        "all of {" # !interleave(names, ", ") #
107          "} have dynamically equal types ",
108        TF_AllTypesMatchPred<
109            !foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;
110
111//===----------------------------------------------------------------------===//
112// Rank/Shape helpers.
113//===----------------------------------------------------------------------===//
114
115class TF_OperandIsUnrankedPred<int n> :
116  CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
117
118class TF_ResultIsUnrankedPred<int n> :
119  CPred<"$_op.getResult(" # n # ").getType().isa<UnrankedTensorType>()">;
120
121// Returns true if the n-th operand has unknown rank or has rank m.
122class TF_OperandHasRank<int n, int m> :
123  PredOpTrait<"operand " # n # " is " # m # "-D",
124    Or<[TF_OperandIsUnrankedPred<n>,
125      CPred<"$_op.getOperand(" # n #
126      ").getType().cast<ShapedType>().getRank() == " # m>]>>;
127
128// Returns true if the n-th result has unknown rank or has rank m.
129class TF_ResultHasRank<int n, int m> :
130  PredOpTrait<"result " # n # " is " # m # "-D",
131    Or<[TF_ResultIsUnrankedPred<n>,
132      CPred<"$_op.getResult(" # n #
133      ").getType().cast<ShapedType>().getRank() == " # m>]>>;
134
135//===----------------------------------------------------------------------===//
136// TensorFlow op side effects
137//===----------------------------------------------------------------------===//
138
139class TF_ResourceBase<string resourceKind> :
140  Resource<!strconcat("::mlir::TF::ResourceEffects::", resourceKind)> {
141}
142
143def TF_VariableResource : TF_ResourceBase<"Variable">;
144def TF_StackResource : TF_ResourceBase<"Stack">;
145def TF_TensorArrayResource : TF_ResourceBase<"TensorArray">;
146def TF_SummaryResource : TF_ResourceBase<"Summary">;
147def TF_LookupTableResource : TF_ResourceBase<"LookupTable">;
148def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">;
149def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">;
150def TF_DatasetIteratorResource : TF_ResourceBase<"DatasetIterator">;
151def TF_TPUEmbeddingResource : TF_ResourceBase<"TPUEmbedding">;
152
153def TF_VariableRead : MemRead<TF_VariableResource>;
154def TF_StackRead : MemRead<TF_StackResource>;
155def TF_TensorArrayRead : MemRead<TF_TensorArrayResource>;
156def TF_LookupTableRead : MemRead<TF_LookupTableResource>;
157def TF_DatasetSeedGeneratorRead : MemRead<TF_DatasetSeedGeneratorResource>;
158def TF_DatasetMemoryCacheRead : MemRead<TF_DatasetMemoryCacheResource>;
159def TF_DatasetIteratorRead : MemRead<TF_DatasetIteratorResource>;
160
161def TF_VariableWrite : MemWrite<TF_VariableResource>;
162def TF_StackWrite : MemWrite<TF_StackResource>;
163def TF_TensorArrayWrite : MemWrite<TF_TensorArrayResource>;
164def TF_SummaryWrite : MemWrite<TF_SummaryResource>;
165def TF_LookupTableWrite : MemWrite<TF_LookupTableResource>;
166def TF_DatasetSeedGeneratorWrite : MemWrite<TF_DatasetSeedGeneratorResource>;
167def TF_DatasetMemoryCacheWrite : MemWrite<TF_DatasetMemoryCacheResource>;
168def TF_DatasetIteratorWrite : MemWrite<TF_DatasetIteratorResource>;
169
170def TF_VariableAlloc : MemAlloc<TF_VariableResource>;
171def TF_StackAlloc : MemAlloc<TF_StackResource>;
172def TF_TensorArrayAlloc : MemAlloc<TF_TensorArrayResource>;
173def TF_SummaryAlloc : MemAlloc<TF_SummaryResource>;
174def TF_LookupTableAlloc : MemAlloc<TF_LookupTableResource>;
175def TF_DatasetSeedGeneratorAlloc : MemAlloc<TF_DatasetSeedGeneratorResource>;
176def TF_DatasetMemoryCacheAlloc : MemAlloc<TF_DatasetMemoryCacheResource>;
177def TF_DatasetIteratorAlloc : MemAlloc<TF_DatasetIteratorResource>;
178
179def TF_StackFree : MemFree<TF_StackResource>;
180def TF_TensorArrayFree : MemFree<TF_TensorArrayResource>;
181def TF_SummaryFree : MemFree<TF_SummaryResource>;
182def TF_DatasetSeedGeneratorFree : MemFree<TF_DatasetSeedGeneratorResource>;
183def TF_DatasetMemoryCacheFree : MemFree<TF_DatasetMemoryCacheResource>;
184def TF_DatasetIteratorFree : MemFree<TF_DatasetIteratorResource>;
185
186def TF_TPUEmbeddingSideEffect : MemoryEffects<[MemWrite<TF_TPUEmbeddingResource>]>;
187
188//===----------------------------------------------------------------------===//
189// TensorFlow op definitions
190//===----------------------------------------------------------------------===//
191
192class TF_Op<string mnemonic, list<OpTrait> traits = []> :
193    Op<TF_Dialect, mnemonic, traits>;
194
195//===----------------------------------------------------------------------===//
196// TensorFlow attribute definitions
197//===----------------------------------------------------------------------===//
198
199class TF_TensorFlowAttr <string name, string description> :
200    Attr<CPred<"$_self.isa<mlir::TF::" # name # "Attr>()">,
201         "TensorFlow " # description # " attribute">;
202
203def TF_ShapeAttr : TF_TensorFlowAttr<"Shape", "shape"> {
204  let returnType = "llvm::Optional<llvm::ArrayRef<int64_t>>";
205  let convertFromStorage = "$_self.cast<mlir::TF::ShapeAttr>().getValue()";
206
207  // Create a ranked shape attr by default.
208  let constBuilderCall = "mlir::TF::ShapeAttr::get($_builder.getContext(), $0)";
209}
210
211def TF_ShapeAttrArray :
212    TypedArrayAttrBase<TF_ShapeAttr, "tensorflow shape attribute array">;
213
214//===----------------------------------------------------------------------===//
215// TensorFlow type definitions
216//===----------------------------------------------------------------------===//
217
218// Any tensor element type defined in the TensorFlow dialect
219def TF_TFDialectType :
220    Type<CPred<"$_self.isa<mlir::TF::TensorFlowType>()">, "TensorFlow type">;
221
222// Class for any TensorFlow dialect specific type
223class TF_TensorFlowType <string name, string description> :
224    Type<CPred<"$_self.isa<mlir::TF::" # name # "Type>()">,
225         "TensorFlow " # description # " type">,
226    BuildableType<"getType<mlir::TF::" # name # "Type>()">;
227
228//===----------------------------------------------------------------------===//
229// Reference types
230
231// Float reference types
232def TF_Float16Ref : TF_TensorFlowType<"HalfRef", "f16ref">;
233def TF_Float32Ref : TF_TensorFlowType<"FloatRef", "f32ref">;
234def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">;
235def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">;
236
237// Complex reference types
238def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">;
239def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">;
240
241// Integer reference types
242def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">;
243def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">;
244def TF_Int32Ref : TF_TensorFlowType<"Int32Ref", "i32ref">;
245def TF_Int64Ref : TF_TensorFlowType<"Int64Ref", "i64ref">;
246
247def TF_Uint8Ref : TF_TensorFlowType<"Uint8Ref", "ui8ref">;
248def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">;
249def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">;
250def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">;
251
252// Quantized reference types
253def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">;
254def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">;
255def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">;
256def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">;
257def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">;
258
259// Other reference types
260def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">;
261def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">;
262def TF_StrRef : TF_TensorFlowType<"StringRef", "stringref">;
263def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">;
264
265//===----------------------------------------------------------------------===//
266// Integer types (including corresponding reference types)
267
268def TF_Bool : AnyTypeOf<[I<1>, TF_BoolRef], "bool">;
269
270def TF_Int8 : AnyTypeOf<[I8, TF_Int8Ref], "8-bit integer">;
271def TF_Int16 : AnyTypeOf<[I16, TF_Int16Ref], "16-bit integer">;
272def TF_Int32 : AnyTypeOf<[I32, TF_Int32Ref], "32-bit integer">;
273def TF_Int64 : AnyTypeOf<[I64, TF_Int64Ref], "64-bit integer">;
274def TF_I32OrI64 : AnyTypeOf<[I32, I64, TF_Int32Ref, TF_Int64Ref],
275                           "32/64-bit signed integer">;
276
277def TF_Uint8 : AnyTypeOf<[UI<8>, TF_Uint8Ref], "8-bit unsigned integer">;
278def TF_Uint16 : AnyTypeOf<[UI<16>, TF_Uint16Ref], "16-bit unsigned integer">;
279def TF_Uint32 : AnyTypeOf<[UI<32>, TF_Uint32Ref], "32-bit unsigned integer">;
280def TF_Uint64 : AnyTypeOf<[UI<64>, TF_Uint64Ref], "64-bit unsigned integer">;
281
282// Any unsigned integer type
283def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64],
284                        "unsigned integer">;
285
286// Any signed integer type
287def TF_SInt : AnyTypeOf<[TF_Int8, TF_Int16, TF_Int32, TF_Int64],
288                        "signed integer">;
289
290// Any integer type
291def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">;
292
293// Tensor types
294def TF_BoolTensor : TensorOf<[TF_Bool]>;
295
296def TF_IntTensor : TensorOf<[TF_Int]>;
297def TF_Int8Tensor : TensorOf<[TF_Int8]>;
298def TF_Int16Tensor : TensorOf<[TF_Int16]>;
299def TF_Int32Tensor : TensorOf<[TF_Int32]>;
300def TF_Int64Tensor : TensorOf<[TF_Int64]>;
301def TF_I32OrI64Tensor : TensorOf<[TF_I32OrI64]>;
302
303def TF_Uint8Tensor : TensorOf<[TF_Uint8]>;
304def TF_Uint16Tensor : TensorOf<[TF_Uint16]>;
305def TF_Uint32Tensor : TensorOf<[TF_Uint32]>;
306def TF_Uint64Tensor : TensorOf<[TF_Uint64]>;
307
308//===----------------------------------------------------------------------===//
309// Quantized types (including corresponding reference types)
310
311def TF_Qint8   : AnyTypeOf<
312  [TF_TensorFlowType<"Qint8", "qint8">, TF_Qint8Ref],
313  "8-bit quantized integer">;
314def TF_Qint16  : AnyTypeOf<
315  [TF_TensorFlowType<"Qint16", "qint16">, TF_Qint16Ref],
316  "16-bit quantized integer">;
317def TF_Qint32  : AnyTypeOf<
318  [TF_TensorFlowType<"Qint32", "qint32">, TF_Qint32Ref],
319  "32-bit quantized integer">;
320def TF_Quint8  : AnyTypeOf<
321  [TF_TensorFlowType<"Quint8", "quint8">, TF_Quint8Ref],
322  "8-bit quantized unsigned integer">;
323def TF_Quint16 : AnyTypeOf<
324  [TF_TensorFlowType<"Quint16", "quint16">, TF_Quint16Ref],
325  "16-bit quantized unsigned integer">;
326
327// Any quantized type
328def TF_Quantized : AnyTypeOf<
329  [TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, TF_Quint16], "quantized">;
330
331//===----------------------------------------------------------------------===//
332// Floating-point types (including corresponding reference types)
333
334def TF_Float16 : AnyTypeOf<[F16, TF_Float16Ref], "16-bit float">;
335def TF_Float32 : AnyTypeOf<[F32, TF_Float32Ref], "32-bit float">;
336def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">;
337def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">;
338
339def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">;
340
341def TF_Float : AnyTypeOf<
342  [TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16],
343  "floating-point">;
344
345// Tensor types
346def TF_FloatTensor : TensorOf<[TF_Float]>;
347def TF_F32OrF64Tensor : TensorOf<[TF_F32OrF64]>;
348def TF_Float16Tensor : TensorOf<[TF_Float16]>;
349def TF_Float32Tensor : TensorOf<[TF_Float32]>;
350def TF_Float64Tensor : TensorOf<[TF_Float64]>;
351def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>;
352
353//===----------------------------------------------------------------------===//
354// Complex types (including corresponding reference types)
355
356// TODO(suderman): Remove TF_Complex64 and use a standard ops declaration, along
357// with the associated cleanup.
358def TF_Complex64 : AnyTypeOf<[Complex<F<32>>, TF_Complex64Ref],
359  "64-bit complex">;
360def TF_Complex128 : AnyTypeOf<[Complex<F<64>>, TF_Complex128Ref],
361  "128-bit complex">;
362def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">;
363
364// Tensor types
365def TF_ComplexTensor : TensorOf<[TF_Complex]>;
366def TF_Complex64Tensor : TensorOf<[TF_Complex64]>;
367def TF_Complex128Tensor : TensorOf<[TF_Complex128]>;
368
369//===----------------------------------------------------------------------===//
370// String/variant/resource types (including corresponding reference types)
371
372def TF_Str : AnyTypeOf<
373  [TF_TensorFlowType<"String", "str">, TF_StrRef], "string">;
374def TF_StrTensor : TensorOf<[TF_Str]>;
375
376def TF_Variant : AnyTypeOf<
377  [TF_TensorFlowType<"Variant", "var">, TF_VariantRef], "variant">;
378def TF_VariantTensor : TensorOf<[TF_Variant]>;
379
380def TF_Resource : AnyTypeOf<
381  [TF_TensorFlowType<"Resource", "res">, TF_ResourceRef], "resource">;
382def TF_ResourceTensor : TensorOf<[TF_Resource]>;
383
384//===----------------------------------------------------------------------===//
385// Multi-category type constraints
386
387def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32OrF64]>;
388def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32OrI64]>;
389def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>;
390def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>;
391def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>;
392
393def TF_Number : AnyTypeOf<
394  [TF_Int, TF_Float, TF_Quantized, TF_Complex], "number">;
395def TF_NumberTensor : TensorOf<[TF_Number]>;
396
397def TF_NumberNotQuantizedOrStr :
398  AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Str]>;
399def TF_NumberNotQuantizedOrStrTensor : TensorOf<[TF_NumberNotQuantizedOrStr]>;
400
401//===----------------------------------------------------------------------===//
402// Tensor and tensor element types
403
404// Any tensor element type allowed in TensorFlow ops
405// (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType)
406def TF_ElementType : Type<Or<[TF_Float.predicate,
407                              TF_Complex.predicate,
408                              TF_Int.predicate,
409                              TF_Bool.predicate,
410                              TF_TFDialectType.predicate]>,
411                          "tf.dtype">;
412
413// Any TensorFlow tensor type
414def TF_Tensor : TensorOf<[TF_ElementType]>;
415
416//===----------------------------------------------------------------------===//
417// TensorFlow attribute definitions
418//===----------------------------------------------------------------------===//
419
420//===----------------------------------------------------------------------===//
421// Tensorflow devices metadata
422
423// Tensorflow GPU device metadata.
424def TF_GpuDeviceMetadata : StructAttr<"GpuDeviceMetadata", TF_Dialect, [
425    // GPU device compute capability: major:minor.
426    StructFieldAttr<"cc_major", I32Attr>,
427    StructFieldAttr<"cc_minor", I32Attr>
428]>;
429
430//===----------------------------------------------------------------------===//
431// String attribute constraints
432
433// A string attribute whose value are one of the values in `cases`.
434class TF_AnyStrAttrOf<list<string> cases> : StringBasedAttr<
435  CPred<!foldl(
436      "$_self.cast<StringAttr>().getValue() == \"" # !head(cases) # "\"",
437      !foreach(case, !tail(cases),
438               "$_self.cast<StringAttr>().getValue() == \"" # case # "\""),
439      prev, cur, prev # " || " # cur)>,
440  "string attribute whose value is " #
441    !foldl(/*init*/!head(cases), /*list*/!tail(cases),
442           prev, cur, prev # ", or " # cur)>;
443
444// TODO: Use EnumAttr to define the common attribute cases
445
446def TF_ConvnetDataFormatAttr : StringBasedAttr<
447    CPred<"$_self.cast<StringAttr>().getValue() == \"NHWC\" || " #
448          "$_self.cast<StringAttr>().getValue() == \"NCHW\"">,
449    "'NHWC' or 'NCHW' convnet data format">;
450
451//===----------------------------------------------------------------------===//
452// Type attributes
453
454// A derived attribute that returns the size of `idx`-th ODS-declared variadic
455// operand.
456class TF_DerivedOperandSizeAttr<int idx> : DerivedAttr<
457  "size_t",
458  "auto range = getODSOperands(" # idx # ");\n"
459  "return std::distance(range.begin(), range.end());",
460  [{ $_builder.getI64IntegerAttr($_self) }]>;
461
462// A derived attribute that returns the element type of `idx`-th ODS-declared
463// operand. If the `idx`-th operand is a variadic operand, then this attribute
464// just returns the element type of its first tensor, which is only meaningful
465// when the variadic operand has at least one tensor and the tensors all have
466// the same element type.
467class TF_DerivedOperandTypeAttr<int idx> : DerivedTypeAttr<
468  "return mlir::getElementTypeOrSelf(*getODSOperands(" # idx # ").begin());">;
469
470// A derived attribute that returns the element types of the tensors in the
471// actual value pack that corresponds to the `idx`-th ODS-declared variadic
472// operand. This returns a list of element types so it is used for variadic
473// operands that can have different element types.
474class TF_DerivedOperandTypeListAttr<int idx> : DerivedAttr<
475  "mlir::OperandElementTypeRange",
476  "auto values = getODSOperands(" # idx # ");\n"
477  "return {mlir::OperandElementTypeIterator(values.begin()), "
478          "mlir::OperandElementTypeIterator(values.end())};",
479  [{
480    ArrayAttr::get($_ctx,
481    [&]() {
482      llvm::SmallVector<Attribute, 4> ret;
483      for (auto t : $_self)
484        ret.push_back(TypeAttr::get(t));
485      return ret;
486    }())
487  }]
488>;
489
490// A derived attribute that returns the shapes of the tensors in the actual
491// value pack that corresponds to the `idx`-th ODS-declared variadic operand.
492// This returns a list of shapes so it is used for variadic operands that
493// can have different shapes.
494class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr<
495  "::mlir::TF::OperandShapeRange",
496  "auto values = getODSOperands(" # idx # ");\n"
497  "return {mlir::TF::OperandShapeIterator(values.begin()), "
498          "mlir::TF::OperandShapeIterator(values.end())};",
499  [{
500    ArrayAttr::get($_ctx,
501      [&](){
502        llvm::SmallVector<Attribute, 4> ret;
503        for (auto shape : $_self)
504          ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape));
505        return ret;
506      }())
507  }]
508>;
509
510// A derived attribute that returns the size of `idx`-th ODS-declared variadic
511// result.
512class TF_DerivedResultSizeAttr<int idx> : DerivedAttr<
513  "size_t",
514  "auto range = getODSResults(" # idx # ");\n"
515  "return std::distance(range.begin(), range.end());",
516  [{ $_builder.getI64IntegerAttr($_self) }]>;
517
518// A derived attribute that returns the element type of `idx`-th ODS-declared
519// result. If the `idx`-th result is a variadic result, then this attribute
520// just returns the element type of its first tensor, which is only meaningful
521// when the variadic result has at least one tensor and the tensors all have
522// the same element type.
523class TF_DerivedResultTypeAttr<int idx> : DerivedTypeAttr<
524  "return mlir::getElementTypeOrSelf(*getODSResults(" # idx # ").begin());">;
525
526// A derived attribute that returns the element types of the tensors in the
527// actual value pack that corresponds to the `idx`-th ODS-declared variadic
528// result. This returns a list of element types so it is used for variadic
529// results that can have different element types.
530class TF_DerivedResultTypeListAttr<int idx> : DerivedAttr<
531  "mlir::ResultElementTypeRange",
532  "auto values = getODSResults(" # idx # ");\n"
533  "return {mlir::ResultElementTypeIterator(values.begin()), "
534          "mlir::ResultElementTypeIterator(values.end())};",
535  [{
536    ArrayAttr::get($_ctx,
537    [&]() {
538      llvm::SmallVector<Attribute, 4> ret;
539      for (auto t : $_self)
540        ret.push_back(TypeAttr::get(t));
541      return ret;
542    }())
543  }]
544>;
545
546// A derived attribute that returns the shapes of the tensors in the actual
547// value pack that corresponds to the `idx`-th ODS-declared variadic result.
548// This returns a list of shapes so it is used for variadic results that
549// can have different shapes.
550class TF_DerivedResultShapeListAttr<int idx> : DerivedAttr<
551  "mlir::TF::ResultShapeRange",
552  "auto values = getODSResults(" # idx # ");\n"
553  "return {mlir::TF::ResultShapeIterator(values.begin()), "
554          "mlir::TF::ResultShapeIterator(values.end())};",
555  [{
556    ArrayAttr::get($_ctx,
557      [&](){
558        llvm::SmallVector<Attribute, 4> ret;
559        for (auto shape : $_self)
560          ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape));
561        return ret;
562      }())
563  }]
564>;
565
566// A derived attribute that returns the shape of the first result type.
567def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType",
568  "return (*getOperation()->result_type_begin()).cast<ShapedType>();",
569  [{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>;
570
571def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> {
572  let returnType = "Type";
573}
574
575//===----------------------------------------------------------------------===//
576// TensorFlow common builders
577//===----------------------------------------------------------------------===//
578
579// Mixin class defining a builder for binary ops supporting broadcast
580// behavior. The result type has the same element type as both operands.
581class WithBroadcastableBinOpBuilder {
582  list<OpBuilderDAG> builders = [
583    OpBuilderDAG<(ins "Value":$x, "Value":$y),
584    [{
585  auto resultType =
586      OpTrait::util::getBroadcastedType(x.getType(), y.getType());
587  if (!resultType)
588    mlir::emitError($_state.location, "non-broadcastable operands");
589  return build($_builder, $_state, resultType, x, y);
590}]>];
591}
592
593// Mixin class defining a builder for comparison ops supporting broadcast
594// behavior. The result type has bool element type.
595class WithBroadcastableCmpOpBuilder {
596  list<OpBuilderDAG> builders = [
597    OpBuilderDAG<(ins "Value":$x, "Value":$y),
598    [{
599  Type resultType;
600  if (x.getType().isa<UnrankedTensorType>() ||
601      y.getType().isa<UnrankedTensorType>()) {
602    resultType = UnrankedTensorType::get($_builder.getI1Type());
603  } else {
604    SmallVector<int64_t, 4> resultShape;
605    if (!OpTrait::util::getBroadcastedShape(
606            x.getType().cast<ShapedType>().getShape(),
607            y.getType().cast<ShapedType>().getShape(), resultShape)) {
608      mlir::emitError($_state.location,
609                      "operands have no broadcastable shapes");
610    }
611
612    resultType = RankedTensorType::get(resultShape, $_builder.getI1Type());
613  }
614  return build($_builder, $_state, resultType, x, y);
615}]>];
616}
617
618#endif // TF_OP_BASE
619