1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// This is the base operation definition file for TensorFlow. 17// 18// This file includes the definition for the TensorFlow dialect, base TensorFlow 19// op, and various commonly used TensorFlow traits, types, attributes, and 20// builders. 21 22#ifndef TF_OP_BASE 23#define TF_OP_BASE 24 25include "mlir/IR/OpBase.td" 26include "mlir/Interfaces/SideEffectInterfaces.td" 27include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td" 28 29//===----------------------------------------------------------------------===// 30// TensorFlow dialect definitions 31//===----------------------------------------------------------------------===// 32 33def TF_Dialect : Dialect { 34 let name = "tf"; 35 36 let description = [{ 37The TensorFlow dialect. 38 39This dialect maps to TensorFlow operations. 40 41Invariants: 42 43* All values are of Tensor type (in particular, scalars are 44 represented using zero-dimensional tensors); 45 46TODO: Make invariants more structured so that we can reference them in ops. 47 }]; 48 49 let cppNamespace = "::mlir::TF"; 50} 51 52//===----------------------------------------------------------------------===// 53// TensorFlow traits 54//===----------------------------------------------------------------------===// 55 56// Specify this trait if the op requires all outputs to have the same type and 57// the inputs either have the same type as result or a ref type corresponding to 58// the result type. 59def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait< 60 "TF::OperandsSameAsResultsTypeOrRef">; 61 62// Op has the same operand and result element types (or type itself, if scalar) 63// after resolving reference types (i.e., after converting reference types to 64// their corresponding TensorFlow or standard types). 65def TF_SameOperandsAndResultElementTypeResolveRef : NativeOpTrait< 66 "TF::SameOperandsAndResultElementTypeResolveRef">; 67 68// Op has the same operand and result types after resolving reference types 69// (i.e., after converting reference types to their corresponding TensorFlow or 70// standard types). 71def TF_SameOperandsAndResultTypeResolveRef : NativeOpTrait< 72 "TF::SameOperandsAndResultTypeResolveRef">; 73 74// Layout agnostic operations do not depend on the operands data layout (data 75// format), as an example all element wise operations are layout agnostic. 76def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">; 77 78// Trait to indicate operations that cannot be duplicated as they might carry 79// certain state around within their implementations. 80def TF_CannotDuplicate : NativeOpTrait<"TF::CannotDuplicate">; 81 82// Trait to indicate an operation cannot be constant folded. 83def TF_NoConstantFold : NativeOpTrait<"TF::NoConstantFold">; 84 85// Coefficient wise binary operation with implicit broadcasting support, for 86// example tf.Sub operation. 87def TF_CwiseBinary : NativeOpTrait<"TF::CwiseBinary">; 88 89// Coefficient wise unary operation, for example tf.Sqrt operation. 90def TF_CwiseUnary : NativeOpTrait<"TF::CwiseUnary">; 91 92// Variant of broadcastable trait that considers TF's subtype behavior. 93class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[ 94 TCOpResIsShapedTypePred<opId, resId>, 95 CPred<"mlir::TF::BroadcastCompatible(" 96 "$_op.getOperand(" # opId # ").getType(), " 97 "$_op.getResult(" # resId # ").getType())">]>; 98 99 100class TF_AllTypesMatchPred<list<string> values> : 101 CPred<"TF::AreCastCompatible(llvm::makeArrayRef({" # 102 !interleave(values, ", ") # "}))">; 103 104class TF_AllTypesMatch<list<string> names> : 105 PredOpTrait< 106 "all of {" # !interleave(names, ", ") # 107 "} have dynamically equal types ", 108 TF_AllTypesMatchPred< 109 !foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>; 110 111//===----------------------------------------------------------------------===// 112// Rank/Shape helpers. 113//===----------------------------------------------------------------------===// 114 115class TF_OperandIsUnrankedPred<int n> : 116 CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">; 117 118class TF_ResultIsUnrankedPred<int n> : 119 CPred<"$_op.getResult(" # n # ").getType().isa<UnrankedTensorType>()">; 120 121// Returns true if the n-th operand has unknown rank or has rank m. 122class TF_OperandHasRank<int n, int m> : 123 PredOpTrait<"operand " # n # " is " # m # "-D", 124 Or<[TF_OperandIsUnrankedPred<n>, 125 CPred<"$_op.getOperand(" # n # 126 ").getType().cast<ShapedType>().getRank() == " # m>]>>; 127 128// Returns true if the n-th result has unknown rank or has rank m. 129class TF_ResultHasRank<int n, int m> : 130 PredOpTrait<"result " # n # " is " # m # "-D", 131 Or<[TF_ResultIsUnrankedPred<n>, 132 CPred<"$_op.getResult(" # n # 133 ").getType().cast<ShapedType>().getRank() == " # m>]>>; 134 135//===----------------------------------------------------------------------===// 136// TensorFlow op side effects 137//===----------------------------------------------------------------------===// 138 139class TF_ResourceBase<string resourceKind> : 140 Resource<!strconcat("::mlir::TF::ResourceEffects::", resourceKind)> { 141} 142 143def TF_VariableResource : TF_ResourceBase<"Variable">; 144def TF_StackResource : TF_ResourceBase<"Stack">; 145def TF_TensorArrayResource : TF_ResourceBase<"TensorArray">; 146def TF_SummaryResource : TF_ResourceBase<"Summary">; 147def TF_LookupTableResource : TF_ResourceBase<"LookupTable">; 148def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">; 149def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">; 150def TF_DatasetIteratorResource : TF_ResourceBase<"DatasetIterator">; 151def TF_TPUEmbeddingResource : TF_ResourceBase<"TPUEmbedding">; 152 153def TF_VariableRead : MemRead<TF_VariableResource>; 154def TF_StackRead : MemRead<TF_StackResource>; 155def TF_TensorArrayRead : MemRead<TF_TensorArrayResource>; 156def TF_LookupTableRead : MemRead<TF_LookupTableResource>; 157def TF_DatasetSeedGeneratorRead : MemRead<TF_DatasetSeedGeneratorResource>; 158def TF_DatasetMemoryCacheRead : MemRead<TF_DatasetMemoryCacheResource>; 159def TF_DatasetIteratorRead : MemRead<TF_DatasetIteratorResource>; 160 161def TF_VariableWrite : MemWrite<TF_VariableResource>; 162def TF_StackWrite : MemWrite<TF_StackResource>; 163def TF_TensorArrayWrite : MemWrite<TF_TensorArrayResource>; 164def TF_SummaryWrite : MemWrite<TF_SummaryResource>; 165def TF_LookupTableWrite : MemWrite<TF_LookupTableResource>; 166def TF_DatasetSeedGeneratorWrite : MemWrite<TF_DatasetSeedGeneratorResource>; 167def TF_DatasetMemoryCacheWrite : MemWrite<TF_DatasetMemoryCacheResource>; 168def TF_DatasetIteratorWrite : MemWrite<TF_DatasetIteratorResource>; 169 170def TF_VariableAlloc : MemAlloc<TF_VariableResource>; 171def TF_StackAlloc : MemAlloc<TF_StackResource>; 172def TF_TensorArrayAlloc : MemAlloc<TF_TensorArrayResource>; 173def TF_SummaryAlloc : MemAlloc<TF_SummaryResource>; 174def TF_LookupTableAlloc : MemAlloc<TF_LookupTableResource>; 175def TF_DatasetSeedGeneratorAlloc : MemAlloc<TF_DatasetSeedGeneratorResource>; 176def TF_DatasetMemoryCacheAlloc : MemAlloc<TF_DatasetMemoryCacheResource>; 177def TF_DatasetIteratorAlloc : MemAlloc<TF_DatasetIteratorResource>; 178 179def TF_StackFree : MemFree<TF_StackResource>; 180def TF_TensorArrayFree : MemFree<TF_TensorArrayResource>; 181def TF_SummaryFree : MemFree<TF_SummaryResource>; 182def TF_DatasetSeedGeneratorFree : MemFree<TF_DatasetSeedGeneratorResource>; 183def TF_DatasetMemoryCacheFree : MemFree<TF_DatasetMemoryCacheResource>; 184def TF_DatasetIteratorFree : MemFree<TF_DatasetIteratorResource>; 185 186def TF_TPUEmbeddingSideEffect : MemoryEffects<[MemWrite<TF_TPUEmbeddingResource>]>; 187 188//===----------------------------------------------------------------------===// 189// TensorFlow op definitions 190//===----------------------------------------------------------------------===// 191 192class TF_Op<string mnemonic, list<OpTrait> traits = []> : 193 Op<TF_Dialect, mnemonic, traits>; 194 195//===----------------------------------------------------------------------===// 196// TensorFlow attribute definitions 197//===----------------------------------------------------------------------===// 198 199class TF_TensorFlowAttr <string name, string description> : 200 Attr<CPred<"$_self.isa<mlir::TF::" # name # "Attr>()">, 201 "TensorFlow " # description # " attribute">; 202 203def TF_ShapeAttr : TF_TensorFlowAttr<"Shape", "shape"> { 204 let returnType = "llvm::Optional<llvm::ArrayRef<int64_t>>"; 205 let convertFromStorage = "$_self.cast<mlir::TF::ShapeAttr>().getValue()"; 206 207 // Create a ranked shape attr by default. 208 let constBuilderCall = "mlir::TF::ShapeAttr::get($_builder.getContext(), $0)"; 209} 210 211def TF_ShapeAttrArray : 212 TypedArrayAttrBase<TF_ShapeAttr, "tensorflow shape attribute array">; 213 214//===----------------------------------------------------------------------===// 215// TensorFlow type definitions 216//===----------------------------------------------------------------------===// 217 218// Any tensor element type defined in the TensorFlow dialect 219def TF_TFDialectType : 220 Type<CPred<"$_self.isa<mlir::TF::TensorFlowType>()">, "TensorFlow type">; 221 222// Class for any TensorFlow dialect specific type 223class TF_TensorFlowType <string name, string description> : 224 Type<CPred<"$_self.isa<mlir::TF::" # name # "Type>()">, 225 "TensorFlow " # description # " type">, 226 BuildableType<"getType<mlir::TF::" # name # "Type>()">; 227 228//===----------------------------------------------------------------------===// 229// Reference types 230 231// Float reference types 232def TF_Float16Ref : TF_TensorFlowType<"HalfRef", "f16ref">; 233def TF_Float32Ref : TF_TensorFlowType<"FloatRef", "f32ref">; 234def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">; 235def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">; 236 237// Complex reference types 238def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">; 239def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">; 240 241// Integer reference types 242def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">; 243def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">; 244def TF_Int32Ref : TF_TensorFlowType<"Int32Ref", "i32ref">; 245def TF_Int64Ref : TF_TensorFlowType<"Int64Ref", "i64ref">; 246 247def TF_Uint8Ref : TF_TensorFlowType<"Uint8Ref", "ui8ref">; 248def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">; 249def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">; 250def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">; 251 252// Quantized reference types 253def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">; 254def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">; 255def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">; 256def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">; 257def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">; 258 259// Other reference types 260def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">; 261def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">; 262def TF_StrRef : TF_TensorFlowType<"StringRef", "stringref">; 263def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">; 264 265//===----------------------------------------------------------------------===// 266// Integer types (including corresponding reference types) 267 268def TF_Bool : AnyTypeOf<[I<1>, TF_BoolRef], "bool">; 269 270def TF_Int8 : AnyTypeOf<[I8, TF_Int8Ref], "8-bit integer">; 271def TF_Int16 : AnyTypeOf<[I16, TF_Int16Ref], "16-bit integer">; 272def TF_Int32 : AnyTypeOf<[I32, TF_Int32Ref], "32-bit integer">; 273def TF_Int64 : AnyTypeOf<[I64, TF_Int64Ref], "64-bit integer">; 274def TF_I32OrI64 : AnyTypeOf<[I32, I64, TF_Int32Ref, TF_Int64Ref], 275 "32/64-bit signed integer">; 276 277def TF_Uint8 : AnyTypeOf<[UI<8>, TF_Uint8Ref], "8-bit unsigned integer">; 278def TF_Uint16 : AnyTypeOf<[UI<16>, TF_Uint16Ref], "16-bit unsigned integer">; 279def TF_Uint32 : AnyTypeOf<[UI<32>, TF_Uint32Ref], "32-bit unsigned integer">; 280def TF_Uint64 : AnyTypeOf<[UI<64>, TF_Uint64Ref], "64-bit unsigned integer">; 281 282// Any unsigned integer type 283def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64], 284 "unsigned integer">; 285 286// Any signed integer type 287def TF_SInt : AnyTypeOf<[TF_Int8, TF_Int16, TF_Int32, TF_Int64], 288 "signed integer">; 289 290// Any integer type 291def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">; 292 293// Tensor types 294def TF_BoolTensor : TensorOf<[TF_Bool]>; 295 296def TF_IntTensor : TensorOf<[TF_Int]>; 297def TF_Int8Tensor : TensorOf<[TF_Int8]>; 298def TF_Int16Tensor : TensorOf<[TF_Int16]>; 299def TF_Int32Tensor : TensorOf<[TF_Int32]>; 300def TF_Int64Tensor : TensorOf<[TF_Int64]>; 301def TF_I32OrI64Tensor : TensorOf<[TF_I32OrI64]>; 302 303def TF_Uint8Tensor : TensorOf<[TF_Uint8]>; 304def TF_Uint16Tensor : TensorOf<[TF_Uint16]>; 305def TF_Uint32Tensor : TensorOf<[TF_Uint32]>; 306def TF_Uint64Tensor : TensorOf<[TF_Uint64]>; 307 308//===----------------------------------------------------------------------===// 309// Quantized types (including corresponding reference types) 310 311def TF_Qint8 : AnyTypeOf< 312 [TF_TensorFlowType<"Qint8", "qint8">, TF_Qint8Ref], 313 "8-bit quantized integer">; 314def TF_Qint16 : AnyTypeOf< 315 [TF_TensorFlowType<"Qint16", "qint16">, TF_Qint16Ref], 316 "16-bit quantized integer">; 317def TF_Qint32 : AnyTypeOf< 318 [TF_TensorFlowType<"Qint32", "qint32">, TF_Qint32Ref], 319 "32-bit quantized integer">; 320def TF_Quint8 : AnyTypeOf< 321 [TF_TensorFlowType<"Quint8", "quint8">, TF_Quint8Ref], 322 "8-bit quantized unsigned integer">; 323def TF_Quint16 : AnyTypeOf< 324 [TF_TensorFlowType<"Quint16", "quint16">, TF_Quint16Ref], 325 "16-bit quantized unsigned integer">; 326 327// Any quantized type 328def TF_Quantized : AnyTypeOf< 329 [TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, TF_Quint16], "quantized">; 330 331//===----------------------------------------------------------------------===// 332// Floating-point types (including corresponding reference types) 333 334def TF_Float16 : AnyTypeOf<[F16, TF_Float16Ref], "16-bit float">; 335def TF_Float32 : AnyTypeOf<[F32, TF_Float32Ref], "32-bit float">; 336def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">; 337def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">; 338 339def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">; 340 341def TF_Float : AnyTypeOf< 342 [TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16], 343 "floating-point">; 344 345// Tensor types 346def TF_FloatTensor : TensorOf<[TF_Float]>; 347def TF_F32OrF64Tensor : TensorOf<[TF_F32OrF64]>; 348def TF_Float16Tensor : TensorOf<[TF_Float16]>; 349def TF_Float32Tensor : TensorOf<[TF_Float32]>; 350def TF_Float64Tensor : TensorOf<[TF_Float64]>; 351def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>; 352 353//===----------------------------------------------------------------------===// 354// Complex types (including corresponding reference types) 355 356// TODO(suderman): Remove TF_Complex64 and use a standard ops declaration, along 357// with the associated cleanup. 358def TF_Complex64 : AnyTypeOf<[Complex<F<32>>, TF_Complex64Ref], 359 "64-bit complex">; 360def TF_Complex128 : AnyTypeOf<[Complex<F<64>>, TF_Complex128Ref], 361 "128-bit complex">; 362def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">; 363 364// Tensor types 365def TF_ComplexTensor : TensorOf<[TF_Complex]>; 366def TF_Complex64Tensor : TensorOf<[TF_Complex64]>; 367def TF_Complex128Tensor : TensorOf<[TF_Complex128]>; 368 369//===----------------------------------------------------------------------===// 370// String/variant/resource types (including corresponding reference types) 371 372def TF_Str : AnyTypeOf< 373 [TF_TensorFlowType<"String", "str">, TF_StrRef], "string">; 374def TF_StrTensor : TensorOf<[TF_Str]>; 375 376def TF_Variant : AnyTypeOf< 377 [TF_TensorFlowType<"Variant", "var">, TF_VariantRef], "variant">; 378def TF_VariantTensor : TensorOf<[TF_Variant]>; 379 380def TF_Resource : AnyTypeOf< 381 [TF_TensorFlowType<"Resource", "res">, TF_ResourceRef], "resource">; 382def TF_ResourceTensor : TensorOf<[TF_Resource]>; 383 384//===----------------------------------------------------------------------===// 385// Multi-category type constraints 386 387def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32OrF64]>; 388def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32OrI64]>; 389def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>; 390def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>; 391def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>; 392 393def TF_Number : AnyTypeOf< 394 [TF_Int, TF_Float, TF_Quantized, TF_Complex], "number">; 395def TF_NumberTensor : TensorOf<[TF_Number]>; 396 397def TF_NumberNotQuantizedOrStr : 398 AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Str]>; 399def TF_NumberNotQuantizedOrStrTensor : TensorOf<[TF_NumberNotQuantizedOrStr]>; 400 401//===----------------------------------------------------------------------===// 402// Tensor and tensor element types 403 404// Any tensor element type allowed in TensorFlow ops 405// (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType) 406def TF_ElementType : Type<Or<[TF_Float.predicate, 407 TF_Complex.predicate, 408 TF_Int.predicate, 409 TF_Bool.predicate, 410 TF_TFDialectType.predicate]>, 411 "tf.dtype">; 412 413// Any TensorFlow tensor type 414def TF_Tensor : TensorOf<[TF_ElementType]>; 415 416//===----------------------------------------------------------------------===// 417// TensorFlow attribute definitions 418//===----------------------------------------------------------------------===// 419 420//===----------------------------------------------------------------------===// 421// Tensorflow devices metadata 422 423// Tensorflow GPU device metadata. 424def TF_GpuDeviceMetadata : StructAttr<"GpuDeviceMetadata", TF_Dialect, [ 425 // GPU device compute capability: major:minor. 426 StructFieldAttr<"cc_major", I32Attr>, 427 StructFieldAttr<"cc_minor", I32Attr> 428]>; 429 430//===----------------------------------------------------------------------===// 431// String attribute constraints 432 433// A string attribute whose value are one of the values in `cases`. 434class TF_AnyStrAttrOf<list<string> cases> : StringBasedAttr< 435 CPred<!foldl( 436 "$_self.cast<StringAttr>().getValue() == \"" # !head(cases) # "\"", 437 !foreach(case, !tail(cases), 438 "$_self.cast<StringAttr>().getValue() == \"" # case # "\""), 439 prev, cur, prev # " || " # cur)>, 440 "string attribute whose value is " # 441 !foldl(/*init*/!head(cases), /*list*/!tail(cases), 442 prev, cur, prev # ", or " # cur)>; 443 444// TODO: Use EnumAttr to define the common attribute cases 445 446def TF_ConvnetDataFormatAttr : StringBasedAttr< 447 CPred<"$_self.cast<StringAttr>().getValue() == \"NHWC\" || " # 448 "$_self.cast<StringAttr>().getValue() == \"NCHW\"">, 449 "'NHWC' or 'NCHW' convnet data format">; 450 451//===----------------------------------------------------------------------===// 452// Type attributes 453 454// A derived attribute that returns the size of `idx`-th ODS-declared variadic 455// operand. 456class TF_DerivedOperandSizeAttr<int idx> : DerivedAttr< 457 "size_t", 458 "auto range = getODSOperands(" # idx # ");\n" 459 "return std::distance(range.begin(), range.end());", 460 [{ $_builder.getI64IntegerAttr($_self) }]>; 461 462// A derived attribute that returns the element type of `idx`-th ODS-declared 463// operand. If the `idx`-th operand is a variadic operand, then this attribute 464// just returns the element type of its first tensor, which is only meaningful 465// when the variadic operand has at least one tensor and the tensors all have 466// the same element type. 467class TF_DerivedOperandTypeAttr<int idx> : DerivedTypeAttr< 468 "return mlir::getElementTypeOrSelf(*getODSOperands(" # idx # ").begin());">; 469 470// A derived attribute that returns the element types of the tensors in the 471// actual value pack that corresponds to the `idx`-th ODS-declared variadic 472// operand. This returns a list of element types so it is used for variadic 473// operands that can have different element types. 474class TF_DerivedOperandTypeListAttr<int idx> : DerivedAttr< 475 "mlir::OperandElementTypeRange", 476 "auto values = getODSOperands(" # idx # ");\n" 477 "return {mlir::OperandElementTypeIterator(values.begin()), " 478 "mlir::OperandElementTypeIterator(values.end())};", 479 [{ 480 ArrayAttr::get($_ctx, 481 [&]() { 482 llvm::SmallVector<Attribute, 4> ret; 483 for (auto t : $_self) 484 ret.push_back(TypeAttr::get(t)); 485 return ret; 486 }()) 487 }] 488>; 489 490// A derived attribute that returns the shapes of the tensors in the actual 491// value pack that corresponds to the `idx`-th ODS-declared variadic operand. 492// This returns a list of shapes so it is used for variadic operands that 493// can have different shapes. 494class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr< 495 "::mlir::TF::OperandShapeRange", 496 "auto values = getODSOperands(" # idx # ");\n" 497 "return {mlir::TF::OperandShapeIterator(values.begin()), " 498 "mlir::TF::OperandShapeIterator(values.end())};", 499 [{ 500 ArrayAttr::get($_ctx, 501 [&](){ 502 llvm::SmallVector<Attribute, 4> ret; 503 for (auto shape : $_self) 504 ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape)); 505 return ret; 506 }()) 507 }] 508>; 509 510// A derived attribute that returns the size of `idx`-th ODS-declared variadic 511// result. 512class TF_DerivedResultSizeAttr<int idx> : DerivedAttr< 513 "size_t", 514 "auto range = getODSResults(" # idx # ");\n" 515 "return std::distance(range.begin(), range.end());", 516 [{ $_builder.getI64IntegerAttr($_self) }]>; 517 518// A derived attribute that returns the element type of `idx`-th ODS-declared 519// result. If the `idx`-th result is a variadic result, then this attribute 520// just returns the element type of its first tensor, which is only meaningful 521// when the variadic result has at least one tensor and the tensors all have 522// the same element type. 523class TF_DerivedResultTypeAttr<int idx> : DerivedTypeAttr< 524 "return mlir::getElementTypeOrSelf(*getODSResults(" # idx # ").begin());">; 525 526// A derived attribute that returns the element types of the tensors in the 527// actual value pack that corresponds to the `idx`-th ODS-declared variadic 528// result. This returns a list of element types so it is used for variadic 529// results that can have different element types. 530class TF_DerivedResultTypeListAttr<int idx> : DerivedAttr< 531 "mlir::ResultElementTypeRange", 532 "auto values = getODSResults(" # idx # ");\n" 533 "return {mlir::ResultElementTypeIterator(values.begin()), " 534 "mlir::ResultElementTypeIterator(values.end())};", 535 [{ 536 ArrayAttr::get($_ctx, 537 [&]() { 538 llvm::SmallVector<Attribute, 4> ret; 539 for (auto t : $_self) 540 ret.push_back(TypeAttr::get(t)); 541 return ret; 542 }()) 543 }] 544>; 545 546// A derived attribute that returns the shapes of the tensors in the actual 547// value pack that corresponds to the `idx`-th ODS-declared variadic result. 548// This returns a list of shapes so it is used for variadic results that 549// can have different shapes. 550class TF_DerivedResultShapeListAttr<int idx> : DerivedAttr< 551 "mlir::TF::ResultShapeRange", 552 "auto values = getODSResults(" # idx # ");\n" 553 "return {mlir::TF::ResultShapeIterator(values.begin()), " 554 "mlir::TF::ResultShapeIterator(values.end())};", 555 [{ 556 ArrayAttr::get($_ctx, 557 [&](){ 558 llvm::SmallVector<Attribute, 4> ret; 559 for (auto shape : $_self) 560 ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape)); 561 return ret; 562 }()) 563 }] 564>; 565 566// A derived attribute that returns the shape of the first result type. 567def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType", 568 "return (*getOperation()->result_type_begin()).cast<ShapedType>();", 569 [{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>; 570 571def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> { 572 let returnType = "Type"; 573} 574 575//===----------------------------------------------------------------------===// 576// TensorFlow common builders 577//===----------------------------------------------------------------------===// 578 579// Mixin class defining a builder for binary ops supporting broadcast 580// behavior. The result type has the same element type as both operands. 581class WithBroadcastableBinOpBuilder { 582 list<OpBuilderDAG> builders = [ 583 OpBuilderDAG<(ins "Value":$x, "Value":$y), 584 [{ 585 auto resultType = 586 OpTrait::util::getBroadcastedType(x.getType(), y.getType()); 587 if (!resultType) 588 mlir::emitError($_state.location, "non-broadcastable operands"); 589 return build($_builder, $_state, resultType, x, y); 590}]>]; 591} 592 593// Mixin class defining a builder for comparison ops supporting broadcast 594// behavior. The result type has bool element type. 595class WithBroadcastableCmpOpBuilder { 596 list<OpBuilderDAG> builders = [ 597 OpBuilderDAG<(ins "Value":$x, "Value":$y), 598 [{ 599 Type resultType; 600 if (x.getType().isa<UnrankedTensorType>() || 601 y.getType().isa<UnrankedTensorType>()) { 602 resultType = UnrankedTensorType::get($_builder.getI1Type()); 603 } else { 604 SmallVector<int64_t, 4> resultShape; 605 if (!OpTrait::util::getBroadcastedShape( 606 x.getType().cast<ShapedType>().getShape(), 607 y.getType().cast<ShapedType>().getShape(), resultShape)) { 608 mlir::emitError($_state.location, 609 "operands have no broadcastable shapes"); 610 } 611 612 resultType = RankedTensorType::get(resultShape, $_builder.getI1Type()); 613 } 614 return build($_builder, $_state, resultType, x, y); 615}]>]; 616} 617 618#endif // TF_OP_BASE 619