• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The TF dialect uses some TF types that are illegal in the MHLO dialect and
17 // some generic types that are legal in MHLO. This pass legalizes TF types into
18 // types that are legal in MHLO. For example, TF::Qint8Type is converted to i8.
19 // Rewrites here should run before TF to MHLO op legalizations are run.
20 // TODO(b/180234029): The rewrite here should be part of the LegalizeTF pass
21 // rather than its own pass.
22 
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
28 #include "mlir/Pass/Pass.h"  // from @llvm-project
29 #include "mlir/Support/LLVM.h"  // from @llvm-project
30 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
31 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
33 #include "tensorflow/compiler/mlir/xla/transforms/passes_detail.h"
34 
35 #define DEBUG_TYPE "xla-legalize-tf-types"
36 
37 namespace mlir {
38 namespace mhlo {
39 namespace {
40 
isIllegalElementType(Type type)41 bool isIllegalElementType(Type type) {
42   return type
43       .isa<mlir::TF::Qint8Type, mlir::TF::Qint16Type, mlir::TF::Qint32Type,
44            mlir::TF::Quint8Type, mlir::TF::Quint16Type>();
45 }
46 
replaceElementType(Type type)47 Type replaceElementType(Type type) {
48   return TypeSwitch<Type, Type>(type)
49       .Case<mlir::TF::Qint8Type>([&type](Type) {
50         return mlir::IntegerType::get(type.getContext(), 8);
51       })
52       .Case<mlir::TF::Qint16Type>([&type](Type) {
53         return mlir::IntegerType::get(type.getContext(), 16);
54       })
55       .Case<mlir::TF::Qint32Type>([&type](Type) {
56         return mlir::IntegerType::get(type.getContext(), 32);
57       })
58       .Case<mlir::TF::Quint8Type>([&type](Type) {
59         return mlir::IntegerType::get(
60             type.getContext(), 8,
61             mlir::IntegerType::SignednessSemantics::Unsigned);
62       })
63       .Case<mlir::TF::Quint16Type>([&type](Type) {
64         return mlir::IntegerType::get(
65             type.getContext(), 16,
66             mlir::IntegerType::SignednessSemantics::Unsigned);
67       })
68       .Default([&type](Type) { return type; });
69 }
70 
71 // TODO(b/180234863): What's below this line is generic so convert it to a
72 // utility.
73 
isIllegalType(Type type)74 bool isIllegalType(Type type) {
75   if (isIllegalElementType(type)) return true;
76   if (auto shaped = type.dyn_cast<ShapedType>())
77     return isIllegalType(shaped.getElementType());
78   return false;
79 }
80 
replaceType(Type type)81 Type replaceType(Type type) {
82   if (isIllegalElementType(type)) return replaceElementType(type);
83   if (auto shaped = type.dyn_cast<ShapedType>()) {
84     Type elem = shaped.getElementType();
85     if (isIllegalType(elem)) return shaped.clone(replaceType(elem));
86   }
87   return type;
88 }
89 
90 // An Op is illegal iff it contains an illegalType.
91 class TfTypeConversionTarget : public ConversionTarget {
92  public:
TfTypeConversionTarget(MLIRContext & ctx)93   explicit TfTypeConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
94     markUnknownOpDynamicallyLegal();
95   }
96 
97  protected:
isDynamicallyLegal(Operation * op) const98   bool isDynamicallyLegal(Operation *op) const override {
99     // The FuncOp type can contain types that the op's operand and result types
100     // do not contain.
101     if (auto func = dyn_cast<FuncOp>(op)) {
102       if (llvm::any_of(func.getType().getInputs(), isIllegalType) ||
103           llvm::any_of(func.getType().getResults(), isIllegalType))
104         return false;
105     }
106     if (llvm::any_of(op->getOperandTypes(), isIllegalType) ||
107         llvm::any_of(op->getResultTypes(), isIllegalType))
108       return false;
109     return true;
110   }
111 };
112 
113 class TfTypeConverter : public TypeConverter {
114  public:
TfTypeConverter()115   TfTypeConverter() {
116     addConversion([](Type type) -> Type {
117       if (isIllegalType(type))
118         return replaceType(type);
119       else
120         return type;
121     });
122   }
123 };
124 
125 class TfTypePattern : public ConversionPattern {
126  public:
TfTypePattern(MLIRContext * ctx,TypeConverter & converter)127   TfTypePattern(MLIRContext *ctx, TypeConverter &converter)
128       : ConversionPattern(1, converter, MatchAnyOpTypeTag()) {}
129 
130   // The dialect conversion framework will call this matchAndRewrite on each
131   // Operation in the IR tree. This call matchAndRewrite needs to update the
132   // Operation's results and child regions.
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const133   LogicalResult matchAndRewrite(
134       Operation *op, ArrayRef<Value> operands,
135       ConversionPatternRewriter &rewriter) const override {
136     // Update the results.
137     llvm::SmallVector<Type, 4> new_results;
138     if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
139                                                 new_results)))
140       return failure();
141 
142     // Update the regions. The dialect conversion framework wants new regions to
143     // be created and updated, rather than updating the old op. Thus we use an
144     // OperationState so we can add regions to the new up.
145     OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
146                          new_results, op->getAttrs(), op->getSuccessors());
147     for (Region &region : op->getRegions()) {
148       Region &new_region = *state.addRegion();
149       rewriter.inlineRegionBefore(region, new_region, new_region.begin());
150       if (failed(rewriter.convertRegionTypes(&new_region, *getTypeConverter())))
151         return failure();
152     }
153     rewriter.replaceOp(op, rewriter.createOperation(state)->getResults());
154 
155     return success();
156   }
157 };
158 
159 struct LegalizeTfTypesPass
160     : public LegalizeTfTypesPassBase<LegalizeTfTypesPass> {
161   void runOnOperation() override;
162 };
163 
runOnOperation()164 void LegalizeTfTypesPass::runOnOperation() {
165   TfTypeConverter converter;
166   OwningRewritePatternList patterns;
167   patterns.insert<TfTypePattern>(&getContext(), converter);
168   populateFuncOpTypeConversionPattern(patterns, &getContext(), converter);
169   TfTypeConversionTarget target(getContext());
170   if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
171     return signalPassFailure();
172 }
173 
174 static PassRegistration<LegalizeTfTypesPass> registration(
175     "xla-legalize-tf-types",
176     "Replace TensorFlow types with types that are legal in the MHLO dialect");
177 
178 }  // namespace
179 
CreateLegalizeTfTypesPass()180 std::unique_ptr<OperationPass<>> CreateLegalizeTfTypesPass() {
181   return std::make_unique<LegalizeTfTypesPass>();
182 }
183 
184 }  // namespace mhlo
185 }  // namespace mlir
186