• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <string>
21 
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringSet.h"
28 #include "llvm/ADT/Twine.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
32 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
33 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
34 #include "mlir/IR/Attributes.h"  // from @llvm-project
35 #include "mlir/IR/Builders.h"  // from @llvm-project
36 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
37 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
38 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
39 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
40 #include "mlir/IR/FunctionImplementation.h"  // from @llvm-project
41 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
42 #include "mlir/IR/Matchers.h"  // from @llvm-project
43 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
44 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
45 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
46 #include "mlir/IR/Types.h"  // from @llvm-project
47 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
48 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
49 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
50 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
51 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
52 #include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h"
53 
54 namespace mlir {
55 
56 namespace TFR {
57 
58 //===----------------------------------------------------------------------===//
59 // InlinerInterface
60 //===----------------------------------------------------------------------===//
61 
62 namespace {
63 /// This class defines the interface for inlining within the TFR dialect.
64 class TFRInlinerInterface : public DialectInlinerInterface {
65   using DialectInlinerInterface::DialectInlinerInterface;
66 
67  public:
68   // Allow all call operations to be inlined.
isLegalToInline(Operation * call,Operation * callable,bool wouldBeCloned) const69   bool isLegalToInline(Operation *call, Operation *callable,
70                        bool wouldBeCloned) const final {
71     return true;
72   }
73   // Returns true if the given region 'src' can be inlined into the region
74   // 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline(Region * dest,Region * src,bool wouldBeCloned,BlockAndValueMapping &) const75   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
76                        BlockAndValueMapping &) const final {
77     return true;
78   }
79 
80   // Returns true if the given operation 'op', that is registered to this
81   // dialect, can be inlined into the region 'dest' that is attached to an
82   // operation registered to the current dialect.
isLegalToInline(Operation * op,Region * dest,bool wouldBeCloned,BlockAndValueMapping &) const83   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
84                        BlockAndValueMapping &) const final {
85     return true;
86   }
87 
88   // Handle the given inlined terminator by replacing it with a new operation
89   // as necessary. Required when the region has only one block.
handleTerminator(Operation * op,ArrayRef<Value> valuesToRepl) const90   void handleTerminator(Operation *op,
91                         ArrayRef<Value> valuesToRepl) const final {
92     auto retValOp = dyn_cast<TFRReturnOp>(op);
93     if (!retValOp) return;
94 
95     for (auto ret_value : llvm::zip(valuesToRepl, retValOp.operands())) {
96       std::get<0>(ret_value).replaceAllUsesWith(std::get<1>(ret_value));
97     }
98   }
99 
100   // Attempts to materialize a conversion for a type mismatch between a call
101   // from this dialect, and a callable region. This method should generate an
102   // operation that takes 'input' as the only operand, and produces a single
103   // result of 'resultType'. If a conversion can not be generated, nullptr
104   // should be returned.
materializeCallConversion(OpBuilder & builder,Value input,Type result_type,Location conversion_loc) const105   Operation *materializeCallConversion(OpBuilder &builder, Value input,
106                                        Type result_type,
107                                        Location conversion_loc) const final {
108     if (!input.getType().isa<IntegerType>() ||
109         !result_type.isa<IntegerType>()) {
110       return nullptr;
111     }
112     auto input_itype = input.getType().cast<IntegerType>();
113     auto result_itype = result_type.cast<IntegerType>();
114     if (input_itype.getWidth() == result_itype.getWidth()) return nullptr;
115     if (input_itype.getWidth() > result_itype.getWidth()) {
116       return builder.create<TruncateIOp>(conversion_loc, result_type, input);
117     } else {
118       return builder.create<SignExtendIOp>(conversion_loc, result_type, input);
119     }
120   }
121 };
122 }  // namespace
123 
124 //===----------------------------------------------------------------------===//
125 // TFR Dialect
126 //===----------------------------------------------------------------------===//
127 
TFRDialect(MLIRContext * context)128 TFRDialect::TFRDialect(MLIRContext *context)
129     : Dialect(/*name=*/"tfr", context, TypeID::get<TFRDialect>()) {
130   // TFR depends on TensorFlow for its canonicalization
131   context->getOrLoadDialect<TF::TensorFlowDialect>();
132 
133   addTypes<TFRTensorType, TFRTensorListType, TFRAttrType>();
134   addOperations<
135 #define GET_OP_LIST
136 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc"
137       >();
138 
139   addInterfaces<TFRInlinerInterface>();
140 }
141 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)142 Operation *TFRDialect::materializeConstant(OpBuilder &builder, Attribute value,
143                                            Type type, Location loc) {
144   if (ConstantOp::isBuildableWith(value, type))
145     return builder.create<ConstantOp>(loc, type, value);
146   return nullptr;
147 }
148 
classof(Type type)149 bool TFRType::classof(Type type) {
150   return llvm::isa<TFRDialect>(type.getDialect());
151 }
152 
153 //===----------------------------------------------------------------------===//
154 // Custom op methods
155 //===----------------------------------------------------------------------===//
156 
Verify(ConstantTensorOp op)157 static LogicalResult Verify(ConstantTensorOp op) {
158   auto input_type = op.arg().getType();
159   auto output_type = op.out().getType();
160 
161   if (auto output_tensor_type = output_type.dyn_cast<TFRTensorType>()) {
162     return success();
163   }
164 
165   auto output_tensor_type = output_type.dyn_cast<RankedTensorType>();
166   if (!output_tensor_type || !output_tensor_type.hasStaticShape()) {
167     op.emitError("output type should be static and ranked.");
168     return failure();
169   }
170 
171   if (output_tensor_type.getRank() == 0) {
172     bool same_scalar = output_tensor_type.getElementType() == input_type;
173     if (!same_scalar) {
174       op.emitError("input and output should have the same scalar types.");
175     }
176     return success(same_scalar);
177   }
178 
179   if (auto input_vector_type = input_type.dyn_cast<VectorType>()) {
180     bool same_element_type = output_tensor_type.getElementType() ==
181                              input_vector_type.getElementType();
182     bool same_shape =
183         output_tensor_type.getShape() == input_vector_type.getShape();
184     if (!same_element_type || !same_shape) {
185       op.emitError("input and output should have same shape and element type.");
186     }
187     return success(same_element_type && same_shape);
188   }
189 
190   op.emitError("input can not be converted to an output tensor.");
191   return failure();
192 }
193 
Verify(TFRFuncOp func)194 static LogicalResult Verify(TFRFuncOp func) {
195   // Collect all attribute names used by the tensor and tensor list arguments
196   // and returns. Also, collect the names of all the attribute arguments as the
197   // defined list. Later on, the used attribute names will be verified to be in
198   // the defined list.
199   llvm::SmallVector<StringAttr, 4> input_used_attrs, output_used_attrs;
200 
201   // While scanning the arguments, record the start/end indices of each argument
202   // type, so the order can be verified as well.
203   // TODO(fengliuai): the attribute arguments with default values need to be
204   // at the end?
205   int first_tensor = -1, last_tensor = -1, first_tensor_list = -1,
206       last_tensor_list = -1, first_attr = -1;
207   for (auto arg : llvm::enumerate(func.getType().getInputs())) {
208     Type arg_type = arg.value();
209 
210     if (auto tensor = arg_type.dyn_cast<TFRTensorType>()) {
211       if (first_tensor == -1) {
212         first_tensor = arg.index();
213       }
214       last_tensor = arg.index();
215       auto used = tensor.getAttrKeys();
216       input_used_attrs.append(used.begin(), used.end());
217       continue;
218     }
219 
220     if (auto tensor_list = arg_type.dyn_cast<TFRTensorListType>()) {
221       if (first_tensor_list == -1) {
222         first_tensor_list = arg.index();
223       }
224       last_tensor_list = arg.index();
225       auto used = tensor_list.getAttrKeys();
226       input_used_attrs.append(used.begin(), used.end());
227       continue;
228     }
229 
230     if (!arg_type.isa<TensorType>()) {
231       if (first_attr == -1) {
232         first_attr = arg.index();
233       }
234       auto name =
235           func.getArgAttrOfType<StringAttr>(arg.index(), kAttrArgumentNameAttr);
236       if (!name) {
237         func.emitError(
238             llvm::Twine(arg.index()) +
239             " attribute argument doesn't have a tfr.name attribute.");
240         return failure();
241       }
242       continue;
243     }
244 
245     func.emitError("Builtin TensorType isn't allowed as the argument.");
246     return failure();
247   }
248 
249   // Collect all the undefined attributes used in the inputs.
250   llvm::SmallVector<StringAttr, 4> undefined_attrs;
251   for (auto attr : input_used_attrs) {
252     if (!func->getAttr(attr.getValue())) {
253       undefined_attrs.push_back(attr);
254     }
255   }
256 
257   // Verify the argument order: tensors, tensor list, attributes; and also
258   // verify there is at most one tensor list argument.
259   if (first_attr != -1 &&
260       (first_attr < last_tensor_list || first_attr < last_tensor)) {
261     func.emitError(
262         "tfr.tensor/tfr.tensor_list argument should be before non tensor "
263         "arguments.");
264     return failure();
265   }
266   // The order between tensor arguments and tensor list arguments and the number
267   // of tensor list arguments are verified only when they couldn't be determined
268   // by the attributes.
269   if (!undefined_attrs.empty()) {
270     if (first_tensor_list != -1 && first_tensor_list < last_tensor) {
271       func.emitError(
272           "tfr.tensor argument should be before tfr.tensor_list argument.");
273       return failure();
274     }
275     if (first_tensor_list != last_tensor_list) {
276       func.emitError("More than one tfr.tensor_list argument isn't allowed.");
277       return failure();
278     }
279   }
280 
281   // Verify the result order: tensor, tensor list, and also verify at most one
282   // tensor list result.
283   int undefined_input_attrs_number = undefined_attrs.size();
284   bool seen_tensor_list = false, has_tensor_list_order_error = false,
285        has_multiple_tensor_lists_error = false;
286   for (auto result_type : func.getType().getResults()) {
287     if (auto tensor = result_type.dyn_cast<TFRTensorType>()) {
288       if (seen_tensor_list) {
289         has_tensor_list_order_error = true;
290       } else {
291         auto used = tensor.getAttrKeys();
292         output_used_attrs.append(used.begin(), used.end());
293       }
294       continue;
295     }
296 
297     if (auto tensor_list = result_type.dyn_cast<TFRTensorListType>()) {
298       if (seen_tensor_list) {
299         has_multiple_tensor_lists_error = true;
300       } else {
301         seen_tensor_list = true;
302         auto used = tensor_list.getAttrKeys();
303         output_used_attrs.append(used.begin(), used.end());
304       }
305       continue;
306     }
307 
308     func.emitError(
309         "None tfr.tensor/tfr.tensor_list results aren't allowed as a "
310         "result.");
311     return failure();
312   }
313 
314   // Collect all the undefined attributes used in the outputs.
315   for (auto attr : output_used_attrs) {
316     if (!func->getAttr(attr.getValue())) {
317       undefined_attrs.push_back(attr);
318     }
319   }
320 
321   // Verify there are no tensor/tensor list order error and multiple tensor
322   // list arguments error.
323   if (undefined_input_attrs_number != undefined_attrs.size()) {
324     if (has_tensor_list_order_error) {
325       func.emitError(
326           "tfr.tensor result should be before tfr.tensor_list result.");
327       return failure();
328     } else if (has_multiple_tensor_lists_error) {
329       func.emitError("More than one tfr.tensor_list result isn't allowed.");
330       return failure();
331     }
332   }
333 
334   // TODO(fengliuai): We might want to refine this constraint because the
335   // tensor element type can be derived.
336   if (!undefined_attrs.empty()) {
337     llvm::SmallVector<std::string, 4> attr_names(undefined_attrs.size());
338     std::transform(undefined_attrs.begin(), undefined_attrs.end(),
339                    attr_names.begin(),
340                    [](StringAttr attr) { return attr.getValue().str(); });
341     func.emitError(llvm::Twine("Undefined attributes are used: ",
342                                llvm::join(attr_names, ",")));
343     return failure();
344   }
345 
346   return success();
347 }
348 
ParseFuncOp(OpAsmParser & parser,OperationState * result)349 static ParseResult ParseFuncOp(OpAsmParser &parser, OperationState *result) {
350   auto build_func_type = [](Builder &builder, ArrayRef<Type> arg_types,
351                             ArrayRef<Type> results,
352                             function_like_impl::VariadicFlag, std::string &) {
353     return builder.getFunctionType(arg_types, results);
354   };
355   return function_like_impl::parseFunctionLikeOp(
356       parser, *result, /*allowVariadic=*/false, build_func_type);
357 }
358 
PrintFuncOp(OpAsmPrinter & p,TFRFuncOp op)359 static void PrintFuncOp(OpAsmPrinter &p, TFRFuncOp op) {
360   FunctionType fn_type = op.getType();
361   function_like_impl::printFunctionLikeOp(
362       p, op, fn_type.getInputs(), /*isVariadic=*/false, fn_type.getResults());
363 }
364 
365 }  // namespace TFR
366 }  // namespace mlir
367 
368 //===----------------------------------------------------------------------===//
369 // TableGen'd op method definitions
370 //===----------------------------------------------------------------------===//
371 
372 #define GET_OP_CLASSES
373 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc"
374 
375 namespace mlir {
376 namespace TFR {
377 namespace {
378 class ConvertConstToTensorConst : public OpRewritePattern<ConstantTensorOp> {
379   using OpRewritePattern<ConstantTensorOp>::OpRewritePattern;
380 
381  public:
matchAndRewrite(ConstantTensorOp cst_tensor_op,PatternRewriter & rewriter) const382   LogicalResult matchAndRewrite(ConstantTensorOp cst_tensor_op,
383                                 PatternRewriter &rewriter) const override {
384     Location loc = cst_tensor_op.getLoc();
385     Type out_type = cst_tensor_op.getType();
386     Operation *new_cst = nullptr;
387 
388     ArrayAttr array;
389     if (matchPattern(cst_tensor_op.arg(), m_Constant(&array))) {
390       llvm::DenseSet<Type> all_types;
391       for (auto it : array) {
392         all_types.insert(it.getType());
393       }
394       if (all_types.size() != 1) return failure();
395       ShapedType new_out_type = RankedTensorType::get(
396           {static_cast<int64_t>(array.size())}, *all_types.begin());
397       DenseElementsAttr attr =
398           DenseElementsAttr::get(new_out_type, array.getValue());
399       new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, attr);
400       if (out_type.isa<TFRTensorType>()) {
401         new_cst = rewriter.create<CastOp>(loc, out_type, new_cst->getResult(0));
402       }
403       rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0));
404       return success();
405     }
406 
407     Attribute scalar;
408     if (matchPattern(cst_tensor_op.arg(), m_Constant(&scalar))) {
409       Type new_out_type = RankedTensorType::get({}, scalar.getType());
410       new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, scalar);
411       if (out_type.isa<TFRTensorType>()) {
412         new_cst = rewriter.create<CastOp>(loc, out_type, new_cst->getResult(0));
413       }
414       rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0));
415       return success();
416     }
417     return failure();
418   }
419 };
420 
421 class RemoveRedundantCast : public OpRewritePattern<CastOp> {
422   using OpRewritePattern<CastOp>::OpRewritePattern;
423 
424  public:
matchAndRewrite(CastOp cast_op,PatternRewriter & rewriter) const425   LogicalResult matchAndRewrite(CastOp cast_op,
426                                 PatternRewriter &rewriter) const override {
427     auto preceding_cast =
428         llvm::dyn_cast_or_null<CastOp>(cast_op.arg().getDefiningOp());
429     if (!preceding_cast) {
430       return failure();
431     }
432     Value input = preceding_cast.arg();
433     Type input_type = input.getType();
434     Type output_type = cast_op.getType();
435 
436     // Preserve quantization information for intermediate tensors.
437     auto intermediate_type = preceding_cast.getType().dyn_cast<TensorType>();
438     if (intermediate_type &&
439         intermediate_type.getElementType().isa<quant::QuantizedType>()) {
440       return failure();
441     }
442 
443     // If the two types are the same, the back-to-back tfr.cast ops can be
444     // removed.
445     if (input_type == output_type || output_type.isa<UnrankedTensorType>()) {
446       rewriter.replaceOp(cast_op, {input});
447       return success();
448     }
449 
450     // If the rank of the input tensor isn't ranked, we replace the pair
451     // with tf.EnsureShape op so it can be removed after shape inference or
452     // confirmed at runtime.
453     if (input_type.isa<UnrankedTensorType>() && output_type.isa<ShapedType>()) {
454       auto shape = output_type.cast<ShapedType>().getShape();
455       auto shape_attr = TF::ShapeAttr::get(rewriter.getContext(), shape);
456       rewriter.replaceOpWithNewOp<TF::EnsureShapeOp>(cast_op, output_type,
457                                                      input, shape_attr);
458     }
459 
460     return success();
461   }
462 };
463 
464 class GetTensorShape : public OpRewritePattern<GetShapeOp> {
465   using OpRewritePattern<GetShapeOp>::OpRewritePattern;
466 
467  public:
matchAndRewrite(GetShapeOp shape_op,PatternRewriter & rewriter) const468   LogicalResult matchAndRewrite(GetShapeOp shape_op,
469                                 PatternRewriter &rewriter) const override {
470     Operation *preceding_op = shape_op.arg().getDefiningOp();
471     if (auto cast_op = llvm::dyn_cast_or_null<CastOp>(preceding_op)) {
472       // replace this pair by shape.shape_of, so the folding works.
473       rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(shape_op, cast_op.arg());
474       return success();
475     }
476     return failure();
477   }
478 };
479 
480 class RemoveRedundantGetElement : public OpRewritePattern<GetElementOp> {
481   using OpRewritePattern<GetElementOp>::OpRewritePattern;
482 
483  public:
matchAndRewrite(GetElementOp ge_op,PatternRewriter & rewriter) const484   LogicalResult matchAndRewrite(GetElementOp ge_op,
485                                 PatternRewriter &rewriter) const override {
486     IntegerAttr index;
487     if (!matchPattern(ge_op.index(), m_Constant(&index))) {
488       return failure();
489     }
490     auto preceding_build_list = llvm::dyn_cast_or_null<BuildListOp>(
491         ge_op.tensor_list().getDefiningOp());
492     if (!preceding_build_list ||
493         preceding_build_list.getNumOperands() <= index.getInt()) {
494       return failure();
495     }
496     Value input = preceding_build_list.getOperand(index.getInt());
497     Type output_type = ge_op.getType();
498     if (input.getType() != output_type &&
499         !output_type.isa<UnrankedTensorType>()) {
500       return failure();
501     }
502     rewriter.replaceOp(ge_op, {input});
503     return success();
504   }
505 };
506 
507 class RemoveRedundantGetLength : public OpRewritePattern<GetLengthOp> {
508   using OpRewritePattern<GetLengthOp>::OpRewritePattern;
509 
510  public:
matchAndRewrite(GetLengthOp gl_op,PatternRewriter & rewriter) const511   LogicalResult matchAndRewrite(GetLengthOp gl_op,
512                                 PatternRewriter &rewriter) const override {
513     auto preceding_build_list = llvm::dyn_cast_or_null<BuildListOp>(
514         gl_op.tensor_list().getDefiningOp());
515     if (!preceding_build_list) {
516       return failure();
517     }
518     int64_t num_tensors = preceding_build_list.getNumOperands();
519     rewriter.replaceOpWithNewOp<ConstantOp>(gl_op,
520                                             rewriter.getIndexAttr(num_tensors));
521     return success();
522   }
523 };
524 
525 class BuildConstantListAsAttr : public OpRewritePattern<BuildListOp> {
526   using OpRewritePattern<BuildListOp>::OpRewritePattern;
527 
528  public:
matchAndRewrite(BuildListOp bl_op,PatternRewriter & rewriter) const529   LogicalResult matchAndRewrite(BuildListOp bl_op,
530                                 PatternRewriter &rewriter) const override {
531     SmallVector<Attribute, 4> array_list;
532     array_list.reserve(bl_op.getNumOperands());
533     for (const auto &operand : bl_op.getOperands()) {
534       Attribute array_elt;
535       if (!matchPattern(operand, m_Constant(&array_elt))) {
536         return failure();
537       }
538       array_list.push_back(array_elt);
539     }
540     auto array_attr = rewriter.getArrayAttr(array_list);
541     rewriter.replaceOpWithNewOp<TFR::ConstOp>(bl_op, array_attr);
542     return success();
543   }
544 };
545 
getQuantizedElementType(CastOp cast_op)546 quant::QuantizedType getQuantizedElementType(CastOp cast_op) {
547   if (!cast_op || !cast_op.getInputElementType()) {
548     return {};
549   }
550   return cast_op.getInputElementType()
551       .cast<TypeAttr>()
552       .getValue()
553       .dyn_cast<quant::QuantizedType>();
554 }
555 
556 class RemoveRawDataOp : public OpRewritePattern<TFRQuantRawDataOp> {
557   using OpRewritePattern<TFRQuantRawDataOp>::OpRewritePattern;
558 
559  public:
matchAndRewrite(TFRQuantRawDataOp raw_data_op,PatternRewriter & rewriter) const560   LogicalResult matchAndRewrite(TFRQuantRawDataOp raw_data_op,
561                                 PatternRewriter &rewriter) const override {
562     auto preceding_cast = dyn_cast<CastOp>(raw_data_op.input().getDefiningOp());
563     if (!getQuantizedElementType(preceding_cast)) {
564       return failure();
565     }
566     // If there are redundant casts, hoist output of raw data op originating op.
567     if (auto redundant_cast = preceding_cast.arg().getDefiningOp()) {
568       if (!isa<CastOp>(redundant_cast) ||
569           cast<CastOp>(redundant_cast).arg().getType() !=
570               preceding_cast.out().getType()) {
571         return failure();
572       }
573       raw_data_op.output().replaceAllUsesWith(
574           cast<CastOp>(redundant_cast).arg());
575     } else {
576       // If the argument of cast op is input, then simply remove the RawData op.
577       raw_data_op.output().replaceAllUsesWith(preceding_cast.out());
578     }
579     return success();
580   }
581 };
582 
583 class RemoveQParamsOp : public OpRewritePattern<TFRQuantQParamsOp> {
584   using OpRewritePattern<TFRQuantQParamsOp>::OpRewritePattern;
585 
586  public:
matchAndRewrite(TFRQuantQParamsOp qparams_op,PatternRewriter & rewriter) const587   LogicalResult matchAndRewrite(TFRQuantQParamsOp qparams_op,
588                                 PatternRewriter &rewriter) const override {
589     auto cast_op = dyn_cast<TFR::CastOp>(qparams_op.input().getDefiningOp());
590     auto cast_qtype = getQuantizedElementType(cast_op);
591     if (!cast_qtype) {
592       return failure();
593     }
594 
595     TF::ConstOp scale_op;
596     TF::ConstOp zp_op;
597 
598     // Reads quantization parameters from the quantized type, and converts
599     // them to constants.
600     rewriter.setInsertionPoint(qparams_op);
601     Location loc = qparams_op->getLoc();
602     if (auto qtype = cast_qtype.dyn_cast<quant::UniformQuantizedType>()) {
603       scale_op = rewriter.create<TF::ConstOp>(
604           loc, RankedTensorType::get({}, rewriter.getF32Type()),
605           rewriter.getF32FloatAttr(qtype.getScale()));
606       zp_op = rewriter.create<TF::ConstOp>(
607           loc, RankedTensorType::get({}, rewriter.getI32Type()),
608           rewriter.getI32IntegerAttr(qtype.getZeroPoint()));
609     } else if (auto qtype =
610                    cast_qtype.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
611       SmallVector<float> scales(qtype.getScales().begin(),
612                                 qtype.getScales().end());
613       SmallVector<int32_t> zps(qtype.getZeroPoints().begin(),
614                                qtype.getZeroPoints().end());
615       const size_t num_channels = qtype.getScales().size();
616 
617       auto scales_type = RankedTensorType::get(
618           {static_cast<int64_t>(num_channels)}, rewriter.getF32Type());
619       auto scales_attr =
620           DenseElementsAttr::get(scales_type, llvm::makeArrayRef(scales));
621       scale_op = rewriter.create<TF::ConstOp>(loc, scales_attr);
622 
623       auto zps_type = RankedTensorType::get(
624           {static_cast<int64_t>(num_channels)}, rewriter.getI32Type());
625       auto zps_attr = DenseElementsAttr::get(zps_type, llvm::makeArrayRef(zps));
626       zp_op = rewriter.create<TF::ConstOp>(loc, zps_attr);
627     }
628     if (!scale_op || !zp_op) {
629       return failure();
630     }
631     auto scale_cast = rewriter.create<CastOp>(loc, qparams_op.scale().getType(),
632                                               scale_op.output());
633     auto zp_cast =
634         rewriter.create<CastOp>(loc, qparams_op.zp().getType(), zp_op.output());
635 
636     qparams_op.scale().replaceAllUsesWith(scale_cast.out());
637     qparams_op.zp().replaceAllUsesWith(zp_cast.out());
638     return success();
639   }
640 };
641 
642 // TODO(b/193731721): Migrate tfr_ builtin canonicalizations to LowerTFROpPass
643 class RemoveScaleFactorOp : public OpRewritePattern<TFRQuantScaleFactorOp> {
644   using OpRewritePattern<TFRQuantScaleFactorOp>::OpRewritePattern;
645 
646  public:
647   // Replace quant_scale_factor with constant tensor equivalent to
648   // TFR_ConstantTensorOp (
649   //   ConstantOp (ConstAttr<F32Attr (in_scale[0] * in_scale[1] / out_scale))
650   // )
651   // Currently, all decompositions using this pattern (Conv2D, FC) have the
652   // following preconditions:
653   // * out_scale: float scalar attribute
654   // * in_scale[0] (input scale): float scalar, given by tf.Const -> tfr.cast
655   // * in_scale[1] (filter scale): float scalar/vector
656   //     (per-tensor vs per-channel) quantization, given by tf.Const -> tfr.cast
matchAndRewrite(TFRQuantScaleFactorOp scale_factor_op,PatternRewriter & rewriter) const657   LogicalResult matchAndRewrite(TFRQuantScaleFactorOp scale_factor_op,
658                                 PatternRewriter &rewriter) const override {
659     auto out_scale_op = scale_factor_op.out_scale().getDefiningOp<ConstantOp>();
660     if (!out_scale_op) {
661       return failure();
662     }
663     const double out_scale =
664         out_scale_op.value().cast<FloatAttr>().getValueAsDouble();
665 
666     auto in_scales_op =
667         scale_factor_op.in_scales().getDefiningOp<BuildListOp>();
668     if (!in_scales_op || in_scales_op.getNumOperands() != 2) {
669       // BuildListOp is variadic, but we require two values: input_scale
670       // and filter_scale.
671       return failure();
672     }
673 
674     auto in_scale_op = in_scales_op.getOperand(0).getDefiningOp<CastOp>();
675     if (!in_scale_op) {
676       return failure();
677     }
678 
679     DenseFPElementsAttr in_scale_attr;
680     if (!matchPattern(in_scale_op.arg(), m_Constant(&in_scale_attr)) ||
681         in_scale_attr.size() != 1) {
682       return failure();
683     }
684     const float in_scale = in_scale_attr.getValue<float>(0);
685     auto filter_scale_op = in_scales_op.getOperand(1).getDefiningOp<CastOp>();
686     if (!filter_scale_op) {
687       return failure();
688     }
689     DenseFPElementsAttr filter_scale_attr;
690     if (!matchPattern(filter_scale_op.arg(), m_Constant(&filter_scale_attr))) {
691       return failure();
692     }
693 
694     // The shape of scale_type is {} (rank 0) for per-tensor quantized tensor,
695     // and {num_channels} (rank 1) for per-channel quantized one.
696     auto scale_type = filter_scale_attr.getType().dyn_cast<RankedTensorType>();
697     if (scale_type.getRank() != 0 && scale_type.getRank() != 1) {
698       return failure();
699     }
700     SmallVector<float> scale_factors;
701     scale_factors.reserve(filter_scale_attr.size());
702     for (auto value : filter_scale_attr.getFloatValues()) {
703       scale_factors.push_back(in_scale * value.convertToFloat() / out_scale);
704     }
705     rewriter.setInsertionPoint(scale_factor_op);
706     const Location loc = scale_factor_op->getLoc();
707     auto result_scale_op = rewriter.create<TF::ConstOp>(
708         loc,
709         DenseElementsAttr::get(scale_type, llvm::makeArrayRef(scale_factors)));
710     auto result_scale_cast_op = rewriter.create<CastOp>(
711         loc, scale_factor_op.getType(), result_scale_op.output());
712     scale_factor_op.scale_factor().replaceAllUsesWith(
713         result_scale_cast_op.out());
714     return success();
715   }
716 };
717 
718 class RemoveRescaleOp : public OpRewritePattern<TFRQuantRescaleOp> {
719   using OpRewritePattern<TFRQuantRescaleOp>::OpRewritePattern;
720 
721  public:
722   // Replace quant_rescale (input, scale, zp) with
723   // tf.Cast(tf.Round(tf.Cast(input, f32) * scale) + tf.Cast(zp, f32), i32)
matchAndRewrite(TFRQuantRescaleOp rescale_op,PatternRewriter & rewriter) const724   LogicalResult matchAndRewrite(TFRQuantRescaleOp rescale_op,
725                                 PatternRewriter &rewriter) const override {
726     Value input = rescale_op.input();
727     Value scale = rescale_op.scale();
728     Value zp = rescale_op.zp();
729 
730     const Location loc = rescale_op->getLoc();
731     const auto result_types = rescale_op->getResultTypes();
732     auto c_false =
733         rewriter.create<ConstantOp>(loc, rewriter.getBoolAttr(false));
734     TypeAttr f32_attr = TypeAttr::get(rewriter.getF32Type());
735     TFRAttrType output_type = TFRAttrType::get(rewriter.getContext());
736     auto constant_f32_op = rewriter.create<ConstOp>(loc, output_type, f32_attr);
737     TypeAttr i32_attr = TypeAttr::get(rewriter.getI32Type());
738     auto constant_i32_op = rewriter.create<ConstOp>(loc, output_type, i32_attr);
739 
740     IntegerAttr zp_attr;
741     if (!matchPattern(zp, m_Constant(&zp_attr))) {
742       return failure();
743     }
744     rewriter.setInsertionPoint(zp.getDefiningOp());
745     auto zp_tensor = rewriter.create<TF::ConstOp>(
746         loc, RankedTensorType::get({}, zp.getType()), zp_attr);
747     auto zp_cast = rewriter.create<CastOp>(
748         loc, rewriter.getType<TFRTensorType>(), zp_tensor.output());
749 
750     rewriter.setInsertionPoint(rescale_op);
751     auto cast_input_to_float_op = rewriter.create<CallOp>(
752         loc, result_types, rewriter.getSymbolRefAttr("tf__cast"),
753         ArrayRef<Value>{input, constant_f32_op, c_false});
754     auto input_x_scale_op = rewriter.create<CallOp>(
755         loc, result_types, rewriter.getSymbolRefAttr("tf__mul"),
756         ArrayRef<Value>{cast_input_to_float_op.getResult(0), scale});
757     auto round_rescaled_op = rewriter.create<CallOp>(
758         loc, result_types, rewriter.getSymbolRefAttr("tf__round"),
759         ArrayRef<Value>{input_x_scale_op->getResult(0)});
760     auto cast_zp_to_float_op = rewriter.create<CallOp>(
761         loc, result_types, rewriter.getSymbolRefAttr("tf__cast"),
762         ArrayRef<Value>{zp_cast, constant_f32_op, c_false});
763     auto recentered_op = rewriter.create<CallOp>(
764         loc, result_types, rewriter.getSymbolRefAttr("tf__add"),
765         ArrayRef<Value>{round_rescaled_op->getResult(0),
766                         cast_zp_to_float_op->getResult(0)});
767     auto cast_output_to_i32 = rewriter.create<CallOp>(
768         loc, result_types, rewriter.getSymbolRefAttr("tf__cast"),
769         ArrayRef<Value>{recentered_op->getResult(0), constant_i32_op, c_false});
770     rescale_op.output().replaceAllUsesWith(cast_output_to_i32.getResult(0));
771     return success();
772   }
773 };
774 
775 }  // namespace
776 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)777 void ConstantTensorOp::getCanonicalizationPatterns(
778     OwningRewritePatternList &results, MLIRContext *context) {
779   results.insert<ConvertConstToTensorConst>(context);
780 }
781 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)782 void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
783                                          MLIRContext *context) {
784   results.insert<RemoveRedundantCast>(context);
785 }
786 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)787 void GetShapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
788                                              MLIRContext *context) {
789   results.insert<GetTensorShape>(context);
790 }
791 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)792 void GetElementOp::getCanonicalizationPatterns(
793     OwningRewritePatternList &results, MLIRContext *context) {
794   results.insert<RemoveRedundantGetElement>(context);
795 }
796 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)797 void GetLengthOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
798                                               MLIRContext *context) {
799   results.insert<RemoveRedundantGetLength>(context);
800 }
801 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)802 void BuildListOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
803                                               MLIRContext *context) {
804   results.insert<BuildConstantListAsAttr>(context);
805 }
806 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)807 void TFRQuantRawDataOp::getCanonicalizationPatterns(
808     OwningRewritePatternList &results, MLIRContext *context) {
809   results.insert<RemoveRawDataOp>(context);
810 }
811 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)812 void TFRQuantQParamsOp::getCanonicalizationPatterns(
813     OwningRewritePatternList &results, MLIRContext *context) {
814   results.insert<RemoveQParamsOp>(context);
815 }
816 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)817 void TFRQuantRescaleOp::getCanonicalizationPatterns(
818     OwningRewritePatternList &results, MLIRContext *context) {
819   results.insert<RemoveRescaleOp>(context);
820 }
821 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)822 void TFRQuantScaleFactorOp::getCanonicalizationPatterns(
823     OwningRewritePatternList &results, MLIRContext *context) {
824   results.insert<RemoveScaleFactorOp>(context);
825 }
826 
fold(ArrayRef<Attribute> operands)827 OpFoldResult TFR::EqualOp::fold(ArrayRef<Attribute> operands) {
828   assert(operands.size() == 2 && "equal op has two operands");
829   auto ctx = getContext();
830   if (operands[0] == operands[1]) return BoolAttr::get(ctx, true);
831   return BoolAttr::get(ctx, false);
832 }
833 
fold(ArrayRef<Attribute> operands)834 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
835   assert(operands.empty() && "constant has no operands");
836 
837   // Return the held attribute value.
838   return value();
839 }
840 
841 // CallableOpInterface
getCallableRegion()842 Region *TFRFuncOp::getCallableRegion() {
843   return isExternal() ? nullptr : &body().front();
844 }
845 
846 // CallableOpInterface
getCallableResults()847 ArrayRef<Type> TFRFuncOp::getCallableResults() {
848   return getType().getResults();
849 }
850 
851 //===----------------------------------------------------------------------===//
852 // Dialect type definitions
853 //===----------------------------------------------------------------------===//
854 
855 // Parses a TFR type.
856 //   tfr_type ::= tensor_type | tensor_list_type | attr_type
857 //   string_list ::= `[` string-literal (, string-literal)+ `]`
858 //   tensor_type ::= `tensor`
859 //                 | `tensor<` (string-literal | string_list)  '>'
860 //   tensor_list_type ::= `tensor_list`
861 //                      | `tensor_list<` (string-literal | string_list)  '>'
862 //   attr_type ::= `attr`
parseType(DialectAsmParser & parser) const863 Type TFRDialect::parseType(DialectAsmParser &parser) const {
864   Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
865   MLIRContext *ctx = loc.getContext();
866 
867   StringRef typeNameSpelling;
868   if (failed(parser.parseKeyword(&typeNameSpelling))) return {};
869   llvm::SmallVector<StringAttr, 4> attrs;
870   if (succeeded(parser.parseOptionalLess())) {
871     bool l_square_parsed = false;
872     if (succeeded(parser.parseOptionalLSquare())) {
873       l_square_parsed = true;
874     }
875 
876     do {
877       StringRef attr;
878       if (failed(parser.parseKeyword(&attr))) return {};
879       attrs.push_back(StringAttr::get(ctx, attr));
880     } while (succeeded(parser.parseOptionalComma()));
881 
882     if (l_square_parsed && failed(parser.parseRSquare())) {
883       parser.emitError(parser.getNameLoc(), "expected ']'");
884     }
885 
886     if (failed(parser.parseGreater())) {
887       parser.emitError(parser.getNameLoc(), "expected '>'");
888     }
889   }
890 
891   if (typeNameSpelling == "tensor") {
892     return TFRTensorType::getChecked(attrs, loc);
893   } else if (typeNameSpelling == "tensor_list") {
894     return TFRTensorListType::getChecked(attrs, loc);
895   } else if (typeNameSpelling == "attr") {
896     return TFRAttrType::getChecked(loc, loc.getContext());
897   } else {
898     parser.emitError(parser.getNameLoc(), "unknown type " + typeNameSpelling);
899     return {};
900   }
901 }
902 
printType(Type type,DialectAsmPrinter & os) const903 void TFRDialect::printType(Type type, DialectAsmPrinter &os) const {
904   llvm::ArrayRef<StringAttr> attrs;
905 
906   if (type.isa<TFRAttrType>()) {
907     os << "attr";
908     return;
909   }
910   if (auto tensor_ty = type.dyn_cast<TFRTensorType>()) {
911     attrs = tensor_ty.getAttrKeys();
912     os << "tensor";
913   } else if (auto tensor_list_ty = type.dyn_cast<TFRTensorListType>()) {
914     attrs = tensor_list_ty.getAttrKeys();
915     os << "tensor_list";
916   } else {
917     llvm_unreachable("Unhandled tfr type");
918   }
919 
920   if (attrs.empty()) return;
921   os << "<";
922 
923   if (attrs.size() > 1) {
924     os << "[";
925   }
926 
927   llvm::interleaveComma(attrs, os,
928                         [&](StringAttr attr) { os << attr.getValue(); });
929 
930   if (attrs.size() > 1) {
931     os << "]";
932   }
933   os << ">";
934 }
935 
936 }  // namespace TFR
937 }  // namespace mlir
938