• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The TF dialect uses some TF types that are illegal in the MHLO dialect and
17 // some generic types that are legal in MHLO. This pass legalizes TF types into
18 // types that are legal in MHLO. For example, TF::Qint8Type is converted to i8.
19 // Rewrites here should run before TF to MHLO op legalizations are run.
20 // TODO(b/180234029): The rewrite here should be part of the LegalizeTF pass
21 // rather than its own pass.
22 
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
28 #include "mlir/Pass/Pass.h"  // from @llvm-project
29 #include "mlir/Support/LLVM.h"  // from @llvm-project
30 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
31 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
33 #include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes_detail.h"
34 
35 #define DEBUG_TYPE "xla-legalize-tf-types"
36 
37 namespace mlir {
38 namespace mhlo {
39 namespace {
40 
IsIllegalElementType(Type type)41 bool IsIllegalElementType(Type type) {
42   return type
43       .isa<mlir::TF::Qint8Type, mlir::TF::Qint16Type, mlir::TF::Qint32Type,
44            mlir::TF::Quint8Type, mlir::TF::Quint16Type>();
45 }
46 
ToLegalElementType(Type type)47 Type ToLegalElementType(Type type) {
48   return TypeSwitch<Type, Type>(type)
49       .Case<mlir::TF::Qint8Type>([&type](Type) {
50         return mlir::IntegerType::get(type.getContext(), 8);
51       })
52       .Case<mlir::TF::Qint16Type>([&type](Type) {
53         return mlir::IntegerType::get(type.getContext(), 16);
54       })
55       .Case<mlir::TF::Qint32Type>([&type](Type) {
56         return mlir::IntegerType::get(type.getContext(), 32);
57       })
58       .Case<mlir::TF::Quint8Type>([&type](Type) {
59         return mlir::IntegerType::get(
60             type.getContext(), 8,
61             mlir::IntegerType::SignednessSemantics::Unsigned);
62       })
63       .Case<mlir::TF::Quint16Type>([&type](Type) {
64         return mlir::IntegerType::get(
65             type.getContext(), 16,
66             mlir::IntegerType::SignednessSemantics::Unsigned);
67       })
68       .Default([&type](Type) { return type; });
69 }
70 
71 // TODO(b/180234863): What's below this line is generic so convert it to a
72 // utility.
73 
IsIllegalType(Type type)74 bool IsIllegalType(Type type) {
75   return IsIllegalElementType(getElementTypeOrSelf(type));
76 }
77 
ToLegalType(Type type)78 Type ToLegalType(Type type) {
79   if (IsIllegalElementType(type)) return ToLegalElementType(type);
80   if (auto shaped = type.dyn_cast<ShapedType>()) {
81     Type elem = shaped.getElementType();
82     if (IsIllegalType(elem)) return shaped.clone(ToLegalType(elem));
83   }
84   return type;
85 }
86 
87 class TfTypeConverter : public TypeConverter {
88  public:
TfTypeConverter()89   TfTypeConverter() {
90     addConversion([](Type type) -> Type {
91       return IsIllegalType(type) ? ToLegalType(type) : type;
92     });
93   }
94 };
95 
96 // An Op is illegal iff it contains an illegalType.
97 class TfTypeConversionTarget : public ConversionTarget {
98  public:
TfTypeConversionTarget(MLIRContext & ctx,TfTypeConverter & converter)99   explicit TfTypeConversionTarget(MLIRContext &ctx, TfTypeConverter &converter)
100       : ConversionTarget(ctx), converter_(converter) {
101     markUnknownOpDynamicallyLegal([this](Operation *op) {
102       // The FuncOp type can contain types that the op's operand and result
103       // types do not contain.
104       if (auto func = dyn_cast<FuncOp>(op)) {
105         if (!converter_.isSignatureLegal(func.getType())) return false;
106       }
107       return converter_.isLegal(op);
108     });
109   }
110 
111  private:
112   TfTypeConverter &converter_;
113 };
114 
115 class TfTypePattern : public ConversionPattern {
116  public:
TfTypePattern(MLIRContext * ctx,TypeConverter & converter)117   TfTypePattern(MLIRContext *ctx, TypeConverter &converter)
118       : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, ctx) {}
119 
120   // The dialect conversion framework will call this matchAndRewrite on each
121   // Operation in the IR tree. This call matchAndRewrite needs to update the
122   // Operation's results and child regions.
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const123   LogicalResult matchAndRewrite(
124       Operation *op, ArrayRef<Value> operands,
125       ConversionPatternRewriter &rewriter) const override {
126     // Update the results.
127     llvm::SmallVector<Type, 4> new_results;
128     if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
129                                                 new_results)))
130       return failure();
131 
132     // Update the regions. The dialect conversion framework wants new regions to
133     // be created and updated, rather than updating the old op. Thus we use an
134     // OperationState so we can add regions to the new up.
135     OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
136                          new_results, op->getAttrs(), op->getSuccessors());
137     for (Region &region : op->getRegions()) {
138       Region &new_region = *state.addRegion();
139       rewriter.inlineRegionBefore(region, new_region, new_region.begin());
140       if (failed(rewriter.convertRegionTypes(&new_region, *getTypeConverter())))
141         return failure();
142     }
143     rewriter.replaceOp(op, rewriter.createOperation(state)->getResults());
144 
145     return success();
146   }
147 };
148 
149 struct LegalizeTfTypesPass
150     : public LegalizeTfTypesPassBase<LegalizeTfTypesPass> {
151   void runOnOperation() override;
152 };
153 
runOnOperation()154 void LegalizeTfTypesPass::runOnOperation() {
155   TfTypeConverter converter;
156   OwningRewritePatternList patterns(&getContext());
157   patterns.insert<TfTypePattern>(&getContext(), converter);
158   populateFuncOpTypeConversionPattern(patterns, converter);
159   TfTypeConversionTarget target(getContext(), converter);
160   if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
161     return signalPassFailure();
162 }
163 
164 }  // namespace
165 
CreateLegalizeTfTypesPass()166 std::unique_ptr<OperationPass<>> CreateLegalizeTfTypesPass() {
167   return std::make_unique<LegalizeTfTypesPass>();
168 }
169 
170 }  // namespace mhlo
171 }  // namespace mlir
172