1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <string>
21
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringSet.h"
28 #include "llvm/ADT/Twine.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
32 #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
33 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
34 #include "mlir/IR/Attributes.h" // from @llvm-project
35 #include "mlir/IR/Builders.h" // from @llvm-project
36 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
37 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
38 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
39 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
40 #include "mlir/IR/FunctionImplementation.h" // from @llvm-project
41 #include "mlir/IR/MLIRContext.h" // from @llvm-project
42 #include "mlir/IR/Matchers.h" // from @llvm-project
43 #include "mlir/IR/OpDefinition.h" // from @llvm-project
44 #include "mlir/IR/OpImplementation.h" // from @llvm-project
45 #include "mlir/IR/PatternMatch.h" // from @llvm-project
46 #include "mlir/IR/Types.h" // from @llvm-project
47 #include "mlir/Support/LogicalResult.h" // from @llvm-project
48 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
49 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
50 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
51 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
52 #include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h"
53
54 namespace mlir {
55
56 namespace TFR {
57
58 //===----------------------------------------------------------------------===//
59 // InlinerInterface
60 //===----------------------------------------------------------------------===//
61
62 namespace {
63 /// This class defines the interface for inlining within the TFR dialect.
64 class TFRInlinerInterface : public DialectInlinerInterface {
65 using DialectInlinerInterface::DialectInlinerInterface;
66
67 public:
68 // Allow all call operations to be inlined.
isLegalToInline(Operation * call,Operation * callable,bool wouldBeCloned) const69 bool isLegalToInline(Operation *call, Operation *callable,
70 bool wouldBeCloned) const final {
71 return true;
72 }
73 // Returns true if the given region 'src' can be inlined into the region
74 // 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline(Region * dest,Region * src,bool wouldBeCloned,BlockAndValueMapping &) const75 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
76 BlockAndValueMapping &) const final {
77 return true;
78 }
79
80 // Returns true if the given operation 'op', that is registered to this
81 // dialect, can be inlined into the region 'dest' that is attached to an
82 // operation registered to the current dialect.
isLegalToInline(Operation * op,Region * dest,bool wouldBeCloned,BlockAndValueMapping &) const83 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
84 BlockAndValueMapping &) const final {
85 return true;
86 }
87
88 // Handle the given inlined terminator by replacing it with a new operation
89 // as necessary. Required when the region has only one block.
handleTerminator(Operation * op,ArrayRef<Value> valuesToRepl) const90 void handleTerminator(Operation *op,
91 ArrayRef<Value> valuesToRepl) const final {
92 auto retValOp = dyn_cast<TFRReturnOp>(op);
93 if (!retValOp) return;
94
95 for (auto ret_value : llvm::zip(valuesToRepl, retValOp.operands())) {
96 std::get<0>(ret_value).replaceAllUsesWith(std::get<1>(ret_value));
97 }
98 }
99
100 // Attempts to materialize a conversion for a type mismatch between a call
101 // from this dialect, and a callable region. This method should generate an
102 // operation that takes 'input' as the only operand, and produces a single
103 // result of 'resultType'. If a conversion can not be generated, nullptr
104 // should be returned.
materializeCallConversion(OpBuilder & builder,Value input,Type result_type,Location conversion_loc) const105 Operation *materializeCallConversion(OpBuilder &builder, Value input,
106 Type result_type,
107 Location conversion_loc) const final {
108 if (!input.getType().isa<IntegerType>() ||
109 !result_type.isa<IntegerType>()) {
110 return nullptr;
111 }
112 auto input_itype = input.getType().cast<IntegerType>();
113 auto result_itype = result_type.cast<IntegerType>();
114 if (input_itype.getWidth() == result_itype.getWidth()) return nullptr;
115 if (input_itype.getWidth() > result_itype.getWidth()) {
116 return builder.create<TruncateIOp>(conversion_loc, result_type, input);
117 } else {
118 return builder.create<SignExtendIOp>(conversion_loc, result_type, input);
119 }
120 }
121 };
122 } // namespace
123
124 //===----------------------------------------------------------------------===//
125 // TFR Dialect
126 //===----------------------------------------------------------------------===//
127
TFRDialect(MLIRContext * context)128 TFRDialect::TFRDialect(MLIRContext *context)
129 : Dialect(/*name=*/"tfr", context, TypeID::get<TFRDialect>()) {
130 // TFR depends on TensorFlow for its canonicalization
131 context->getOrLoadDialect<TF::TensorFlowDialect>();
132
133 addTypes<TFRTensorType, TFRTensorListType, TFRAttrType>();
134 addOperations<
135 #define GET_OP_LIST
136 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc"
137 >();
138
139 addInterfaces<TFRInlinerInterface>();
140 }
141
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)142 Operation *TFRDialect::materializeConstant(OpBuilder &builder, Attribute value,
143 Type type, Location loc) {
144 if (ConstantOp::isBuildableWith(value, type))
145 return builder.create<ConstantOp>(loc, type, value);
146 return nullptr;
147 }
148
classof(Type type)149 bool TFRType::classof(Type type) {
150 return llvm::isa<TFRDialect>(type.getDialect());
151 }
152
153 //===----------------------------------------------------------------------===//
154 // Custom op methods
155 //===----------------------------------------------------------------------===//
156
Verify(ConstantTensorOp op)157 static LogicalResult Verify(ConstantTensorOp op) {
158 auto input_type = op.arg().getType();
159 auto output_type = op.out().getType();
160
161 if (auto output_tensor_type = output_type.dyn_cast<TFRTensorType>()) {
162 return success();
163 }
164
165 auto output_tensor_type = output_type.dyn_cast<RankedTensorType>();
166 if (!output_tensor_type || !output_tensor_type.hasStaticShape()) {
167 op.emitError("output type should be static and ranked.");
168 return failure();
169 }
170
171 if (output_tensor_type.getRank() == 0) {
172 bool same_scalar = output_tensor_type.getElementType() == input_type;
173 if (!same_scalar) {
174 op.emitError("input and output should have the same scalar types.");
175 }
176 return success(same_scalar);
177 }
178
179 if (auto input_vector_type = input_type.dyn_cast<VectorType>()) {
180 bool same_element_type = output_tensor_type.getElementType() ==
181 input_vector_type.getElementType();
182 bool same_shape =
183 output_tensor_type.getShape() == input_vector_type.getShape();
184 if (!same_element_type || !same_shape) {
185 op.emitError("input and output should have same shape and element type.");
186 }
187 return success(same_element_type && same_shape);
188 }
189
190 op.emitError("input can not be converted to an output tensor.");
191 return failure();
192 }
193
Verify(TFRFuncOp func)194 static LogicalResult Verify(TFRFuncOp func) {
195 // Collect all attribute names used by the tensor and tensor list arguments
196 // and returns. Also, collect the names of all the attribute arguments as the
197 // defined list. Later on, the used attribute names will be verified to be in
198 // the defined list.
199 llvm::SmallVector<StringAttr, 4> input_used_attrs, output_used_attrs;
200
201 // While scanning the arguments, record the start/end indices of each argument
202 // type, so the order can be verified as well.
203 // TODO(fengliuai): the attribute arguments with default values need to be
204 // at the end?
205 int first_tensor = -1, last_tensor = -1, first_tensor_list = -1,
206 last_tensor_list = -1, first_attr = -1;
207 for (auto arg : llvm::enumerate(func.getType().getInputs())) {
208 Type arg_type = arg.value();
209
210 if (auto tensor = arg_type.dyn_cast<TFRTensorType>()) {
211 if (first_tensor == -1) {
212 first_tensor = arg.index();
213 }
214 last_tensor = arg.index();
215 auto used = tensor.getAttrKeys();
216 input_used_attrs.append(used.begin(), used.end());
217 continue;
218 }
219
220 if (auto tensor_list = arg_type.dyn_cast<TFRTensorListType>()) {
221 if (first_tensor_list == -1) {
222 first_tensor_list = arg.index();
223 }
224 last_tensor_list = arg.index();
225 auto used = tensor_list.getAttrKeys();
226 input_used_attrs.append(used.begin(), used.end());
227 continue;
228 }
229
230 if (!arg_type.isa<TensorType>()) {
231 if (first_attr == -1) {
232 first_attr = arg.index();
233 }
234 auto name =
235 func.getArgAttrOfType<StringAttr>(arg.index(), kAttrArgumentNameAttr);
236 if (!name) {
237 func.emitError(
238 llvm::Twine(arg.index()) +
239 " attribute argument doesn't have a tfr.name attribute.");
240 return failure();
241 }
242 continue;
243 }
244
245 func.emitError("Builtin TensorType isn't allowed as the argument.");
246 return failure();
247 }
248
249 // Collect all the undefined attributes used in the inputs.
250 llvm::SmallVector<StringAttr, 4> undefined_attrs;
251 for (auto attr : input_used_attrs) {
252 if (!func->getAttr(attr.getValue())) {
253 undefined_attrs.push_back(attr);
254 }
255 }
256
257 // Verify the argument order: tensors, tensor list, attributes; and also
258 // verify there is at most one tensor list argument.
259 if (first_attr != -1 &&
260 (first_attr < last_tensor_list || first_attr < last_tensor)) {
261 func.emitError(
262 "tfr.tensor/tfr.tensor_list argument should be before non tensor "
263 "arguments.");
264 return failure();
265 }
266 // The order between tensor arguments and tensor list arguments and the number
267 // of tensor list arguments are verified only when they couldn't be determined
268 // by the attributes.
269 if (!undefined_attrs.empty()) {
270 if (first_tensor_list != -1 && first_tensor_list < last_tensor) {
271 func.emitError(
272 "tfr.tensor argument should be before tfr.tensor_list argument.");
273 return failure();
274 }
275 if (first_tensor_list != last_tensor_list) {
276 func.emitError("More than one tfr.tensor_list argument isn't allowed.");
277 return failure();
278 }
279 }
280
281 // Verify the result order: tensor, tensor list, and also verify at most one
282 // tensor list result.
283 int undefined_input_attrs_number = undefined_attrs.size();
284 bool seen_tensor_list = false, has_tensor_list_order_error = false,
285 has_multiple_tensor_lists_error = false;
286 for (auto result_type : func.getType().getResults()) {
287 if (auto tensor = result_type.dyn_cast<TFRTensorType>()) {
288 if (seen_tensor_list) {
289 has_tensor_list_order_error = true;
290 } else {
291 auto used = tensor.getAttrKeys();
292 output_used_attrs.append(used.begin(), used.end());
293 }
294 continue;
295 }
296
297 if (auto tensor_list = result_type.dyn_cast<TFRTensorListType>()) {
298 if (seen_tensor_list) {
299 has_multiple_tensor_lists_error = true;
300 } else {
301 seen_tensor_list = true;
302 auto used = tensor_list.getAttrKeys();
303 output_used_attrs.append(used.begin(), used.end());
304 }
305 continue;
306 }
307
308 func.emitError(
309 "None tfr.tensor/tfr.tensor_list results aren't allowed as a "
310 "result.");
311 return failure();
312 }
313
314 // Collect all the undefined attributes used in the outputs.
315 for (auto attr : output_used_attrs) {
316 if (!func->getAttr(attr.getValue())) {
317 undefined_attrs.push_back(attr);
318 }
319 }
320
321 // Verify there are no tensor/tensor list order error and multiple tensor
322 // list arguments error.
323 if (undefined_input_attrs_number != undefined_attrs.size()) {
324 if (has_tensor_list_order_error) {
325 func.emitError(
326 "tfr.tensor result should be before tfr.tensor_list result.");
327 return failure();
328 } else if (has_multiple_tensor_lists_error) {
329 func.emitError("More than one tfr.tensor_list result isn't allowed.");
330 return failure();
331 }
332 }
333
334 // TODO(fengliuai): We might want to refine this constraint because the
335 // tensor element type can be derived.
336 if (!undefined_attrs.empty()) {
337 llvm::SmallVector<std::string, 4> attr_names(undefined_attrs.size());
338 std::transform(undefined_attrs.begin(), undefined_attrs.end(),
339 attr_names.begin(),
340 [](StringAttr attr) { return attr.getValue().str(); });
341 func.emitError(llvm::Twine("Undefined attributes are used: ",
342 llvm::join(attr_names, ",")));
343 return failure();
344 }
345
346 return success();
347 }
348
ParseFuncOp(OpAsmParser & parser,OperationState * result)349 static ParseResult ParseFuncOp(OpAsmParser &parser, OperationState *result) {
350 auto build_func_type = [](Builder &builder, ArrayRef<Type> arg_types,
351 ArrayRef<Type> results,
352 function_like_impl::VariadicFlag, std::string &) {
353 return builder.getFunctionType(arg_types, results);
354 };
355 return function_like_impl::parseFunctionLikeOp(
356 parser, *result, /*allowVariadic=*/false, build_func_type);
357 }
358
PrintFuncOp(OpAsmPrinter & p,TFRFuncOp op)359 static void PrintFuncOp(OpAsmPrinter &p, TFRFuncOp op) {
360 FunctionType fn_type = op.getType();
361 function_like_impl::printFunctionLikeOp(
362 p, op, fn_type.getInputs(), /*isVariadic=*/false, fn_type.getResults());
363 }
364
365 } // namespace TFR
366 } // namespace mlir
367
368 //===----------------------------------------------------------------------===//
369 // TableGen'd op method definitions
370 //===----------------------------------------------------------------------===//
371
372 #define GET_OP_CLASSES
373 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc"
374
375 namespace mlir {
376 namespace TFR {
377 namespace {
378 class ConvertConstToTensorConst : public OpRewritePattern<ConstantTensorOp> {
379 using OpRewritePattern<ConstantTensorOp>::OpRewritePattern;
380
381 public:
matchAndRewrite(ConstantTensorOp cst_tensor_op,PatternRewriter & rewriter) const382 LogicalResult matchAndRewrite(ConstantTensorOp cst_tensor_op,
383 PatternRewriter &rewriter) const override {
384 Location loc = cst_tensor_op.getLoc();
385 Type out_type = cst_tensor_op.getType();
386 Operation *new_cst = nullptr;
387
388 ArrayAttr array;
389 if (matchPattern(cst_tensor_op.arg(), m_Constant(&array))) {
390 llvm::DenseSet<Type> all_types;
391 for (auto it : array) {
392 all_types.insert(it.getType());
393 }
394 if (all_types.size() != 1) return failure();
395 ShapedType new_out_type = RankedTensorType::get(
396 {static_cast<int64_t>(array.size())}, *all_types.begin());
397 DenseElementsAttr attr =
398 DenseElementsAttr::get(new_out_type, array.getValue());
399 new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, attr);
400 if (out_type.isa<TFRTensorType>()) {
401 new_cst = rewriter.create<CastOp>(loc, out_type, new_cst->getResult(0));
402 }
403 rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0));
404 return success();
405 }
406
407 Attribute scalar;
408 if (matchPattern(cst_tensor_op.arg(), m_Constant(&scalar))) {
409 Type new_out_type = RankedTensorType::get({}, scalar.getType());
410 new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, scalar);
411 if (out_type.isa<TFRTensorType>()) {
412 new_cst = rewriter.create<CastOp>(loc, out_type, new_cst->getResult(0));
413 }
414 rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0));
415 return success();
416 }
417 return failure();
418 }
419 };
420
421 class RemoveRedundantCast : public OpRewritePattern<CastOp> {
422 using OpRewritePattern<CastOp>::OpRewritePattern;
423
424 public:
matchAndRewrite(CastOp cast_op,PatternRewriter & rewriter) const425 LogicalResult matchAndRewrite(CastOp cast_op,
426 PatternRewriter &rewriter) const override {
427 auto preceding_cast =
428 llvm::dyn_cast_or_null<CastOp>(cast_op.arg().getDefiningOp());
429 if (!preceding_cast) {
430 return failure();
431 }
432 Value input = preceding_cast.arg();
433 Type input_type = input.getType();
434 Type output_type = cast_op.getType();
435
436 // Preserve quantization information for intermediate tensors.
437 auto intermediate_type = preceding_cast.getType().dyn_cast<TensorType>();
438 if (intermediate_type &&
439 intermediate_type.getElementType().isa<quant::QuantizedType>()) {
440 return failure();
441 }
442
443 // If the two types are the same, the back-to-back tfr.cast ops can be
444 // removed.
445 if (input_type == output_type || output_type.isa<UnrankedTensorType>()) {
446 rewriter.replaceOp(cast_op, {input});
447 return success();
448 }
449
450 // If the rank of the input tensor isn't ranked, we replace the pair
451 // with tf.EnsureShape op so it can be removed after shape inference or
452 // confirmed at runtime.
453 if (input_type.isa<UnrankedTensorType>() && output_type.isa<ShapedType>()) {
454 auto shape = output_type.cast<ShapedType>().getShape();
455 auto shape_attr = TF::ShapeAttr::get(rewriter.getContext(), shape);
456 rewriter.replaceOpWithNewOp<TF::EnsureShapeOp>(cast_op, output_type,
457 input, shape_attr);
458 }
459
460 return success();
461 }
462 };
463
464 class GetTensorShape : public OpRewritePattern<GetShapeOp> {
465 using OpRewritePattern<GetShapeOp>::OpRewritePattern;
466
467 public:
matchAndRewrite(GetShapeOp shape_op,PatternRewriter & rewriter) const468 LogicalResult matchAndRewrite(GetShapeOp shape_op,
469 PatternRewriter &rewriter) const override {
470 Operation *preceding_op = shape_op.arg().getDefiningOp();
471 if (auto cast_op = llvm::dyn_cast_or_null<CastOp>(preceding_op)) {
472 // replace this pair by shape.shape_of, so the folding works.
473 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(shape_op, cast_op.arg());
474 return success();
475 }
476 return failure();
477 }
478 };
479
480 class RemoveRedundantGetElement : public OpRewritePattern<GetElementOp> {
481 using OpRewritePattern<GetElementOp>::OpRewritePattern;
482
483 public:
matchAndRewrite(GetElementOp ge_op,PatternRewriter & rewriter) const484 LogicalResult matchAndRewrite(GetElementOp ge_op,
485 PatternRewriter &rewriter) const override {
486 IntegerAttr index;
487 if (!matchPattern(ge_op.index(), m_Constant(&index))) {
488 return failure();
489 }
490 auto preceding_build_list = llvm::dyn_cast_or_null<BuildListOp>(
491 ge_op.tensor_list().getDefiningOp());
492 if (!preceding_build_list ||
493 preceding_build_list.getNumOperands() <= index.getInt()) {
494 return failure();
495 }
496 Value input = preceding_build_list.getOperand(index.getInt());
497 Type output_type = ge_op.getType();
498 if (input.getType() != output_type &&
499 !output_type.isa<UnrankedTensorType>()) {
500 return failure();
501 }
502 rewriter.replaceOp(ge_op, {input});
503 return success();
504 }
505 };
506
507 class RemoveRedundantGetLength : public OpRewritePattern<GetLengthOp> {
508 using OpRewritePattern<GetLengthOp>::OpRewritePattern;
509
510 public:
matchAndRewrite(GetLengthOp gl_op,PatternRewriter & rewriter) const511 LogicalResult matchAndRewrite(GetLengthOp gl_op,
512 PatternRewriter &rewriter) const override {
513 auto preceding_build_list = llvm::dyn_cast_or_null<BuildListOp>(
514 gl_op.tensor_list().getDefiningOp());
515 if (!preceding_build_list) {
516 return failure();
517 }
518 int64_t num_tensors = preceding_build_list.getNumOperands();
519 rewriter.replaceOpWithNewOp<ConstantOp>(gl_op,
520 rewriter.getIndexAttr(num_tensors));
521 return success();
522 }
523 };
524
525 class BuildConstantListAsAttr : public OpRewritePattern<BuildListOp> {
526 using OpRewritePattern<BuildListOp>::OpRewritePattern;
527
528 public:
matchAndRewrite(BuildListOp bl_op,PatternRewriter & rewriter) const529 LogicalResult matchAndRewrite(BuildListOp bl_op,
530 PatternRewriter &rewriter) const override {
531 SmallVector<Attribute, 4> array_list;
532 array_list.reserve(bl_op.getNumOperands());
533 for (const auto &operand : bl_op.getOperands()) {
534 Attribute array_elt;
535 if (!matchPattern(operand, m_Constant(&array_elt))) {
536 return failure();
537 }
538 array_list.push_back(array_elt);
539 }
540 auto array_attr = rewriter.getArrayAttr(array_list);
541 rewriter.replaceOpWithNewOp<TFR::ConstOp>(bl_op, array_attr);
542 return success();
543 }
544 };
545
getQuantizedElementType(CastOp cast_op)546 quant::QuantizedType getQuantizedElementType(CastOp cast_op) {
547 if (!cast_op || !cast_op.getInputElementType()) {
548 return {};
549 }
550 return cast_op.getInputElementType()
551 .cast<TypeAttr>()
552 .getValue()
553 .dyn_cast<quant::QuantizedType>();
554 }
555
556 class RemoveRawDataOp : public OpRewritePattern<TFRQuantRawDataOp> {
557 using OpRewritePattern<TFRQuantRawDataOp>::OpRewritePattern;
558
559 public:
matchAndRewrite(TFRQuantRawDataOp raw_data_op,PatternRewriter & rewriter) const560 LogicalResult matchAndRewrite(TFRQuantRawDataOp raw_data_op,
561 PatternRewriter &rewriter) const override {
562 auto preceding_cast = dyn_cast<CastOp>(raw_data_op.input().getDefiningOp());
563 if (!getQuantizedElementType(preceding_cast)) {
564 return failure();
565 }
566 // If there are redundant casts, hoist output of raw data op originating op.
567 if (auto redundant_cast = preceding_cast.arg().getDefiningOp()) {
568 if (!isa<CastOp>(redundant_cast) ||
569 cast<CastOp>(redundant_cast).arg().getType() !=
570 preceding_cast.out().getType()) {
571 return failure();
572 }
573 raw_data_op.output().replaceAllUsesWith(
574 cast<CastOp>(redundant_cast).arg());
575 } else {
576 // If the argument of cast op is input, then simply remove the RawData op.
577 raw_data_op.output().replaceAllUsesWith(preceding_cast.out());
578 }
579 return success();
580 }
581 };
582
583 class RemoveQParamsOp : public OpRewritePattern<TFRQuantQParamsOp> {
584 using OpRewritePattern<TFRQuantQParamsOp>::OpRewritePattern;
585
586 public:
matchAndRewrite(TFRQuantQParamsOp qparams_op,PatternRewriter & rewriter) const587 LogicalResult matchAndRewrite(TFRQuantQParamsOp qparams_op,
588 PatternRewriter &rewriter) const override {
589 auto cast_op = dyn_cast<TFR::CastOp>(qparams_op.input().getDefiningOp());
590 auto cast_qtype = getQuantizedElementType(cast_op);
591 if (!cast_qtype) {
592 return failure();
593 }
594
595 TF::ConstOp scale_op;
596 TF::ConstOp zp_op;
597
598 // Reads quantization parameters from the quantized type, and converts
599 // them to constants.
600 rewriter.setInsertionPoint(qparams_op);
601 Location loc = qparams_op->getLoc();
602 if (auto qtype = cast_qtype.dyn_cast<quant::UniformQuantizedType>()) {
603 scale_op = rewriter.create<TF::ConstOp>(
604 loc, RankedTensorType::get({}, rewriter.getF32Type()),
605 rewriter.getF32FloatAttr(qtype.getScale()));
606 zp_op = rewriter.create<TF::ConstOp>(
607 loc, RankedTensorType::get({}, rewriter.getI32Type()),
608 rewriter.getI32IntegerAttr(qtype.getZeroPoint()));
609 } else if (auto qtype =
610 cast_qtype.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
611 SmallVector<float> scales(qtype.getScales().begin(),
612 qtype.getScales().end());
613 SmallVector<int32_t> zps(qtype.getZeroPoints().begin(),
614 qtype.getZeroPoints().end());
615 const size_t num_channels = qtype.getScales().size();
616
617 auto scales_type = RankedTensorType::get(
618 {static_cast<int64_t>(num_channels)}, rewriter.getF32Type());
619 auto scales_attr =
620 DenseElementsAttr::get(scales_type, llvm::makeArrayRef(scales));
621 scale_op = rewriter.create<TF::ConstOp>(loc, scales_attr);
622
623 auto zps_type = RankedTensorType::get(
624 {static_cast<int64_t>(num_channels)}, rewriter.getI32Type());
625 auto zps_attr = DenseElementsAttr::get(zps_type, llvm::makeArrayRef(zps));
626 zp_op = rewriter.create<TF::ConstOp>(loc, zps_attr);
627 }
628 if (!scale_op || !zp_op) {
629 return failure();
630 }
631 auto scale_cast = rewriter.create<CastOp>(loc, qparams_op.scale().getType(),
632 scale_op.output());
633 auto zp_cast =
634 rewriter.create<CastOp>(loc, qparams_op.zp().getType(), zp_op.output());
635
636 qparams_op.scale().replaceAllUsesWith(scale_cast.out());
637 qparams_op.zp().replaceAllUsesWith(zp_cast.out());
638 return success();
639 }
640 };
641
642 // TODO(b/193731721): Migrate tfr_ builtin canonicalizations to LowerTFROpPass
643 class RemoveScaleFactorOp : public OpRewritePattern<TFRQuantScaleFactorOp> {
644 using OpRewritePattern<TFRQuantScaleFactorOp>::OpRewritePattern;
645
646 public:
647 // Replace quant_scale_factor with constant tensor equivalent to
648 // TFR_ConstantTensorOp (
649 // ConstantOp (ConstAttr<F32Attr (in_scale[0] * in_scale[1] / out_scale))
650 // )
651 // Currently, all decompositions using this pattern (Conv2D, FC) have the
652 // following preconditions:
653 // * out_scale: float scalar attribute
654 // * in_scale[0] (input scale): float scalar, given by tf.Const -> tfr.cast
655 // * in_scale[1] (filter scale): float scalar/vector
656 // (per-tensor vs per-channel) quantization, given by tf.Const -> tfr.cast
matchAndRewrite(TFRQuantScaleFactorOp scale_factor_op,PatternRewriter & rewriter) const657 LogicalResult matchAndRewrite(TFRQuantScaleFactorOp scale_factor_op,
658 PatternRewriter &rewriter) const override {
659 auto out_scale_op = scale_factor_op.out_scale().getDefiningOp<ConstantOp>();
660 if (!out_scale_op) {
661 return failure();
662 }
663 const double out_scale =
664 out_scale_op.value().cast<FloatAttr>().getValueAsDouble();
665
666 auto in_scales_op =
667 scale_factor_op.in_scales().getDefiningOp<BuildListOp>();
668 if (!in_scales_op || in_scales_op.getNumOperands() != 2) {
669 // BuildListOp is variadic, but we require two values: input_scale
670 // and filter_scale.
671 return failure();
672 }
673
674 auto in_scale_op = in_scales_op.getOperand(0).getDefiningOp<CastOp>();
675 if (!in_scale_op) {
676 return failure();
677 }
678
679 DenseFPElementsAttr in_scale_attr;
680 if (!matchPattern(in_scale_op.arg(), m_Constant(&in_scale_attr)) ||
681 in_scale_attr.size() != 1) {
682 return failure();
683 }
684 const float in_scale = in_scale_attr.getValue<float>(0);
685 auto filter_scale_op = in_scales_op.getOperand(1).getDefiningOp<CastOp>();
686 if (!filter_scale_op) {
687 return failure();
688 }
689 DenseFPElementsAttr filter_scale_attr;
690 if (!matchPattern(filter_scale_op.arg(), m_Constant(&filter_scale_attr))) {
691 return failure();
692 }
693
694 // The shape of scale_type is {} (rank 0) for per-tensor quantized tensor,
695 // and {num_channels} (rank 1) for per-channel quantized one.
696 auto scale_type = filter_scale_attr.getType().dyn_cast<RankedTensorType>();
697 if (scale_type.getRank() != 0 && scale_type.getRank() != 1) {
698 return failure();
699 }
700 SmallVector<float> scale_factors;
701 scale_factors.reserve(filter_scale_attr.size());
702 for (auto value : filter_scale_attr.getFloatValues()) {
703 scale_factors.push_back(in_scale * value.convertToFloat() / out_scale);
704 }
705 rewriter.setInsertionPoint(scale_factor_op);
706 const Location loc = scale_factor_op->getLoc();
707 auto result_scale_op = rewriter.create<TF::ConstOp>(
708 loc,
709 DenseElementsAttr::get(scale_type, llvm::makeArrayRef(scale_factors)));
710 auto result_scale_cast_op = rewriter.create<CastOp>(
711 loc, scale_factor_op.getType(), result_scale_op.output());
712 scale_factor_op.scale_factor().replaceAllUsesWith(
713 result_scale_cast_op.out());
714 return success();
715 }
716 };
717
718 class RemoveRescaleOp : public OpRewritePattern<TFRQuantRescaleOp> {
719 using OpRewritePattern<TFRQuantRescaleOp>::OpRewritePattern;
720
721 public:
722 // Replace quant_rescale (input, scale, zp) with
723 // tf.Cast(tf.Round(tf.Cast(input, f32) * scale) + tf.Cast(zp, f32), i32)
matchAndRewrite(TFRQuantRescaleOp rescale_op,PatternRewriter & rewriter) const724 LogicalResult matchAndRewrite(TFRQuantRescaleOp rescale_op,
725 PatternRewriter &rewriter) const override {
726 Value input = rescale_op.input();
727 Value scale = rescale_op.scale();
728 Value zp = rescale_op.zp();
729
730 const Location loc = rescale_op->getLoc();
731 const auto result_types = rescale_op->getResultTypes();
732 auto c_false =
733 rewriter.create<ConstantOp>(loc, rewriter.getBoolAttr(false));
734 TypeAttr f32_attr = TypeAttr::get(rewriter.getF32Type());
735 TFRAttrType output_type = TFRAttrType::get(rewriter.getContext());
736 auto constant_f32_op = rewriter.create<ConstOp>(loc, output_type, f32_attr);
737 TypeAttr i32_attr = TypeAttr::get(rewriter.getI32Type());
738 auto constant_i32_op = rewriter.create<ConstOp>(loc, output_type, i32_attr);
739
740 IntegerAttr zp_attr;
741 if (!matchPattern(zp, m_Constant(&zp_attr))) {
742 return failure();
743 }
744 rewriter.setInsertionPoint(zp.getDefiningOp());
745 auto zp_tensor = rewriter.create<TF::ConstOp>(
746 loc, RankedTensorType::get({}, zp.getType()), zp_attr);
747 auto zp_cast = rewriter.create<CastOp>(
748 loc, rewriter.getType<TFRTensorType>(), zp_tensor.output());
749
750 rewriter.setInsertionPoint(rescale_op);
751 auto cast_input_to_float_op = rewriter.create<CallOp>(
752 loc, result_types, rewriter.getSymbolRefAttr("tf__cast"),
753 ArrayRef<Value>{input, constant_f32_op, c_false});
754 auto input_x_scale_op = rewriter.create<CallOp>(
755 loc, result_types, rewriter.getSymbolRefAttr("tf__mul"),
756 ArrayRef<Value>{cast_input_to_float_op.getResult(0), scale});
757 auto round_rescaled_op = rewriter.create<CallOp>(
758 loc, result_types, rewriter.getSymbolRefAttr("tf__round"),
759 ArrayRef<Value>{input_x_scale_op->getResult(0)});
760 auto cast_zp_to_float_op = rewriter.create<CallOp>(
761 loc, result_types, rewriter.getSymbolRefAttr("tf__cast"),
762 ArrayRef<Value>{zp_cast, constant_f32_op, c_false});
763 auto recentered_op = rewriter.create<CallOp>(
764 loc, result_types, rewriter.getSymbolRefAttr("tf__add"),
765 ArrayRef<Value>{round_rescaled_op->getResult(0),
766 cast_zp_to_float_op->getResult(0)});
767 auto cast_output_to_i32 = rewriter.create<CallOp>(
768 loc, result_types, rewriter.getSymbolRefAttr("tf__cast"),
769 ArrayRef<Value>{recentered_op->getResult(0), constant_i32_op, c_false});
770 rescale_op.output().replaceAllUsesWith(cast_output_to_i32.getResult(0));
771 return success();
772 }
773 };
774
775 } // namespace
776
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)777 void ConstantTensorOp::getCanonicalizationPatterns(
778 OwningRewritePatternList &results, MLIRContext *context) {
779 results.insert<ConvertConstToTensorConst>(context);
780 }
781
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)782 void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
783 MLIRContext *context) {
784 results.insert<RemoveRedundantCast>(context);
785 }
786
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)787 void GetShapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
788 MLIRContext *context) {
789 results.insert<GetTensorShape>(context);
790 }
791
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)792 void GetElementOp::getCanonicalizationPatterns(
793 OwningRewritePatternList &results, MLIRContext *context) {
794 results.insert<RemoveRedundantGetElement>(context);
795 }
796
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)797 void GetLengthOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
798 MLIRContext *context) {
799 results.insert<RemoveRedundantGetLength>(context);
800 }
801
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)802 void BuildListOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
803 MLIRContext *context) {
804 results.insert<BuildConstantListAsAttr>(context);
805 }
806
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)807 void TFRQuantRawDataOp::getCanonicalizationPatterns(
808 OwningRewritePatternList &results, MLIRContext *context) {
809 results.insert<RemoveRawDataOp>(context);
810 }
811
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)812 void TFRQuantQParamsOp::getCanonicalizationPatterns(
813 OwningRewritePatternList &results, MLIRContext *context) {
814 results.insert<RemoveQParamsOp>(context);
815 }
816
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)817 void TFRQuantRescaleOp::getCanonicalizationPatterns(
818 OwningRewritePatternList &results, MLIRContext *context) {
819 results.insert<RemoveRescaleOp>(context);
820 }
821
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)822 void TFRQuantScaleFactorOp::getCanonicalizationPatterns(
823 OwningRewritePatternList &results, MLIRContext *context) {
824 results.insert<RemoveScaleFactorOp>(context);
825 }
826
fold(ArrayRef<Attribute> operands)827 OpFoldResult TFR::EqualOp::fold(ArrayRef<Attribute> operands) {
828 assert(operands.size() == 2 && "equal op has two operands");
829 auto ctx = getContext();
830 if (operands[0] == operands[1]) return BoolAttr::get(ctx, true);
831 return BoolAttr::get(ctx, false);
832 }
833
fold(ArrayRef<Attribute> operands)834 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
835 assert(operands.empty() && "constant has no operands");
836
837 // Return the held attribute value.
838 return value();
839 }
840
841 // CallableOpInterface
getCallableRegion()842 Region *TFRFuncOp::getCallableRegion() {
843 return isExternal() ? nullptr : &body().front();
844 }
845
846 // CallableOpInterface
getCallableResults()847 ArrayRef<Type> TFRFuncOp::getCallableResults() {
848 return getType().getResults();
849 }
850
851 //===----------------------------------------------------------------------===//
852 // Dialect type definitions
853 //===----------------------------------------------------------------------===//
854
855 // Parses a TFR type.
856 // tfr_type ::= tensor_type | tensor_list_type | attr_type
857 // string_list ::= `[` string-literal (, string-literal)+ `]`
858 // tensor_type ::= `tensor`
859 // | `tensor<` (string-literal | string_list) '>'
860 // tensor_list_type ::= `tensor_list`
861 // | `tensor_list<` (string-literal | string_list) '>'
862 // attr_type ::= `attr`
parseType(DialectAsmParser & parser) const863 Type TFRDialect::parseType(DialectAsmParser &parser) const {
864 Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
865 MLIRContext *ctx = loc.getContext();
866
867 StringRef typeNameSpelling;
868 if (failed(parser.parseKeyword(&typeNameSpelling))) return {};
869 llvm::SmallVector<StringAttr, 4> attrs;
870 if (succeeded(parser.parseOptionalLess())) {
871 bool l_square_parsed = false;
872 if (succeeded(parser.parseOptionalLSquare())) {
873 l_square_parsed = true;
874 }
875
876 do {
877 StringRef attr;
878 if (failed(parser.parseKeyword(&attr))) return {};
879 attrs.push_back(StringAttr::get(ctx, attr));
880 } while (succeeded(parser.parseOptionalComma()));
881
882 if (l_square_parsed && failed(parser.parseRSquare())) {
883 parser.emitError(parser.getNameLoc(), "expected ']'");
884 }
885
886 if (failed(parser.parseGreater())) {
887 parser.emitError(parser.getNameLoc(), "expected '>'");
888 }
889 }
890
891 if (typeNameSpelling == "tensor") {
892 return TFRTensorType::getChecked(attrs, loc);
893 } else if (typeNameSpelling == "tensor_list") {
894 return TFRTensorListType::getChecked(attrs, loc);
895 } else if (typeNameSpelling == "attr") {
896 return TFRAttrType::getChecked(loc, loc.getContext());
897 } else {
898 parser.emitError(parser.getNameLoc(), "unknown type " + typeNameSpelling);
899 return {};
900 }
901 }
902
printType(Type type,DialectAsmPrinter & os) const903 void TFRDialect::printType(Type type, DialectAsmPrinter &os) const {
904 llvm::ArrayRef<StringAttr> attrs;
905
906 if (type.isa<TFRAttrType>()) {
907 os << "attr";
908 return;
909 }
910 if (auto tensor_ty = type.dyn_cast<TFRTensorType>()) {
911 attrs = tensor_ty.getAttrKeys();
912 os << "tensor";
913 } else if (auto tensor_list_ty = type.dyn_cast<TFRTensorListType>()) {
914 attrs = tensor_list_ty.getAttrKeys();
915 os << "tensor_list";
916 } else {
917 llvm_unreachable("Unhandled tfr type");
918 }
919
920 if (attrs.empty()) return;
921 os << "<";
922
923 if (attrs.size() > 1) {
924 os << "[";
925 }
926
927 llvm::interleaveComma(attrs, os,
928 [&](StringAttr attr) { os << attr.getValue(); });
929
930 if (attrs.size() > 1) {
931 os << "]";
932 }
933 os << ">";
934 }
935
936 } // namespace TFR
937 } // namespace mlir
938