• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This is the operation definition file for TFR
17
18#ifndef DIALECT_TFR_OPS_
19#define DIALECT_TFR_OPS_
20
21include "mlir/Dialect/Shape/IR/ShapeBase.td"
22include "mlir/IR/OpBase.td"
23include "mlir/IR/FunctionInterfaces.td"
24include "mlir/Dialect/Quant/QuantOpsBase.td"
25include "mlir/IR/SymbolInterfaces.td"
26include "mlir/Interfaces/CallInterfaces.td"
27include "mlir/Interfaces/ControlFlowInterfaces.td"
28include "mlir/Interfaces/InferTypeOpInterface.td"
29include "mlir/Interfaces/SideEffectInterfaces.td"
30include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
31
32//===----------------------------------------------------------------------===//
33// Dialect
34//===----------------------------------------------------------------------===//
35
36def TFR_Dialect : Dialect {
37  let name = "tfr";
38
39  let description = [{
40    The TensorFlow Composition dialect.
41  }];
42
43  let cppNamespace = "::mlir::TFR";
44
45  let emitAccessorPrefix = kEmitAccessorPrefix_Raw;
46}
47
48//===----------------------------------------------------------------------===//
49// Type classes
50//===----------------------------------------------------------------------===//
51
52// tensor argument types
53class TFR_Type<string name> : DialectType<TFR_Dialect,
54    CPred<"$_self.isa<mlir::TFR::" # name # "Type>()">,
55    "TFR " # name #" type">,
56    BuildableType<"$_builder.getType<mlir::TFR::" # name # "Type>()">;
57def TFR_TensorType : TFR_Type<"TFRTensor">;
58def TFR_TensorListType : TFR_Type<"TFRTensorList">;
59def TFR_AllTensorTypes : Type<Or<[
60    TFR_TensorType.predicate,
61    TFR_TensorListType.predicate]>, "all tensor related types">;
62
63// attribute argument types
64def TFR_AttrType : TFR_Type<"TFRAttr">;
65def TFR_AttrScalarType: TypeAlias<TF_ElementType, "scalar attribute">;
66def TFR_AttrVectorType : VectorOf<[TF_ElementType, TFR_AttrType]>;
67def TFR_AllAttrTypes : Type<Or<[
68    TFR_AttrType.predicate,
69    Index.predicate,
70    TFR_AttrScalarType.predicate,
71    TFR_AttrVectorType.predicate]>, "all attribute related types">;
72
73// all allowed arguments types
74def TFR_allowedArgType : Type<Or<[
75    TFR_AllTensorTypes.predicate,
76    TFR_AllAttrTypes.predicate]>, "allowed tfr.call operand types">;
77
78def TFR_allowedConstValues : Attr<Or<[
79    FlatSymbolRefAttr.predicate,
80    TypeAttr.predicate,
81    StrAttr.predicate,
82    ArrayAttr.predicate]>, "allowed tfr.constant value"> {
83  let storageType = "Attribute";
84  let returnType = "Attribute";
85  let convertFromStorage = "$_self";
86  let constBuilderCall = "$0";
87}
88
89// all allowed result types
90def TFR_allowedResultType : TypeAlias<TFR_AllTensorTypes,
91  "allowed tfr.call result types">;
92
93// standard tensor type and tfr.tensor types can be casted to each other.
94def TFR_singleTensorType : Type<Or<[
95    TFR_TensorType.predicate,
96    TF_Tensor.predicate,
97    TensorOf<[quant_QuantizedType]>.predicate]>, "single tensor or tfr.tensor type">;
98
99// all allowed build list input types
100def TFR_allowedBuiltListType : Type<Or<[
101    TFR_TensorType.predicate,
102    TF_ElementType.predicate,
103    TFR_AttrType.predicate]>, "single tfr.tensor or tensor element type">;
104
105// all allowed build list result types
106def TFR_allowedListResultType : Type<Or<[
107    TFR_TensorListType.predicate,
108    TFR_AttrType.predicate]>, "tfr.tensor_list or tfr.attr type">;
109
110//===----------------------------------------------------------------------===//
111// Op classes
112//===----------------------------------------------------------------------===//
113
114class TFR_Op<string mnemonic, list<Trait> traits> :
115    Op<TFR_Dialect, mnemonic, traits>;
116
117def TFR_CallOp : TFR_Op<"call", [CallOpInterface]> {
118  let description = [{
119    The `call` operation represents a direct call to a function that is within
120    the same symbol scope as the callee. The operands and result types of the
121    call must match the specified function type. The callee is encoded as a
122    symbol reference attribute named "callee".
123
124    Example:
125
126    ```mlir
127    %2 = tfr.call @my_add(%0, %1) : (tfr.tensor, f32) -> tfr.tensor_list
128    ```
129
130    Note that the operands of the `call` operation can only be with tfr.tensor,
131    tfr.tensor_list, tfr.attr and mlir float and integer types. The results of
132    the `call` operation can only be with tfr.tensor and tfr.tensor_list types.
133  }];
134
135  let arguments = (ins
136    FlatSymbolRefAttr:$callee,
137    Variadic<TFR_allowedArgType>:$args);
138
139  let results = (outs
140    Variadic<TFR_allowedResultType>:$outs);
141
142  let extraClassDeclaration = [{
143    StringRef getCallee() { return callee(); }
144
145    // Get the argument operands to the called function.
146    operand_range getArgOperands() { return args(); }
147
148    // Return the callee of this operation.
149    CallInterfaceCallable getCallableForCallee() { return calleeAttr(); }
150  }];
151
152  let assemblyFormat = [{
153    $callee `(` $args `)` attr-dict `:` functional-type($args, results)
154  }];
155}
156
157def TFR_CastOp : TFR_Op<"cast", [NoSideEffect]> {
158  let description = [{
159    The `cast` operation converts the operand with built-in tensor type to
160    tfr.tensor type, or vice versa.
161
162    Example:
163
164    ```mlir
165    %1 = tfr.cast(%0) : tensor<f32> -> !tfr.tensor
166    %3 = tfr.cast(%1) : !tfr.tensor -> tensor<f32>
167    ```
168  }];
169
170  let arguments = (ins TFR_singleTensorType:$arg);
171
172  let results = (outs TFR_singleTensorType:$out);
173
174  let extraClassDeclaration = [{
175    // Return element type of the input tensor type. Only available when the
176    // input is a MLIR built-in tensor type.
177    Attribute getInputElementType() {
178      if (auto ty = arg().getType().dyn_cast<TensorType>()) {
179        return TypeAttr::get(ty.getElementType());
180      }
181      return {};
182    }
183  }];
184
185  let hasCanonicalizer = 1;
186}
187
188def TFR_GetShapeOp : TFR_Op<"get_shape", [NoSideEffect]> {
189  let description = [{
190    The `get_shape` operation gets the shape of a tfr.tensor and returns
191    !shape.shape type.
192
193    Example:
194
195    ```mlir
196    %1 = "tfr.get_shape"(%0) : !tfr.tensor -> !shape.shape
197    %1 = tfr.get_shape %0 -> !shape.shape
198    ```
199  }];
200
201  let arguments = (ins TFR_TensorType:$arg);
202
203  let results = (outs Shape_ShapeType:$out);
204
205  let assemblyFormat = "$arg attr-dict `->` type($out)";
206
207  let hasCanonicalizer = 1;
208}
209
210def TFR_GetElementTypeOp : TFR_Op<"get_element_type", [NoSideEffect]> {
211  let description = [{
212    The `get_element_type` operation gets the element type of a tfr.tensor and
213    returns !tfr.attr.
214
215    Example:
216
217    ```mlir
218    %1 = "tfr.get_element_type"(%0) : !tfr.tensor -> !tfr.attr
219    %1 = tfr.get_element_type %0 -> !tfr.attr
220    ```
221  }];
222
223  let arguments = (ins TFR_TensorType:$arg);
224
225  let results = (outs TFR_AttrType:$out);
226
227  let assemblyFormat = "$arg attr-dict `->` type($out)";
228}
229
230def TFR_EqualOp : TFR_Op<"equal", [NoSideEffect, SameTypeOperands]> {
231  let description = [{
232    The `equal` operation compares the values of the tfr.attr type arguments.
233    The operation returns an i1 boolean indicating if the two values are the
234    same.
235    Example:
236
237    ```mlir
238    %x = tfr.equal %lhs, %rhs -> i1
239    %x = "tfr.equal"(%lhs, %rhs) : (!tfr.attr, !tfr.attr) -> i1
240    ```
241  }];
242
243  let arguments = (ins
244      TFR_AttrType:$lhs,
245      TFR_AttrType:$rhs
246  );
247  let results = (outs BoolLike:$result);
248
249  let hasFolder = 1;
250
251  let assemblyFormat = "$lhs `,` $rhs attr-dict `->` type($result)";
252}
253
254def TFR_ConstOp : TFR_Op<"constant", [ConstantLike, NoSideEffect]> {
255  let description = [{
256    The `attr` operation stores TF op's attribute, which doesn't support
257    arithmetic operations.
258
259    Example:
260
261    ```mlir
262    %1 = "tfr.constant"() { value: i32 } : () -> !tfr.attr
263    %2 = "tfr.constant"() { value: [i32, f32] } : () -> !tfr.attr
264    %3 = tfr.constant [i32, f32] -> !tfr.attr
265    %4 = tfr.constant f32 -> !tfr.attr
266    ```
267  }];
268
269  let arguments = (ins TFR_allowedConstValues:$value);
270
271  let results = (outs TFR_AttrType:$out);
272
273  let hasFolder = 1;
274
275  let builders = [
276    OpBuilder<(ins "Attribute":$value),
277    [{
278      auto* ctx = value.getContext();
279      $_state.addAttribute("value", value);
280      $_state.addTypes(TFRAttrType::get(ctx));
281    }]>
282  ];
283
284  let assemblyFormat = [{
285    $value attr-dict `->` type($out)
286  }];
287}
288
289def TFR_ConstantTensorOp : TFR_Op<"constant_tensor", [NoSideEffect]> {
290  let description = [{
291    The `constant_tensor` operation converts the operand with non-built-in
292    tensor type to built-in tensor type or tfr.tensor type. If it is built-in
293    tensor type, the shape shouldn't be changed during the conversion.
294
295    Example:
296
297    ```mlir
298    %1 = tfr.constant_tensor(%0) : f32 -> tensor<f32>
299    %3 = tfr.constant_tensor(%2) : vector<1xf32> -> tensor<1xf32>
300    ```
301  }];
302
303  let arguments = (ins TFR_AllAttrTypes:$arg);
304
305  let results = (outs TFR_singleTensorType:$out);
306
307  let hasCanonicalizer = 1;
308
309  let hasVerifier = 1;
310}
311
312def TFR_GetElementOp : TFR_Op<"get_element", [NoSideEffect]> {
313  let description = [{
314    The `get_element` operation extracts one tfr.tensor element from a
315    tfr.tensor_list.
316
317    Example:
318
319    ```mlir
320    %2 = tfr.get_element %1[%0] : (tfr.tensor, index) -> tfr.tensor
321    ```
322  }];
323
324  let arguments = (ins
325    TFR_TensorListType:$tensor_list,
326    Index:$index);
327
328  let results = (outs TFR_TensorType:$out);
329
330  let hasCanonicalizer = 1;
331
332  let assemblyFormat = [{
333    $tensor_list `[` $index `]` attr-dict `:`
334      `(` type($tensor_list) `,` type($index) `)` `->` type($out)
335  }];
336}
337
338def TFR_BuildListOp : TFR_Op<"build_list", [NoSideEffect]> {
339  let description = [{
340   The `build_list` operation builds a tensor list from a list of tensors, or
341   an tfr.attr from a list of scalars.
342
343    Example:
344
345    ```mlir
346    %3 = tfr.build_list(%2, %1, %0) :
347      (tfr.tensor, tfr.tensor, tfr.tensor) -> tfr.tensor_list
348    %3 = tfr.build_list(%2, %1, %0) : (i32, i32, i32) -> tfr.attr
349    ```
350  }];
351
352  let arguments = (ins Variadic<TFR_allowedBuiltListType>:$tensors);
353
354  let results = (outs TFR_allowedListResultType:$out);
355
356  let hasCanonicalizer = 1;
357}
358
359def TFR_GetLengthOp : TFR_Op<"get_length", [NoSideEffect]> {
360  let description = [{
361    The `get_length` operation returns the number of tensors for a
362    tfr.tensor_list.
363
364    Example:
365
366    ```mlir
367    %2 = tfr.get_length(%1) : tfr.tensor -> index
368    %2 = tfr.get_length %1 -> index
369    ```
370  }];
371
372  let arguments = (ins TFR_TensorListType:$tensor_list);
373
374  let results = (outs Index:$out);
375
376  let hasCanonicalizer = 1;
377
378  let assemblyFormat = [{
379    $tensor_list attr-dict `->` type($out)
380  }];
381}
382
383//===----------------------------------------------------------------------===//
384// Function related classes
385//===----------------------------------------------------------------------===//
386
387def TFR_TFRFuncOp : TFR_Op<"func", [HasParent<"ModuleOp">,
388                                    DeclareOpInterfaceMethods<CallableOpInterface>,
389                                    FunctionOpInterface,
390                                    IsolatedFromAbove, Symbol]> {
391  let summary = "TFR Function defines a composition of other ops";
392
393  let description = [{
394    Defines a function that can be used to decompose an TF function call to
395    the invocation of a set of other TF ops.
396
397    Syntax:
398
399    ```
400    op ::= `tfr.func` visibility? symbol-ref-id `(` argument-list `)` (`->`
401    function-result-list)? function-attributes? region
402    ```
403
404    Example:
405
406    ```mlir
407    tfr.func @foo(%arg0: !tfr.tensor, %arg1: !tfr.tensor_list<T>,
408                  %arg2: int {tfr.name="T", tfr.default=1})
409        attributes {qux: "quux"} {
410      tfr.return
411    }
412    ```
413
414    Note the arguments are ordered by the following rule:
415      tfr.tensor > tfr.tensor_list > tfr.attr/i32/...,
416    and only one trfr.tensor_list argument is allowed.
417  }];
418
419  let arguments = (ins
420    TypeAttrOf<FunctionType>:$function_type,
421    StrAttr:$sym_name
422  );
423
424  let results = (outs);
425
426  // When the regions is empty, the tfr.func is an external function and used
427  // to model the element type constraints of the tf op. Otherwise, there is one
428  // region containing the composition.
429  let regions = (region VariadicRegion<AnyRegion>:$body);
430
431  let skipDefaultBuilders = 1;
432
433  let builders = [
434    OpBuilder<(ins "StringRef":$name, "FunctionType":$type,
435      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
436  ];
437
438  let extraClassDeclaration = [{
439    /// Returns the type of this function.
440    /// FIXME: We should drive this via the ODS `type` param.
441    FunctionType getFunctionType() {
442      return getFunctionTypeAttr().getValue().cast<FunctionType>();
443    }
444    LogicalResult verifyType() { return success(); }
445
446    /// Returns the argument types of this function.
447    ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
448
449    /// Returns the result types of this function.
450    ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
451
452    // Get the names of all defined attributes, including both derived and
453    // non-derived ones.
454    llvm::StringSet<> getDefinedAttributeNames() {
455      llvm::StringSet<> all_attrs;
456      for (auto& attr : (*this)->getAttrs()) {
457        all_attrs.insert(attr.getName().strref());
458      }
459      for (const auto& operand : llvm::enumerate(getArgumentTypes())) {
460        if (auto attr_name = getArgAttrOfType<StringAttr>(
461            operand.index(), kAttrArgumentNameAttr)) {
462          all_attrs.insert(attr_name.getValue());
463        }
464      }
465      return all_attrs;
466    }
467  }];
468
469  let hasVerifier = 1;
470  let hasCustomAssemblyFormat = 1;
471}
472
473def TFR_TFRReturnOp : TFR_Op<"return", [HasParent<"TFRFuncOp">, NoSideEffect,
474                                        ReturnLike, Terminator]> {
475  let description = [{
476    A terminator operation for regions that appear in the body of  `tfr.func`
477    functions. The operands to the `tfr.return` are the result values returned
478    by an invocation of the `tfr.func`.
479
480    Note that only the tfr.tensor and tfr.tensor_list can be returned.
481  }];
482
483  let arguments = (ins Variadic<TFR_allowedResultType>:$operands);
484
485  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
486}
487
488//===----------------------------------------------------------------------===//
489// Quantization related operations
490//===----------------------------------------------------------------------===//
491
492def TFR_TFRQuantActRangeOp : TFR_Op<"quant_act_range", [NoSideEffect]> {
493  let description = [{
494   The `quant_act_range` returns the a pair of integers to indicate the fixed
495   range for the fused activation `act` with the quantization defined by the
496   `scale` and `zero point`. Currently, the allowed activations are
497   `NONE`, `RELU`, `RELU6` and `RELU_N1_TO_1`.
498
499    Example:
500
501    ```mlir
502    %3, %4 = tfr.quant_act_range(%2, %1, %0) :
503        (tfr.attr, float, i64) -> (tfr.tensor, tfr.tensor)
504    ```
505  }];
506
507  let arguments = (ins
508      TFR_AttrType:$act,
509      F32:$scale,
510      I64:$zp);
511
512  let results = (outs TFR_TensorType:$min, TFR_TensorType:$max);
513
514  let assemblyFormat = [{
515    `(` $act `,` $scale `,` $zp `)` attr-dict `:` functional-type(operands, results)
516  }];
517}
518
519def TFR_TFRQuantRescaleOp : TFR_Op<"quant_rescale", [NoSideEffect]> {
520  let description = [{
521   The `quant_rescale` rescales the elements of the integer tensor by the
522   floating-point rescale factor. This op needs to be legalized to the preferred
523   operations of the backends.
524
525    Example:
526
527    ```mlir
528    %3 = tfr.quant_rescale(%2, %1, %0) :
529        (tfr.tensor, tfr.tensor, i64) -> (tfr.tensor)
530    ```
531  }];
532
533  let arguments = (ins
534      TFR_TensorType:$input,
535      TFR_TensorType:$scale,
536      I64:$zp);
537
538  let results = (outs TFR_TensorType:$output);
539
540  let assemblyFormat = [{
541    `(` $input `,` $scale `,` $zp `)` attr-dict `:` functional-type(operands, results)
542  }];
543
544  let hasCanonicalizer = 1;
545}
546
547def TFR_TFRQuantRawDataOp : TFR_Op<"quant_raw_data", [
548    NoSideEffect,
549    SameOperandsAndResultType]> {
550  let description = [{
551   The `quant_raw_data` removes the quantization parameter from the intput
552   tensor(s).
553
554    Example:
555
556    ```mlir
557    %3 = tfr.quant_raw_data(%0) : (tfr.tensor) -> (tfr.tensor)
558    ```
559  }];
560
561  let arguments = (ins TFR_AllTensorTypes:$input);
562
563  let results = (outs TFR_AllTensorTypes:$output);
564
565  let assemblyFormat = [{
566    `(` $input `)` attr-dict `:` functional-type($input, results)
567  }];
568
569  let hasCanonicalizer = 1;
570}
571
572def TFR_TFRQuantQParamsOp : TFR_Op<"quant_qparam", [NoSideEffect]> {
573  let description = [{
574   The `quant_qparam` returns the quantization parameter of the input
575   tensors.
576
577    Example:
578
579    ```mlir
580    %3 = tfr.quant_qparam(%0) : (tfr.tensor) -> (float, tfr.tensor)
581    ```
582  }];
583
584  let arguments = (ins TFR_TensorType:$input);
585
586  let results = (outs TFR_TensorType:$scale, TFR_TensorType:$zp);
587
588  let assemblyFormat = [{
589    `(` $input `)` attr-dict `:` functional-type($input, results)
590  }];
591
592  let hasCanonicalizer = 1;
593}
594
595
596def TFR_TFRQuantScaleFactorOp : TFR_Op<"quant_scale_factor", [NoSideEffect]> {
597  let description = [{
598   The `quant_scale_factor` computes the effective scale factor according to the
599   output scale and input scales.
600
601    Example:
602
603    ```mlir
604    %3 = tfr.quant_scale_factor(%0) : (f32, tfr.tensor_list) -> (tfr.tensor)
605    ```
606  }];
607
608  let arguments = (ins
609      F32:$out_scale,
610      TFR_TensorListType:$in_scales);
611
612  let results = (outs TFR_TensorType:$scale_factor);
613
614  let assemblyFormat = [{
615    `(` $out_scale `,` $in_scales `)` attr-dict `:` functional-type(operands, results)
616  }];
617
618  let hasCanonicalizer = 1;
619}
620
621#endif // DIALECT_TFR_OPS_
622