1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// This is the base operation definition file for TensorFlow. 17// 18// This file includes the definition for the TensorFlow dialect, base TensorFlow 19// op, and various commonly used TensorFlow traits, types, attributes, and 20// builders. 21 22#ifndef TF_OP_BASE 23#define TF_OP_BASE 24 25include "mlir/IR/OpBase.td" 26include "mlir/Interfaces/SideEffectInterfaces.td" 27include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td" 28 29//===----------------------------------------------------------------------===// 30// TensorFlow dialect definitions 31//===----------------------------------------------------------------------===// 32 33def TF_Dialect : Dialect { 34 let name = "tf"; 35 36 let description = [{ 37The TensorFlow dialect. 38 39This dialect maps to TensorFlow operations. 40 41Invariants: 42 43* All values are of Tensor type (in particular, scalars are 44 represented using zero-dimensional tensors); 45 46TODO: Make invariants more structured so that we can reference them in ops. 47 }]; 48 49 let cppNamespace = "::mlir::TF"; 50} 51 52//===----------------------------------------------------------------------===// 53// TensorFlow traits 54//===----------------------------------------------------------------------===// 55 56// Specify this trait if the op requires all outputs to have the same type and 57// the inputs either have the same type as result or a ref type corresponding to 58// the result type. 59def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait< 60 "TF::OperandsSameAsResultsTypeOrRef">; 61 62// Op has the same operand and result element types (or type itself, if scalar) 63// after resolving reference types (i.e., after converting reference types to 64// their corresponding TensorFlow or standard types). 65def TF_SameOperandsAndResultElementTypeResolveRef : NativeOpTrait< 66 "TF::SameOperandsAndResultElementTypeResolveRef">; 67 68// Op has the same operand and result types after resolving reference types 69// (i.e., after converting reference types to their corresponding TensorFlow or 70// standard types). 71def TF_SameOperandsAndResultTypeResolveRef : NativeOpTrait< 72 "TF::SameOperandsAndResultTypeResolveRef">; 73 74// Layout agnostic operations do not depend on the operands data layout (data 75// format), as an example all element wise operations are layout agnostic. 76def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">; 77 78// Trait to indicate operations that cannot be duplicated as they might carry 79// certain state around within their implementations. 80def TF_CannotDuplicate : NativeOpTrait<"TF::CannotDuplicate">; 81 82// Trait to indicate an operation cannot be constant folded. 83def TF_NoConstantFold : NativeOpTrait<"TF::NoConstantFold">; 84 85// Coefficient wise binary operation with implicit broadcasting support, for 86// example tf.Sub operation. 87def TF_CwiseBinary : NativeOpTrait<"TF::CwiseBinary">; 88 89// Coefficient wise unary operation, for example tf.Sqrt operation. 90def TF_CwiseUnary : NativeOpTrait<"TF::CwiseUnary">; 91 92// Variant of broadcastable trait that considers TF's subtype behavior. 93class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[ 94 TCOpResIsShapedTypePred<opId, resId>, 95 CPred<"mlir::tf_type::BroadcastCompatible(" 96 "$_op.getOperand(" # opId # ").getType(), " 97 "$_op.getResult(" # resId # ").getType())">]>; 98 99 100class TF_AllTypesMatchPred<list<string> values> : 101 CPred<"tf_type::AreCastCompatible(llvm::makeArrayRef({" # 102 !interleave(values, ", ") # "}))">; 103 104class TF_AllTypesMatch<list<string> names> : 105 PredOpTrait< 106 "all of {" # !interleave(names, ", ") # 107 "} have dynamically equal types ", 108 TF_AllTypesMatchPred< 109 !foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>; 110 111// This trait indicates that all returned resources are unique for a 112// resource-allocating op (i.e. op with `MemAlloc` side-effect). 113// 114// Note that if the trait is used where this invariant is not true, then this 115// might lead to incorrect execution order, while if not used where it should 116// be, it can only lead to reduced performance due to conservative ordering. 117// Example op where the invariant is not true: `TF_VarHandleOp`. 118def TF_UniqueResourceAllocation: OpTraitList<[ 119 TF_ResourceHandleAllocatorInterface, 120 NativeOpTrait<"TF::UniqueResourceAllocation"> 121]>; 122 123//===----------------------------------------------------------------------===// 124// Rank/Shape helpers. 125//===----------------------------------------------------------------------===// 126 127class TF_OperandIsUnrankedPred<int n> : 128 CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">; 129 130class TF_ResultIsUnrankedPred<int n> : 131 CPred<"$_op.getResult(" # n # ").getType().isa<UnrankedTensorType>()">; 132 133// Returns true if the n-th operand has unknown rank or has rank m. 134class TF_OperandHasRank<int n, int m> : 135 PredOpTrait<"operand " # n # " is " # m # "-D", 136 Or<[TF_OperandIsUnrankedPred<n>, 137 CPred<"$_op.getOperand(" # n # 138 ").getType().cast<ShapedType>().getRank() == " # m>]>>; 139 140// Returns true if the n-th result has unknown rank or has rank m. 141class TF_ResultHasRank<int n, int m> : 142 PredOpTrait<"result " # n # " is " # m # "-D", 143 Or<[TF_ResultIsUnrankedPred<n>, 144 CPred<"$_op.getResult(" # n # 145 ").getType().cast<ShapedType>().getRank() == " # m>]>>; 146 147//===----------------------------------------------------------------------===// 148// TensorFlow op side effects 149//===----------------------------------------------------------------------===// 150 151class TF_ResourceBase<string resourceKind> : 152 Resource<!strconcat("::mlir::TF::ResourceEffects::", resourceKind)> { 153} 154 155// Value-based side-effects 156def TF_VariableResource : TF_ResourceBase<"Variable">; 157def TF_StackResource : TF_ResourceBase<"Stack">; 158def TF_TensorArrayResource : TF_ResourceBase<"TensorArray">; 159def TF_SummaryResource : TF_ResourceBase<"Summary">; 160def TF_LookupTableResource : TF_ResourceBase<"LookupTable">; 161def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">; 162def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">; 163def TF_DatasetIteratorResource : TF_ResourceBase<"DatasetIterator">; 164def TF_TPUEmbeddingResource : TF_ResourceBase<"TPUEmbedding">; 165def TF_GeneratorOpResource : TF_ResourceBase<"GeneratorOp">; 166 167def TF_VariableRead : MemRead<TF_VariableResource>; 168def TF_StackRead : MemRead<TF_StackResource>; 169def TF_TensorArrayRead : MemRead<TF_TensorArrayResource>; 170def TF_LookupTableRead : MemRead<TF_LookupTableResource>; 171def TF_DatasetSeedGeneratorRead : MemRead<TF_DatasetSeedGeneratorResource>; 172def TF_DatasetMemoryCacheRead : MemRead<TF_DatasetMemoryCacheResource>; 173def TF_DatasetIteratorRead : MemRead<TF_DatasetIteratorResource>; 174 175def TF_VariableWrite : MemWrite<TF_VariableResource>; 176def TF_StackWrite : MemWrite<TF_StackResource>; 177def TF_TensorArrayWrite : MemWrite<TF_TensorArrayResource>; 178def TF_SummaryWrite : MemWrite<TF_SummaryResource>; 179def TF_LookupTableWrite : MemWrite<TF_LookupTableResource>; 180def TF_DatasetSeedGeneratorWrite : MemWrite<TF_DatasetSeedGeneratorResource>; 181def TF_DatasetMemoryCacheWrite : MemWrite<TF_DatasetMemoryCacheResource>; 182def TF_DatasetIteratorWrite : MemWrite<TF_DatasetIteratorResource>; 183 184def TF_VariableAlloc : MemAlloc<TF_VariableResource>; 185def TF_StackAlloc : MemAlloc<TF_StackResource>; 186def TF_TensorArrayAlloc : MemAlloc<TF_TensorArrayResource>; 187def TF_SummaryAlloc : MemAlloc<TF_SummaryResource>; 188def TF_LookupTableAlloc : MemAlloc<TF_LookupTableResource>; 189def TF_DatasetSeedGeneratorAlloc : MemAlloc<TF_DatasetSeedGeneratorResource>; 190def TF_DatasetMemoryCacheAlloc : MemAlloc<TF_DatasetMemoryCacheResource>; 191def TF_DatasetIteratorAlloc : MemAlloc<TF_DatasetIteratorResource>; 192 193def TF_StackFree : MemFree<TF_StackResource>; 194def TF_TensorArrayFree : MemFree<TF_TensorArrayResource>; 195def TF_SummaryFree : MemFree<TF_SummaryResource>; 196def TF_DatasetSeedGeneratorFree : MemFree<TF_DatasetSeedGeneratorResource>; 197def TF_DatasetMemoryCacheFree : MemFree<TF_DatasetMemoryCacheResource>; 198def TF_DatasetIteratorFree : MemFree<TF_DatasetIteratorResource>; 199 200// Op-based side-effects 201def TF_TPUEmbeddingSideEffect : MemoryEffects<[MemWrite<TF_TPUEmbeddingResource>]>; 202def TF_GeneratorOpSideEffect : MemoryEffects<[MemWrite<TF_GeneratorOpResource>]>; 203 204//===----------------------------------------------------------------------===// 205// TensorFlow op definitions 206//===----------------------------------------------------------------------===// 207 208class TF_Op<string mnemonic, list<OpTrait> traits = []> : 209 Op<TF_Dialect, mnemonic, traits>; 210 211//===----------------------------------------------------------------------===// 212// TensorFlow attribute definitions 213//===----------------------------------------------------------------------===// 214 215class TF_TensorFlowAttr <string name, string description> : 216 Attr<CPred<"$_self.isa<mlir::TF::" # name # "Attr>()">, 217 "TensorFlow " # description # " attribute">; 218 219def TF_ShapeAttr : TF_TensorFlowAttr<"Shape", "shape"> { 220 let returnType = "llvm::Optional<llvm::ArrayRef<int64_t>>"; 221 let convertFromStorage = "$_self.cast<mlir::TF::ShapeAttr>().getValue()"; 222 223 // Create a ranked shape attr by default. 224 let constBuilderCall = "mlir::TF::ShapeAttr::get($_builder.getContext(), $0)"; 225} 226 227def TF_ShapeAttrArray : 228 TypedArrayAttrBase<TF_ShapeAttr, "tensorflow shape attribute array">; 229 230//===----------------------------------------------------------------------===// 231// TensorFlow type definitions 232//===----------------------------------------------------------------------===// 233 234// Any tensor element type defined in the TensorFlow dialect 235def TF_TFDialectType : 236 Type<CPred<"$_self.isa<mlir::TF::TensorFlowType>()">, "TensorFlow type">; 237 238// Class for any TensorFlow dialect specific type 239class TF_TensorFlowType <string name, string description> : 240 Type<CPred<"$_self.isa<mlir::TF::" # name # "Type>()">, 241 "TensorFlow " # description # " type">, 242 BuildableType<"getType<mlir::TF::" # name # "Type>()">; 243 244//===----------------------------------------------------------------------===// 245// Reference types 246 247// Float reference types 248def TF_Float16Ref : TF_TensorFlowType<"HalfRef", "f16ref">; 249def TF_Float32Ref : TF_TensorFlowType<"FloatRef", "f32ref">; 250def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">; 251def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">; 252 253// Complex reference types 254def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">; 255def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">; 256 257// Integer reference types 258def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">; 259def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">; 260def TF_Int32Ref : TF_TensorFlowType<"Int32Ref", "i32ref">; 261def TF_Int64Ref : TF_TensorFlowType<"Int64Ref", "i64ref">; 262 263def TF_Uint8Ref : TF_TensorFlowType<"Uint8Ref", "ui8ref">; 264def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">; 265def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">; 266def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">; 267 268// Quantized reference types 269def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">; 270def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">; 271def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">; 272def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">; 273def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">; 274 275// Other reference types 276def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">; 277def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">; 278def TF_StrRef : TF_TensorFlowType<"StringRef", "stringref">; 279def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">; 280 281//===----------------------------------------------------------------------===// 282// Integer types (including corresponding reference types) 283 284def TF_Bool : AnyTypeOf<[I<1>, TF_BoolRef], "bool">; 285 286def TF_Int8 : AnyTypeOf<[I8, TF_Int8Ref], "8-bit integer">; 287def TF_Int16 : AnyTypeOf<[I16, TF_Int16Ref], "16-bit integer">; 288def TF_Int32 : AnyTypeOf<[I32, TF_Int32Ref], "32-bit integer">; 289def TF_Int64 : AnyTypeOf<[I64, TF_Int64Ref], "64-bit integer">; 290def TF_I32OrI64 : AnyTypeOf<[I32, I64, TF_Int32Ref, TF_Int64Ref], 291 "32/64-bit signed integer">; 292 293def TF_Uint8 : AnyTypeOf<[UI<8>, TF_Uint8Ref], "8-bit unsigned integer">; 294def TF_Uint16 : AnyTypeOf<[UI<16>, TF_Uint16Ref], "16-bit unsigned integer">; 295def TF_Uint32 : AnyTypeOf<[UI<32>, TF_Uint32Ref], "32-bit unsigned integer">; 296def TF_Uint64 : AnyTypeOf<[UI<64>, TF_Uint64Ref], "64-bit unsigned integer">; 297 298// Any unsigned integer type 299def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64], 300 "unsigned integer">; 301 302// Any signed integer type 303def TF_SInt : AnyTypeOf<[TF_Int8, TF_Int16, TF_Int32, TF_Int64], 304 "signed integer">; 305 306// Any integer type 307def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">; 308 309// Tensor types 310def TF_BoolTensor : TensorOf<[TF_Bool]>; 311 312def TF_IntTensor : TensorOf<[TF_Int]>; 313def TF_Int8Tensor : TensorOf<[TF_Int8]>; 314def TF_Int16Tensor : TensorOf<[TF_Int16]>; 315def TF_Int32Tensor : TensorOf<[TF_Int32]>; 316def TF_Int64Tensor : TensorOf<[TF_Int64]>; 317def TF_I32OrI64Tensor : TensorOf<[TF_I32OrI64]>; 318 319def TF_Uint8Tensor : TensorOf<[TF_Uint8]>; 320def TF_Uint16Tensor : TensorOf<[TF_Uint16]>; 321def TF_Uint32Tensor : TensorOf<[TF_Uint32]>; 322def TF_Uint64Tensor : TensorOf<[TF_Uint64]>; 323 324//===----------------------------------------------------------------------===// 325// Quantized types (including corresponding reference types) 326 327def TF_Qint8 : AnyTypeOf< 328 [TF_TensorFlowType<"Qint8", "qint8">, TF_Qint8Ref], 329 "8-bit quantized integer">; 330def TF_Qint16 : AnyTypeOf< 331 [TF_TensorFlowType<"Qint16", "qint16">, TF_Qint16Ref], 332 "16-bit quantized integer">; 333def TF_Qint32 : AnyTypeOf< 334 [TF_TensorFlowType<"Qint32", "qint32">, TF_Qint32Ref], 335 "32-bit quantized integer">; 336def TF_Quint8 : AnyTypeOf< 337 [TF_TensorFlowType<"Quint8", "quint8">, TF_Quint8Ref], 338 "8-bit quantized unsigned integer">; 339def TF_Quint16 : AnyTypeOf< 340 [TF_TensorFlowType<"Quint16", "quint16">, TF_Quint16Ref], 341 "16-bit quantized unsigned integer">; 342 343// Any quantized type 344def TF_Quantized : AnyTypeOf< 345 [TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, TF_Quint16], "quantized">; 346 347//===----------------------------------------------------------------------===// 348// Floating-point types (including corresponding reference types) 349 350def TF_Float16 : AnyTypeOf<[F16, TF_Float16Ref], "16-bit float">; 351def TF_Float32 : AnyTypeOf<[F32, TF_Float32Ref], "32-bit float">; 352def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">; 353def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">; 354 355def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">; 356 357def TF_Float : AnyTypeOf< 358 [TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16], 359 "floating-point">; 360 361// Tensor types 362def TF_FloatTensor : TensorOf<[TF_Float]>; 363def TF_F32OrF64Tensor : TensorOf<[TF_F32OrF64]>; 364def TF_Float16Tensor : TensorOf<[TF_Float16]>; 365def TF_Float32Tensor : TensorOf<[TF_Float32]>; 366def TF_Float64Tensor : TensorOf<[TF_Float64]>; 367def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>; 368 369//===----------------------------------------------------------------------===// 370// Complex types (including corresponding reference types) 371 372// TODO(suderman): Remove TF_Complex64 and use a standard ops declaration, along 373// with the associated cleanup. 374def TF_Complex64 : AnyTypeOf<[Complex<F<32>>, TF_Complex64Ref], 375 "64-bit complex">; 376def TF_Complex128 : AnyTypeOf<[Complex<F<64>>, TF_Complex128Ref], 377 "128-bit complex">; 378def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">; 379 380// Tensor types 381def TF_ComplexTensor : TensorOf<[TF_Complex]>; 382def TF_Complex64Tensor : TensorOf<[TF_Complex64]>; 383def TF_Complex128Tensor : TensorOf<[TF_Complex128]>; 384 385//===----------------------------------------------------------------------===// 386// String/variant/resource types (including corresponding reference types) 387 388def TF_Str : AnyTypeOf< 389 [TF_TensorFlowType<"String", "str">, TF_StrRef], "string">; 390def TF_StrTensor : TensorOf<[TF_Str]>; 391 392def TF_Variant : AnyTypeOf< 393 [TF_TensorFlowType<"Variant", "var">, TF_VariantRef], "variant">; 394def TF_VariantTensor : TensorOf<[TF_Variant]>; 395 396def TF_Resource : AnyTypeOf< 397 [TF_TensorFlowType<"Resource", "res">, TF_ResourceRef], "resource">; 398def TF_ResourceTensor : TensorOf<[TF_Resource]>; 399 400//===----------------------------------------------------------------------===// 401// Multi-category type constraints 402 403def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32OrF64]>; 404def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32OrI64]>; 405def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>; 406def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>; 407def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>; 408 409def TF_Number : AnyTypeOf< 410 [TF_Int, TF_Float, TF_Quantized, TF_Complex], "number">; 411def TF_NumberTensor : TensorOf<[TF_Number]>; 412 413def TF_NumberNotQuantizedOrStr : 414 AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Str]>; 415def TF_NumberNotQuantizedOrStrTensor : TensorOf<[TF_NumberNotQuantizedOrStr]>; 416 417//===----------------------------------------------------------------------===// 418// Tensor and tensor element types 419 420// Any tensor element type allowed in TensorFlow ops 421// (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType) 422def TF_ElementType : Type<Or<[TF_Float.predicate, 423 TF_Complex.predicate, 424 TF_Int.predicate, 425 TF_Bool.predicate, 426 TF_TFDialectType.predicate]>, 427 "tf.dtype">; 428 429// Any TensorFlow tensor type 430def TF_Tensor : TensorOf<[TF_ElementType]>; 431 432//===----------------------------------------------------------------------===// 433// TensorFlow attribute definitions 434//===----------------------------------------------------------------------===// 435 436//===----------------------------------------------------------------------===// 437// Tensorflow devices metadata 438 439// Tensorflow GPU device metadata. 440def TF_GpuDeviceMetadata : StructAttr<"GpuDeviceMetadata", TF_Dialect, [ 441 // GPU device compute capability: major:minor. 442 StructFieldAttr<"cc_major", I32Attr>, 443 StructFieldAttr<"cc_minor", I32Attr> 444]>; 445 446//===----------------------------------------------------------------------===// 447// String attribute constraints 448 449// A string attribute whose value are one of the values in `cases`. 450class TF_AnyStrAttrOf<list<string> cases> : StringBasedAttr< 451 CPred<!foldl( 452 "$_self.cast<StringAttr>().getValue() == \"" # !head(cases) # "\"", 453 !foreach(case, !tail(cases), 454 "$_self.cast<StringAttr>().getValue() == \"" # case # "\""), 455 prev, cur, prev # " || " # cur)>, 456 "string attribute whose value is " # 457 !foldl(/*init*/!head(cases), /*list*/!tail(cases), 458 prev, cur, prev # ", or " # cur)>; 459 460// TODO: Use EnumAttr to define the common attribute cases 461 462def TF_ConvnetDataFormatAttr : StringBasedAttr< 463 CPred<"$_self.cast<StringAttr>().getValue() == \"NHWC\" || " # 464 "$_self.cast<StringAttr>().getValue() == \"NCHW\"">, 465 "'NHWC' or 'NCHW' convnet data format">; 466 467//===----------------------------------------------------------------------===// 468// Type attributes 469 470// A derived attribute that returns the size of `idx`-th ODS-declared variadic 471// operand. 472class TF_DerivedOperandSizeAttr<int idx> : DerivedAttr< 473 "size_t", 474 "auto range = getODSOperands(" # idx # ");\n" 475 "return std::distance(range.begin(), range.end());", 476 [{ $_builder.getI64IntegerAttr($_self) }]>; 477 478// A derived attribute that returns the element type of `idx`-th ODS-declared 479// operand. If the `idx`-th operand is a variadic operand, then this attribute 480// just returns the element type of its first tensor, which is only meaningful 481// when the variadic operand has at least one tensor and the tensors all have 482// the same element type. 483class TF_DerivedOperandTypeAttr<int idx> : DerivedTypeAttr< 484 "return mlir::getElementTypeOrSelf(*getODSOperands(" # idx # ").begin());">; 485 486// A derived attribute that returns the element types of the tensors in the 487// actual value pack that corresponds to the `idx`-th ODS-declared variadic 488// operand. This returns a list of element types so it is used for variadic 489// operands that can have different element types. 490class TF_DerivedOperandTypeListAttr<int idx> : DerivedAttr< 491 "mlir::OperandElementTypeRange", 492 "auto values = getODSOperands(" # idx # ");\n" 493 "return {mlir::OperandElementTypeIterator(values.begin()), " 494 "mlir::OperandElementTypeIterator(values.end())};", 495 [{ 496 ArrayAttr::get($_ctx, 497 [&]() { 498 llvm::SmallVector<Attribute, 4> ret; 499 for (auto t : $_self) 500 ret.push_back(TypeAttr::get(t)); 501 return ret; 502 }()) 503 }] 504>; 505 506// A derived attribute that returns the shapes of the tensors in the actual 507// value pack that corresponds to the `idx`-th ODS-declared variadic operand. 508// This returns a list of shapes so it is used for variadic operands that 509// can have different shapes. 510class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr< 511 "::mlir::TF::OperandShapeRange", 512 "auto values = getODSOperands(" # idx # ");\n" 513 "return {mlir::TF::OperandShapeIterator(values.begin()), " 514 "mlir::TF::OperandShapeIterator(values.end())};", 515 [{ 516 ArrayAttr::get($_ctx, 517 [&](){ 518 llvm::SmallVector<Attribute, 4> ret; 519 for (auto shape : $_self) 520 ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape)); 521 return ret; 522 }()) 523 }] 524>; 525 526// A derived attribute that returns the size of `idx`-th ODS-declared variadic 527// result. 528class TF_DerivedResultSizeAttr<int idx> : DerivedAttr< 529 "size_t", 530 "auto range = getODSResults(" # idx # ");\n" 531 "return std::distance(range.begin(), range.end());", 532 [{ $_builder.getI64IntegerAttr($_self) }]>; 533 534// A derived attribute that returns the element type of `idx`-th ODS-declared 535// result. If the `idx`-th result is a variadic result, then this attribute 536// just returns the element type of its first tensor, which is only meaningful 537// when the variadic result has at least one tensor and the tensors all have 538// the same element type. 539class TF_DerivedResultTypeAttr<int idx> : DerivedTypeAttr< 540 "return mlir::getElementTypeOrSelf(*getODSResults(" # idx # ").begin());">; 541 542// A derived attribute that returns the element types of the tensors in the 543// actual value pack that corresponds to the `idx`-th ODS-declared variadic 544// result. This returns a list of element types so it is used for variadic 545// results that can have different element types. 546class TF_DerivedResultTypeListAttr<int idx> : DerivedAttr< 547 "mlir::ResultElementTypeRange", 548 "auto values = getODSResults(" # idx # ");\n" 549 "return {mlir::ResultElementTypeIterator(values.begin()), " 550 "mlir::ResultElementTypeIterator(values.end())};", 551 [{ 552 ArrayAttr::get($_ctx, 553 [&]() { 554 llvm::SmallVector<Attribute, 4> ret; 555 for (auto t : $_self) 556 ret.push_back(TypeAttr::get(t)); 557 return ret; 558 }()) 559 }] 560>; 561 562// A derived attribute that returns the shapes of the tensors in the actual 563// value pack that corresponds to the `idx`-th ODS-declared variadic result. 564// This returns a list of shapes so it is used for variadic results that 565// can have different shapes. 566class TF_DerivedResultShapeListAttr<int idx> : DerivedAttr< 567 "mlir::TF::ResultShapeRange", 568 "auto values = getODSResults(" # idx # ");\n" 569 "return {mlir::TF::ResultShapeIterator(values.begin()), " 570 "mlir::TF::ResultShapeIterator(values.end())};", 571 [{ 572 ArrayAttr::get($_ctx, 573 [&](){ 574 llvm::SmallVector<Attribute, 4> ret; 575 for (auto shape : $_self) 576 ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape)); 577 return ret; 578 }()) 579 }] 580>; 581 582// A derived attribute that returns the shape of the first result type. 583def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType", 584 "return (*getOperation()->result_type_begin()).cast<ShapedType>();", 585 [{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>; 586 587def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> { 588 let returnType = "Type"; 589} 590 591//===----------------------------------------------------------------------===// 592// TensorFlow common builders 593//===----------------------------------------------------------------------===// 594 595// Mixin class defining a builder for binary ops supporting broadcast 596// behavior. The result type has the same element type as both operands. 597class WithBroadcastableBinOpBuilder { 598 list<OpBuilder> builders = [ 599 OpBuilder<(ins "Value":$x, "Value":$y), 600 [{ 601 auto resultType = 602 OpTrait::util::getBroadcastedType(x.getType(), y.getType()); 603 if (!resultType) 604 mlir::emitError($_state.location, "non-broadcastable operands"); 605 return build($_builder, $_state, resultType, x, y); 606}]>]; 607} 608 609// Mixin class defining a builder for comparison ops supporting broadcast 610// behavior. The result type has bool element type. 611class WithBroadcastableCmpOpBuilder { 612 list<OpBuilder> builders = [ 613 OpBuilder<(ins "Value":$x, "Value":$y), 614 [{ 615 Type resultType; 616 if (x.getType().isa<UnrankedTensorType>() || 617 y.getType().isa<UnrankedTensorType>()) { 618 resultType = UnrankedTensorType::get($_builder.getI1Type()); 619 } else { 620 SmallVector<int64_t, 4> resultShape; 621 if (!OpTrait::util::getBroadcastedShape( 622 x.getType().cast<ShapedType>().getShape(), 623 y.getType().cast<ShapedType>().getShape(), resultShape)) { 624 mlir::emitError($_state.location, 625 "operands have no broadcastable shapes"); 626 } 627 628 resultType = RankedTensorType::get(resultShape, $_builder.getI1Type()); 629 } 630 return build($_builder, $_state, resultType, x, y); 631}]>]; 632} 633 634#endif // TF_OP_BASE 635