1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // The TF dialect uses some TF types that are illegal in the MHLO dialect and
17 // some generic types that are legal in MHLO. This pass legalizes TF types into
18 // types that are legal in MHLO. For example, TF::Qint8Type is converted to i8.
19 // Rewrites here should run before TF to MHLO op legalizations are run.
20 // TODO(b/180234029): The rewrite here should be part of the LegalizeTF pass
21 // rather than its own pass.
22
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/PatternMatch.h" // from @llvm-project
28 #include "mlir/Pass/Pass.h" // from @llvm-project
29 #include "mlir/Support/LLVM.h" // from @llvm-project
30 #include "mlir/Support/LogicalResult.h" // from @llvm-project
31 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
33 #include "tensorflow/compiler/mlir/xla/transforms/passes_detail.h"
34
35 #define DEBUG_TYPE "xla-legalize-tf-types"
36
37 namespace mlir {
38 namespace mhlo {
39 namespace {
40
isIllegalElementType(Type type)41 bool isIllegalElementType(Type type) {
42 return type
43 .isa<mlir::TF::Qint8Type, mlir::TF::Qint16Type, mlir::TF::Qint32Type,
44 mlir::TF::Quint8Type, mlir::TF::Quint16Type>();
45 }
46
replaceElementType(Type type)47 Type replaceElementType(Type type) {
48 return TypeSwitch<Type, Type>(type)
49 .Case<mlir::TF::Qint8Type>([&type](Type) {
50 return mlir::IntegerType::get(type.getContext(), 8);
51 })
52 .Case<mlir::TF::Qint16Type>([&type](Type) {
53 return mlir::IntegerType::get(type.getContext(), 16);
54 })
55 .Case<mlir::TF::Qint32Type>([&type](Type) {
56 return mlir::IntegerType::get(type.getContext(), 32);
57 })
58 .Case<mlir::TF::Quint8Type>([&type](Type) {
59 return mlir::IntegerType::get(
60 type.getContext(), 8,
61 mlir::IntegerType::SignednessSemantics::Unsigned);
62 })
63 .Case<mlir::TF::Quint16Type>([&type](Type) {
64 return mlir::IntegerType::get(
65 type.getContext(), 16,
66 mlir::IntegerType::SignednessSemantics::Unsigned);
67 })
68 .Default([&type](Type) { return type; });
69 }
70
71 // TODO(b/180234863): What's below this line is generic so convert it to a
72 // utility.
73
isIllegalType(Type type)74 bool isIllegalType(Type type) {
75 if (isIllegalElementType(type)) return true;
76 if (auto shaped = type.dyn_cast<ShapedType>())
77 return isIllegalType(shaped.getElementType());
78 return false;
79 }
80
replaceType(Type type)81 Type replaceType(Type type) {
82 if (isIllegalElementType(type)) return replaceElementType(type);
83 if (auto shaped = type.dyn_cast<ShapedType>()) {
84 Type elem = shaped.getElementType();
85 if (isIllegalType(elem)) return shaped.clone(replaceType(elem));
86 }
87 return type;
88 }
89
90 // An Op is illegal iff it contains an illegalType.
91 class TfTypeConversionTarget : public ConversionTarget {
92 public:
TfTypeConversionTarget(MLIRContext & ctx)93 explicit TfTypeConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
94 markUnknownOpDynamicallyLegal();
95 }
96
97 protected:
isDynamicallyLegal(Operation * op) const98 bool isDynamicallyLegal(Operation *op) const override {
99 // The FuncOp type can contain types that the op's operand and result types
100 // do not contain.
101 if (auto func = dyn_cast<FuncOp>(op)) {
102 if (llvm::any_of(func.getType().getInputs(), isIllegalType) ||
103 llvm::any_of(func.getType().getResults(), isIllegalType))
104 return false;
105 }
106 if (llvm::any_of(op->getOperandTypes(), isIllegalType) ||
107 llvm::any_of(op->getResultTypes(), isIllegalType))
108 return false;
109 return true;
110 }
111 };
112
113 class TfTypeConverter : public TypeConverter {
114 public:
TfTypeConverter()115 TfTypeConverter() {
116 addConversion([](Type type) -> Type {
117 if (isIllegalType(type))
118 return replaceType(type);
119 else
120 return type;
121 });
122 }
123 };
124
125 class TfTypePattern : public ConversionPattern {
126 public:
TfTypePattern(MLIRContext * ctx,TypeConverter & converter)127 TfTypePattern(MLIRContext *ctx, TypeConverter &converter)
128 : ConversionPattern(1, converter, MatchAnyOpTypeTag()) {}
129
130 // The dialect conversion framework will call this matchAndRewrite on each
131 // Operation in the IR tree. This call matchAndRewrite needs to update the
132 // Operation's results and child regions.
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const133 LogicalResult matchAndRewrite(
134 Operation *op, ArrayRef<Value> operands,
135 ConversionPatternRewriter &rewriter) const override {
136 // Update the results.
137 llvm::SmallVector<Type, 4> new_results;
138 if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
139 new_results)))
140 return failure();
141
142 // Update the regions. The dialect conversion framework wants new regions to
143 // be created and updated, rather than updating the old op. Thus we use an
144 // OperationState so we can add regions to the new up.
145 OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
146 new_results, op->getAttrs(), op->getSuccessors());
147 for (Region ®ion : op->getRegions()) {
148 Region &new_region = *state.addRegion();
149 rewriter.inlineRegionBefore(region, new_region, new_region.begin());
150 if (failed(rewriter.convertRegionTypes(&new_region, *getTypeConverter())))
151 return failure();
152 }
153 rewriter.replaceOp(op, rewriter.createOperation(state)->getResults());
154
155 return success();
156 }
157 };
158
159 struct LegalizeTfTypesPass
160 : public LegalizeTfTypesPassBase<LegalizeTfTypesPass> {
161 void runOnOperation() override;
162 };
163
runOnOperation()164 void LegalizeTfTypesPass::runOnOperation() {
165 TfTypeConverter converter;
166 OwningRewritePatternList patterns;
167 patterns.insert<TfTypePattern>(&getContext(), converter);
168 populateFuncOpTypeConversionPattern(patterns, &getContext(), converter);
169 TfTypeConversionTarget target(getContext());
170 if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
171 return signalPassFailure();
172 }
173
174 static PassRegistration<LegalizeTfTypesPass> registration(
175 "xla-legalize-tf-types",
176 "Replace TensorFlow types with types that are legal in the MHLO dialect");
177
178 } // namespace
179
CreateLegalizeTfTypesPass()180 std::unique_ptr<OperationPass<>> CreateLegalizeTfTypesPass() {
181 return std::make_unique<LegalizeTfTypesPass>();
182 }
183
184 } // namespace mhlo
185 } // namespace mlir
186