1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_UTILS_H
17 #define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_UTILS_H
18
19 #include <climits>
20 #include <cstddef>
21 #include <cstdint>
22 #include <iterator>
23 #include <numeric>
24
25 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
26 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
27 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
28 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/PatternMatch.h" // from @llvm-project
31 #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
32 #include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project
33 #include "mlir/Support/LLVM.h" // from @llvm-project
34 #include "tensorflow/core/framework/kernel_shape_util.h"
35 #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
36 #include "tensorflow/core/util/padding.h"
37 #include "tensorflow/core/util/tensor_format.h"
38
39 namespace mlir {
40 namespace tosa {
41
42 // Create a TOSA rescale op from TFLite scaling, zero points and rounding mode
43 Value buildRescale(PatternRewriter& rewriter, Operation* op,
44 ShapedType output_type, Value input_val, double scale,
45 int64_t input_zp, int64_t output_zp, bool double_round,
46 bool scale32);
47
48 // Creates TOSA rescale op with int32 output
49 Value buildRescaleToInt32(PatternRewriter& rewriter, Operation* op,
50 Value input_val, double input_scale,
51 int64_t input_zp);
52
53 // Creates TOSA rescale op with int32 input
54 Value buildRescaleFromInt32(PatternRewriter& rewriter, Operation* op,
55 ShapedType output_type, Value input_val,
56 double output_scale, int64_t output_zp);
57
58 // Creates a TOSA rescale op based on conv2d parameters.
59 Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op,
60 Value conv_val, ShapedType input_type,
61 ShapedType weight_type, ShapedType output_type);
62
63 // Create a 8-bit TOSA TABLE constant tensor
64 Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op,
65 double input_scale, int32_t input_zp,
66 double output_scale, int32_t output_zp,
67 std::function<double(double)> func);
68
69 // Create a 16-bit TOSA TABLE constant tensor
70 Value getTosaConst16bitTable(PatternRewriter& rewriter, Operation* op,
71 std::function<double(double)> func, double min,
72 double max);
73
74 // Create a 32-bit TOSA TABLE constant tensor
75 // Output is restricted to [-1.0, 1.0] as s0.31 format
76 void getTosaConst32bitTable(PatternRewriter& rewriter, Operation* op,
77 double input_scale, int32_t input_zp,
78 std::function<double(double)> func,
79 Value& first_const, Value& second_const,
80 Value& third_const, Value& fourth_const);
81
82 // Create a 32-bit float constant operator from a float
83 Value getTosaConstTensorSingleF32(PatternRewriter& rewriter, Operation* op,
84 float val);
85
86 // Create a 32-bit integer constant operator from an int
87 Value getTosaConstTensorSingleI32(PatternRewriter& rewriter, Operation* op,
88 int32_t val);
89
90 // Create a vector from a 32-bit value tensor. Returns vector size on success
91 // or -1 on error.
92 LogicalResult getVectorFromValue32(Value val, SmallVectorImpl<int32_t>& vec);
93
94 // Calculates the TOSA padding values based on TF operators padded with
95 // SAME/VALID.
96 bool getPaddingValuesFromPadType(tensorflow::Padding tf_pad,
97 tensorflow::TensorFormat data_format_tf,
98 uint32_t first_filter_spatial_dim,
99 ShapedType input_type, ShapedType filter_type,
100 ArrayAttr strides, ArrayAttr dilations,
101 PatternRewriter& rewriter,
102 ArrayAttr& explicit_pad);
103
104 // Calculates the TOSA padding values for explicit-padded TF operators.
105 ArrayAttr getPaddingValuesFromExplicitPadAttr(
106 ArrayAttr explicit_pad, tensorflow::TensorFormat data_format_tf,
107 PatternRewriter& rewriter);
108
109 // Calculates the TOSA padding values for transposeConv2d
110 bool getTransposeConv2dPaddingValues(
111 tensorflow::Padding tf_pad, tensorflow::TensorFormat data_format_tf,
112 uint32_t first_filter_spatial_dim, ShapedType input_type,
113 ShapedType filter_type, ShapedType output_type, ArrayAttr strides,
114 PatternRewriter& rewriter, ArrayAttr& explicit_pad);
115
116 // Templated function to create a constant op for given type and shape.
117 // T: storage C type.
118 // Default template creates a constant tensor in T.
119 // To create INT48 TOSA constant, need to pass in llvm::APInt instead.
120 template <typename T>
121 llvm::Optional<Value> getConstTensor(PatternRewriter& rewriter, Operation* op,
122 ArrayRef<T> vec, ArrayRef<int64_t> shape);
123
124 // Check if scale32 mode is used for given output_element_type
125 bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
126
127 // Applies a set of patterns greedily to the specified function, then applies
128 // a cleanup to guarantee the function contract and constants are valid. This
129 // means patterns can performed shape inference while not altering immutable
130 // types.
131 LogicalResult ApplyPatternsWithShapeResolution(
132 func::FuncOp func, const FrozenRewritePatternSet& patterns);
133
134 // Creates a TOSA operation and performs shape inference on the individual
135 // op. This allows shape inference during the TFLite to TOSA lowering.
136 template <typename TosaOp, typename... Args>
CreateOpAndInfer(PatternRewriter & rewriter,Location loc,Type result_ty,Args &&...args)137 TosaOp CreateOpAndInfer(PatternRewriter& rewriter, Location loc, Type result_ty,
138 Args&&... args) {
139 auto op = rewriter.create<TosaOp>(loc, result_ty, args...);
140
141 InferShapedTypeOpInterface shapeInterface =
142 dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
143 if (!shapeInterface) return op;
144
145 SmallVector<ShapedTypeComponents> returnedShapes;
146 if (shapeInterface
147 .inferReturnTypeComponents(op.getContext(), op.getLoc(),
148 op->getOperands(), op->getAttrDictionary(),
149 op->getRegions(), returnedShapes)
150 .failed())
151 return op;
152
153 // We need to use the element type of the existing result type to generate
154 // the new result shaped type. This is because rescale can include a cast to
155 // different bit-width types and does not have a TypeAttr to define the
156 // target type.
157 auto result = op->getResult(0);
158 auto predictedShape = returnedShapes[0];
159 auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(result_ty);
160
161 // Compute the knowledge based on the inferred type.
162 auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
163 inferredKnowledge.dtype = result_ty.cast<ShapedType>().getElementType();
164 inferredKnowledge.hasRank = predictedShape.hasRank();
165 if (predictedShape.hasRank()) {
166 for (auto dim : predictedShape.getDims()) {
167 inferredKnowledge.sizes.push_back(dim);
168 }
169 }
170
171 // Compute the new type based on the joined version.
172 auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
173 auto new_ty = newKnowledge.getType();
174 result.setType(new_ty);
175 return op;
176 }
177
178 template <typename TosaOp, typename... Args>
CreateReplaceOpAndInfer(PatternRewriter & rewriter,Operation * op,Type result_ty,Args &&...args)179 void CreateReplaceOpAndInfer(PatternRewriter& rewriter, Operation* op,
180 Type result_ty, Args&&... args) {
181 auto result =
182 CreateOpAndInfer<TosaOp>(rewriter, op->getLoc(), result_ty, args...);
183 rewriter.replaceOp(op, result->getResults());
184 }
185
186 } // namespace tosa
187 } // namespace mlir
188
189 #endif // TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_UTILS_H
190