1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// This is the operation definition file for TFR 17 18#ifndef DIALECT_TFR_OPS_ 19#define DIALECT_TFR_OPS_ 20 21include "mlir/Dialect/Shape/IR/ShapeBase.td" 22include "mlir/IR/OpBase.td" 23include "mlir/IR/FunctionInterfaces.td" 24include "mlir/Dialect/Quant/QuantOpsBase.td" 25include "mlir/IR/SymbolInterfaces.td" 26include "mlir/Interfaces/CallInterfaces.td" 27include "mlir/Interfaces/ControlFlowInterfaces.td" 28include "mlir/Interfaces/InferTypeOpInterface.td" 29include "mlir/Interfaces/SideEffectInterfaces.td" 30include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" 31 32//===----------------------------------------------------------------------===// 33// Dialect 34//===----------------------------------------------------------------------===// 35 36def TFR_Dialect : Dialect { 37 let name = "tfr"; 38 39 let description = [{ 40 The TensorFlow Composition dialect. 41 }]; 42 43 let cppNamespace = "::mlir::TFR"; 44 45 let emitAccessorPrefix = kEmitAccessorPrefix_Raw; 46} 47 48//===----------------------------------------------------------------------===// 49// Type classes 50//===----------------------------------------------------------------------===// 51 52// tensor argument types 53class TFR_Type<string name> : DialectType<TFR_Dialect, 54 CPred<"$_self.isa<mlir::TFR::" # name # "Type>()">, 55 "TFR " # name #" type">, 56 BuildableType<"$_builder.getType<mlir::TFR::" # name # "Type>()">; 57def TFR_TensorType : TFR_Type<"TFRTensor">; 58def TFR_TensorListType : TFR_Type<"TFRTensorList">; 59def TFR_AllTensorTypes : Type<Or<[ 60 TFR_TensorType.predicate, 61 TFR_TensorListType.predicate]>, "all tensor related types">; 62 63// attribute argument types 64def TFR_AttrType : TFR_Type<"TFRAttr">; 65def TFR_AttrScalarType: TypeAlias<TF_ElementType, "scalar attribute">; 66def TFR_AttrVectorType : VectorOf<[TF_ElementType, TFR_AttrType]>; 67def TFR_AllAttrTypes : Type<Or<[ 68 TFR_AttrType.predicate, 69 Index.predicate, 70 TFR_AttrScalarType.predicate, 71 TFR_AttrVectorType.predicate]>, "all attribute related types">; 72 73// all allowed arguments types 74def TFR_allowedArgType : Type<Or<[ 75 TFR_AllTensorTypes.predicate, 76 TFR_AllAttrTypes.predicate]>, "allowed tfr.call operand types">; 77 78def TFR_allowedConstValues : Attr<Or<[ 79 FlatSymbolRefAttr.predicate, 80 TypeAttr.predicate, 81 StrAttr.predicate, 82 ArrayAttr.predicate]>, "allowed tfr.constant value"> { 83 let storageType = "Attribute"; 84 let returnType = "Attribute"; 85 let convertFromStorage = "$_self"; 86 let constBuilderCall = "$0"; 87} 88 89// all allowed result types 90def TFR_allowedResultType : TypeAlias<TFR_AllTensorTypes, 91 "allowed tfr.call result types">; 92 93// standard tensor type and tfr.tensor types can be casted to each other. 94def TFR_singleTensorType : Type<Or<[ 95 TFR_TensorType.predicate, 96 TF_Tensor.predicate, 97 TensorOf<[quant_QuantizedType]>.predicate]>, "single tensor or tfr.tensor type">; 98 99// all allowed build list input types 100def TFR_allowedBuiltListType : Type<Or<[ 101 TFR_TensorType.predicate, 102 TF_ElementType.predicate, 103 TFR_AttrType.predicate]>, "single tfr.tensor or tensor element type">; 104 105// all allowed build list result types 106def TFR_allowedListResultType : Type<Or<[ 107 TFR_TensorListType.predicate, 108 TFR_AttrType.predicate]>, "tfr.tensor_list or tfr.attr type">; 109 110//===----------------------------------------------------------------------===// 111// Op classes 112//===----------------------------------------------------------------------===// 113 114class TFR_Op<string mnemonic, list<Trait> traits> : 115 Op<TFR_Dialect, mnemonic, traits>; 116 117def TFR_CallOp : TFR_Op<"call", [CallOpInterface]> { 118 let description = [{ 119 The `call` operation represents a direct call to a function that is within 120 the same symbol scope as the callee. The operands and result types of the 121 call must match the specified function type. The callee is encoded as a 122 symbol reference attribute named "callee". 123 124 Example: 125 126 ```mlir 127 %2 = tfr.call @my_add(%0, %1) : (tfr.tensor, f32) -> tfr.tensor_list 128 ``` 129 130 Note that the operands of the `call` operation can only be with tfr.tensor, 131 tfr.tensor_list, tfr.attr and mlir float and integer types. The results of 132 the `call` operation can only be with tfr.tensor and tfr.tensor_list types. 133 }]; 134 135 let arguments = (ins 136 FlatSymbolRefAttr:$callee, 137 Variadic<TFR_allowedArgType>:$args); 138 139 let results = (outs 140 Variadic<TFR_allowedResultType>:$outs); 141 142 let extraClassDeclaration = [{ 143 StringRef getCallee() { return callee(); } 144 145 // Get the argument operands to the called function. 146 operand_range getArgOperands() { return args(); } 147 148 // Return the callee of this operation. 149 CallInterfaceCallable getCallableForCallee() { return calleeAttr(); } 150 }]; 151 152 let assemblyFormat = [{ 153 $callee `(` $args `)` attr-dict `:` functional-type($args, results) 154 }]; 155} 156 157def TFR_CastOp : TFR_Op<"cast", [NoSideEffect]> { 158 let description = [{ 159 The `cast` operation converts the operand with built-in tensor type to 160 tfr.tensor type, or vice versa. 161 162 Example: 163 164 ```mlir 165 %1 = tfr.cast(%0) : tensor<f32> -> !tfr.tensor 166 %3 = tfr.cast(%1) : !tfr.tensor -> tensor<f32> 167 ``` 168 }]; 169 170 let arguments = (ins TFR_singleTensorType:$arg); 171 172 let results = (outs TFR_singleTensorType:$out); 173 174 let extraClassDeclaration = [{ 175 // Return element type of the input tensor type. Only available when the 176 // input is a MLIR built-in tensor type. 177 Attribute getInputElementType() { 178 if (auto ty = arg().getType().dyn_cast<TensorType>()) { 179 return TypeAttr::get(ty.getElementType()); 180 } 181 return {}; 182 } 183 }]; 184 185 let hasCanonicalizer = 1; 186} 187 188def TFR_GetShapeOp : TFR_Op<"get_shape", [NoSideEffect]> { 189 let description = [{ 190 The `get_shape` operation gets the shape of a tfr.tensor and returns 191 !shape.shape type. 192 193 Example: 194 195 ```mlir 196 %1 = "tfr.get_shape"(%0) : !tfr.tensor -> !shape.shape 197 %1 = tfr.get_shape %0 -> !shape.shape 198 ``` 199 }]; 200 201 let arguments = (ins TFR_TensorType:$arg); 202 203 let results = (outs Shape_ShapeType:$out); 204 205 let assemblyFormat = "$arg attr-dict `->` type($out)"; 206 207 let hasCanonicalizer = 1; 208} 209 210def TFR_GetElementTypeOp : TFR_Op<"get_element_type", [NoSideEffect]> { 211 let description = [{ 212 The `get_element_type` operation gets the element type of a tfr.tensor and 213 returns !tfr.attr. 214 215 Example: 216 217 ```mlir 218 %1 = "tfr.get_element_type"(%0) : !tfr.tensor -> !tfr.attr 219 %1 = tfr.get_element_type %0 -> !tfr.attr 220 ``` 221 }]; 222 223 let arguments = (ins TFR_TensorType:$arg); 224 225 let results = (outs TFR_AttrType:$out); 226 227 let assemblyFormat = "$arg attr-dict `->` type($out)"; 228} 229 230def TFR_EqualOp : TFR_Op<"equal", [NoSideEffect, SameTypeOperands]> { 231 let description = [{ 232 The `equal` operation compares the values of the tfr.attr type arguments. 233 The operation returns an i1 boolean indicating if the two values are the 234 same. 235 Example: 236 237 ```mlir 238 %x = tfr.equal %lhs, %rhs -> i1 239 %x = "tfr.equal"(%lhs, %rhs) : (!tfr.attr, !tfr.attr) -> i1 240 ``` 241 }]; 242 243 let arguments = (ins 244 TFR_AttrType:$lhs, 245 TFR_AttrType:$rhs 246 ); 247 let results = (outs BoolLike:$result); 248 249 let hasFolder = 1; 250 251 let assemblyFormat = "$lhs `,` $rhs attr-dict `->` type($result)"; 252} 253 254def TFR_ConstOp : TFR_Op<"constant", [ConstantLike, NoSideEffect]> { 255 let description = [{ 256 The `attr` operation stores TF op's attribute, which doesn't support 257 arithmetic operations. 258 259 Example: 260 261 ```mlir 262 %1 = "tfr.constant"() { value: i32 } : () -> !tfr.attr 263 %2 = "tfr.constant"() { value: [i32, f32] } : () -> !tfr.attr 264 %3 = tfr.constant [i32, f32] -> !tfr.attr 265 %4 = tfr.constant f32 -> !tfr.attr 266 ``` 267 }]; 268 269 let arguments = (ins TFR_allowedConstValues:$value); 270 271 let results = (outs TFR_AttrType:$out); 272 273 let hasFolder = 1; 274 275 let builders = [ 276 OpBuilder<(ins "Attribute":$value), 277 [{ 278 auto* ctx = value.getContext(); 279 $_state.addAttribute("value", value); 280 $_state.addTypes(TFRAttrType::get(ctx)); 281 }]> 282 ]; 283 284 let assemblyFormat = [{ 285 $value attr-dict `->` type($out) 286 }]; 287} 288 289def TFR_ConstantTensorOp : TFR_Op<"constant_tensor", [NoSideEffect]> { 290 let description = [{ 291 The `constant_tensor` operation converts the operand with non-built-in 292 tensor type to built-in tensor type or tfr.tensor type. If it is built-in 293 tensor type, the shape shouldn't be changed during the conversion. 294 295 Example: 296 297 ```mlir 298 %1 = tfr.constant_tensor(%0) : f32 -> tensor<f32> 299 %3 = tfr.constant_tensor(%2) : vector<1xf32> -> tensor<1xf32> 300 ``` 301 }]; 302 303 let arguments = (ins TFR_AllAttrTypes:$arg); 304 305 let results = (outs TFR_singleTensorType:$out); 306 307 let hasCanonicalizer = 1; 308 309 let hasVerifier = 1; 310} 311 312def TFR_GetElementOp : TFR_Op<"get_element", [NoSideEffect]> { 313 let description = [{ 314 The `get_element` operation extracts one tfr.tensor element from a 315 tfr.tensor_list. 316 317 Example: 318 319 ```mlir 320 %2 = tfr.get_element %1[%0] : (tfr.tensor, index) -> tfr.tensor 321 ``` 322 }]; 323 324 let arguments = (ins 325 TFR_TensorListType:$tensor_list, 326 Index:$index); 327 328 let results = (outs TFR_TensorType:$out); 329 330 let hasCanonicalizer = 1; 331 332 let assemblyFormat = [{ 333 $tensor_list `[` $index `]` attr-dict `:` 334 `(` type($tensor_list) `,` type($index) `)` `->` type($out) 335 }]; 336} 337 338def TFR_BuildListOp : TFR_Op<"build_list", [NoSideEffect]> { 339 let description = [{ 340 The `build_list` operation builds a tensor list from a list of tensors, or 341 an tfr.attr from a list of scalars. 342 343 Example: 344 345 ```mlir 346 %3 = tfr.build_list(%2, %1, %0) : 347 (tfr.tensor, tfr.tensor, tfr.tensor) -> tfr.tensor_list 348 %3 = tfr.build_list(%2, %1, %0) : (i32, i32, i32) -> tfr.attr 349 ``` 350 }]; 351 352 let arguments = (ins Variadic<TFR_allowedBuiltListType>:$tensors); 353 354 let results = (outs TFR_allowedListResultType:$out); 355 356 let hasCanonicalizer = 1; 357} 358 359def TFR_GetLengthOp : TFR_Op<"get_length", [NoSideEffect]> { 360 let description = [{ 361 The `get_length` operation returns the number of tensors for a 362 tfr.tensor_list. 363 364 Example: 365 366 ```mlir 367 %2 = tfr.get_length(%1) : tfr.tensor -> index 368 %2 = tfr.get_length %1 -> index 369 ``` 370 }]; 371 372 let arguments = (ins TFR_TensorListType:$tensor_list); 373 374 let results = (outs Index:$out); 375 376 let hasCanonicalizer = 1; 377 378 let assemblyFormat = [{ 379 $tensor_list attr-dict `->` type($out) 380 }]; 381} 382 383//===----------------------------------------------------------------------===// 384// Function related classes 385//===----------------------------------------------------------------------===// 386 387def TFR_TFRFuncOp : TFR_Op<"func", [HasParent<"ModuleOp">, 388 DeclareOpInterfaceMethods<CallableOpInterface>, 389 FunctionOpInterface, 390 IsolatedFromAbove, Symbol]> { 391 let summary = "TFR Function defines a composition of other ops"; 392 393 let description = [{ 394 Defines a function that can be used to decompose an TF function call to 395 the invocation of a set of other TF ops. 396 397 Syntax: 398 399 ``` 400 op ::= `tfr.func` visibility? symbol-ref-id `(` argument-list `)` (`->` 401 function-result-list)? function-attributes? region 402 ``` 403 404 Example: 405 406 ```mlir 407 tfr.func @foo(%arg0: !tfr.tensor, %arg1: !tfr.tensor_list<T>, 408 %arg2: int {tfr.name="T", tfr.default=1}) 409 attributes {qux: "quux"} { 410 tfr.return 411 } 412 ``` 413 414 Note the arguments are ordered by the following rule: 415 tfr.tensor > tfr.tensor_list > tfr.attr/i32/..., 416 and only one trfr.tensor_list argument is allowed. 417 }]; 418 419 let arguments = (ins 420 TypeAttrOf<FunctionType>:$function_type, 421 StrAttr:$sym_name 422 ); 423 424 let results = (outs); 425 426 // When the regions is empty, the tfr.func is an external function and used 427 // to model the element type constraints of the tf op. Otherwise, there is one 428 // region containing the composition. 429 let regions = (region VariadicRegion<AnyRegion>:$body); 430 431 let skipDefaultBuilders = 1; 432 433 let builders = [ 434 OpBuilder<(ins "StringRef":$name, "FunctionType":$type, 435 CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)> 436 ]; 437 438 let extraClassDeclaration = [{ 439 /// Returns the type of this function. 440 /// FIXME: We should drive this via the ODS `type` param. 441 FunctionType getFunctionType() { 442 return getFunctionTypeAttr().getValue().cast<FunctionType>(); 443 } 444 LogicalResult verifyType() { return success(); } 445 446 /// Returns the argument types of this function. 447 ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); } 448 449 /// Returns the result types of this function. 450 ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); } 451 452 // Get the names of all defined attributes, including both derived and 453 // non-derived ones. 454 llvm::StringSet<> getDefinedAttributeNames() { 455 llvm::StringSet<> all_attrs; 456 for (auto& attr : (*this)->getAttrs()) { 457 all_attrs.insert(attr.getName().strref()); 458 } 459 for (const auto& operand : llvm::enumerate(getArgumentTypes())) { 460 if (auto attr_name = getArgAttrOfType<StringAttr>( 461 operand.index(), kAttrArgumentNameAttr)) { 462 all_attrs.insert(attr_name.getValue()); 463 } 464 } 465 return all_attrs; 466 } 467 }]; 468 469 let hasVerifier = 1; 470 let hasCustomAssemblyFormat = 1; 471} 472 473def TFR_TFRReturnOp : TFR_Op<"return", [HasParent<"TFRFuncOp">, NoSideEffect, 474 ReturnLike, Terminator]> { 475 let description = [{ 476 A terminator operation for regions that appear in the body of `tfr.func` 477 functions. The operands to the `tfr.return` are the result values returned 478 by an invocation of the `tfr.func`. 479 480 Note that only the tfr.tensor and tfr.tensor_list can be returned. 481 }]; 482 483 let arguments = (ins Variadic<TFR_allowedResultType>:$operands); 484 485 let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; 486} 487 488//===----------------------------------------------------------------------===// 489// Quantization related operations 490//===----------------------------------------------------------------------===// 491 492def TFR_TFRQuantActRangeOp : TFR_Op<"quant_act_range", [NoSideEffect]> { 493 let description = [{ 494 The `quant_act_range` returns the a pair of integers to indicate the fixed 495 range for the fused activation `act` with the quantization defined by the 496 `scale` and `zero point`. Currently, the allowed activations are 497 `NONE`, `RELU`, `RELU6` and `RELU_N1_TO_1`. 498 499 Example: 500 501 ```mlir 502 %3, %4 = tfr.quant_act_range(%2, %1, %0) : 503 (tfr.attr, float, i64) -> (tfr.tensor, tfr.tensor) 504 ``` 505 }]; 506 507 let arguments = (ins 508 TFR_AttrType:$act, 509 F32:$scale, 510 I64:$zp); 511 512 let results = (outs TFR_TensorType:$min, TFR_TensorType:$max); 513 514 let assemblyFormat = [{ 515 `(` $act `,` $scale `,` $zp `)` attr-dict `:` functional-type(operands, results) 516 }]; 517} 518 519def TFR_TFRQuantRescaleOp : TFR_Op<"quant_rescale", [NoSideEffect]> { 520 let description = [{ 521 The `quant_rescale` rescales the elements of the integer tensor by the 522 floating-point rescale factor. This op needs to be legalized to the preferred 523 operations of the backends. 524 525 Example: 526 527 ```mlir 528 %3 = tfr.quant_rescale(%2, %1, %0) : 529 (tfr.tensor, tfr.tensor, i64) -> (tfr.tensor) 530 ``` 531 }]; 532 533 let arguments = (ins 534 TFR_TensorType:$input, 535 TFR_TensorType:$scale, 536 I64:$zp); 537 538 let results = (outs TFR_TensorType:$output); 539 540 let assemblyFormat = [{ 541 `(` $input `,` $scale `,` $zp `)` attr-dict `:` functional-type(operands, results) 542 }]; 543 544 let hasCanonicalizer = 1; 545} 546 547def TFR_TFRQuantRawDataOp : TFR_Op<"quant_raw_data", [ 548 NoSideEffect, 549 SameOperandsAndResultType]> { 550 let description = [{ 551 The `quant_raw_data` removes the quantization parameter from the intput 552 tensor(s). 553 554 Example: 555 556 ```mlir 557 %3 = tfr.quant_raw_data(%0) : (tfr.tensor) -> (tfr.tensor) 558 ``` 559 }]; 560 561 let arguments = (ins TFR_AllTensorTypes:$input); 562 563 let results = (outs TFR_AllTensorTypes:$output); 564 565 let assemblyFormat = [{ 566 `(` $input `)` attr-dict `:` functional-type($input, results) 567 }]; 568 569 let hasCanonicalizer = 1; 570} 571 572def TFR_TFRQuantQParamsOp : TFR_Op<"quant_qparam", [NoSideEffect]> { 573 let description = [{ 574 The `quant_qparam` returns the quantization parameter of the input 575 tensors. 576 577 Example: 578 579 ```mlir 580 %3 = tfr.quant_qparam(%0) : (tfr.tensor) -> (float, tfr.tensor) 581 ``` 582 }]; 583 584 let arguments = (ins TFR_TensorType:$input); 585 586 let results = (outs TFR_TensorType:$scale, TFR_TensorType:$zp); 587 588 let assemblyFormat = [{ 589 `(` $input `)` attr-dict `:` functional-type($input, results) 590 }]; 591 592 let hasCanonicalizer = 1; 593} 594 595 596def TFR_TFRQuantScaleFactorOp : TFR_Op<"quant_scale_factor", [NoSideEffect]> { 597 let description = [{ 598 The `quant_scale_factor` computes the effective scale factor according to the 599 output scale and input scales. 600 601 Example: 602 603 ```mlir 604 %3 = tfr.quant_scale_factor(%0) : (f32, tfr.tensor_list) -> (tfr.tensor) 605 ``` 606 }]; 607 608 let arguments = (ins 609 F32:$out_scale, 610 TFR_TensorListType:$in_scales); 611 612 let results = (outs TFR_TensorType:$scale_factor); 613 614 let assemblyFormat = [{ 615 `(` $out_scale `,` $in_scales `)` attr-dict `:` functional-type(operands, results) 616 }]; 617 618 let hasCanonicalizer = 1; 619} 620 621#endif // DIALECT_TFR_OPS_ 622