1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // The TF dialect uses some TF types that are illegal in the MHLO dialect and
17 // some generic types that are legal in MHLO. This pass legalizes TF types into
18 // types that are legal in MHLO. For example, TF::Qint8Type is converted to i8.
19 // Rewrites here should run before TF to MHLO op legalizations are run.
20 // TODO(b/180234029): The rewrite here should be part of the LegalizeTF pass
21 // rather than its own pass.
22
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/PatternMatch.h" // from @llvm-project
28 #include "mlir/Pass/Pass.h" // from @llvm-project
29 #include "mlir/Support/LLVM.h" // from @llvm-project
30 #include "mlir/Support/LogicalResult.h" // from @llvm-project
31 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
33 #include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes_detail.h"
34
35 #define DEBUG_TYPE "xla-legalize-tf-types"
36
37 namespace mlir {
38 namespace mhlo {
39 namespace {
40
IsIllegalElementType(Type type)41 bool IsIllegalElementType(Type type) {
42 return type
43 .isa<mlir::TF::Qint8Type, mlir::TF::Qint16Type, mlir::TF::Qint32Type,
44 mlir::TF::Quint8Type, mlir::TF::Quint16Type>();
45 }
46
ToLegalElementType(Type type)47 Type ToLegalElementType(Type type) {
48 return TypeSwitch<Type, Type>(type)
49 .Case<mlir::TF::Qint8Type>([&type](Type) {
50 return mlir::IntegerType::get(type.getContext(), 8);
51 })
52 .Case<mlir::TF::Qint16Type>([&type](Type) {
53 return mlir::IntegerType::get(type.getContext(), 16);
54 })
55 .Case<mlir::TF::Qint32Type>([&type](Type) {
56 return mlir::IntegerType::get(type.getContext(), 32);
57 })
58 .Case<mlir::TF::Quint8Type>([&type](Type) {
59 return mlir::IntegerType::get(
60 type.getContext(), 8,
61 mlir::IntegerType::SignednessSemantics::Unsigned);
62 })
63 .Case<mlir::TF::Quint16Type>([&type](Type) {
64 return mlir::IntegerType::get(
65 type.getContext(), 16,
66 mlir::IntegerType::SignednessSemantics::Unsigned);
67 })
68 .Default([&type](Type) { return type; });
69 }
70
71 // TODO(b/180234863): What's below this line is generic so convert it to a
72 // utility.
73
IsIllegalType(Type type)74 bool IsIllegalType(Type type) {
75 return IsIllegalElementType(getElementTypeOrSelf(type));
76 }
77
ToLegalType(Type type)78 Type ToLegalType(Type type) {
79 if (IsIllegalElementType(type)) return ToLegalElementType(type);
80 if (auto shaped = type.dyn_cast<ShapedType>()) {
81 Type elem = shaped.getElementType();
82 if (IsIllegalType(elem)) return shaped.clone(ToLegalType(elem));
83 }
84 return type;
85 }
86
87 class TfTypeConverter : public TypeConverter {
88 public:
TfTypeConverter()89 TfTypeConverter() {
90 addConversion([](Type type) -> Type {
91 return IsIllegalType(type) ? ToLegalType(type) : type;
92 });
93 }
94 };
95
96 // An Op is illegal iff it contains an illegalType.
97 class TfTypeConversionTarget : public ConversionTarget {
98 public:
TfTypeConversionTarget(MLIRContext & ctx,TfTypeConverter & converter)99 explicit TfTypeConversionTarget(MLIRContext &ctx, TfTypeConverter &converter)
100 : ConversionTarget(ctx), converter_(converter) {
101 markUnknownOpDynamicallyLegal([this](Operation *op) {
102 // The FuncOp type can contain types that the op's operand and result
103 // types do not contain.
104 if (auto func = dyn_cast<FuncOp>(op)) {
105 if (!converter_.isSignatureLegal(func.getType())) return false;
106 }
107 return converter_.isLegal(op);
108 });
109 }
110
111 private:
112 TfTypeConverter &converter_;
113 };
114
115 class TfTypePattern : public ConversionPattern {
116 public:
TfTypePattern(MLIRContext * ctx,TypeConverter & converter)117 TfTypePattern(MLIRContext *ctx, TypeConverter &converter)
118 : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, ctx) {}
119
120 // The dialect conversion framework will call this matchAndRewrite on each
121 // Operation in the IR tree. This call matchAndRewrite needs to update the
122 // Operation's results and child regions.
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const123 LogicalResult matchAndRewrite(
124 Operation *op, ArrayRef<Value> operands,
125 ConversionPatternRewriter &rewriter) const override {
126 // Update the results.
127 llvm::SmallVector<Type, 4> new_results;
128 if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
129 new_results)))
130 return failure();
131
132 // Update the regions. The dialect conversion framework wants new regions to
133 // be created and updated, rather than updating the old op. Thus we use an
134 // OperationState so we can add regions to the new up.
135 OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
136 new_results, op->getAttrs(), op->getSuccessors());
137 for (Region ®ion : op->getRegions()) {
138 Region &new_region = *state.addRegion();
139 rewriter.inlineRegionBefore(region, new_region, new_region.begin());
140 if (failed(rewriter.convertRegionTypes(&new_region, *getTypeConverter())))
141 return failure();
142 }
143 rewriter.replaceOp(op, rewriter.createOperation(state)->getResults());
144
145 return success();
146 }
147 };
148
149 struct LegalizeTfTypesPass
150 : public LegalizeTfTypesPassBase<LegalizeTfTypesPass> {
151 void runOnOperation() override;
152 };
153
runOnOperation()154 void LegalizeTfTypesPass::runOnOperation() {
155 TfTypeConverter converter;
156 OwningRewritePatternList patterns(&getContext());
157 patterns.insert<TfTypePattern>(&getContext(), converter);
158 populateFuncOpTypeConversionPattern(patterns, converter);
159 TfTypeConversionTarget target(getContext(), converter);
160 if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
161 return signalPassFailure();
162 }
163
164 } // namespace
165
CreateLegalizeTfTypesPass()166 std::unique_ptr<OperationPass<>> CreateLegalizeTfTypesPass() {
167 return std::make_unique<LegalizeTfTypesPass>();
168 }
169
170 } // namespace mhlo
171 } // namespace mlir
172