1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for flattening tuples in HLO ops.
17
18 #include <cassert>
19 #include <string>
20 #include <utility>
21
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/MapVector.h"
24 #include "llvm/ADT/SmallSet.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
28 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
29 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
30 #include "mlir/IR/BlockAndValueMapping.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/Location.h"
33 #include "mlir/IR/Operation.h"
34 #include "mlir/IR/Region.h"
35 #include "mlir/IR/Value.h"
36 #include "mlir/Transforms/DialectConversion.h"
37 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
38
39 namespace mlir {
40 namespace mhlo {
41 namespace {
42
43 // Calculates the flatten types of a value.
FlattenTupleType(Value value,llvm::SmallVectorImpl<Type> & types)44 void FlattenTupleType(Value value, llvm::SmallVectorImpl<Type> &types) {
45 if (!value.getType().isa<TupleType>()) {
46 types.push_back(value.getType());
47 return;
48 }
49
50 // This function doesn't handle nested tuple.
51 auto tupleType = value.getType().cast<TupleType>();
52 types.append(tupleType.begin(), tupleType.end());
53 }
54
55 // Flattens value into flatten_values.
FlattenTupleValue(OpBuilder & builder,Location loc,Value value,llvm::SmallVectorImpl<Value> & flatten_values)56 void FlattenTupleValue(OpBuilder &builder, Location loc, Value value,
57 llvm::SmallVectorImpl<Value> &flatten_values) {
58 if (!value.getType().isa<TupleType>()) {
59 flatten_values.push_back(value);
60 return;
61 }
62
63 // This function doesn't handle nested tuple.
64 int flattenIdx = 0;
65 auto tupleType = value.getType().cast<TupleType>();
66 for (auto childType : tupleType.getTypes()) {
67 auto getTupleOp = builder.create<mhlo::GetTupleElementOp>(
68 loc, childType, value, builder.getI32IntegerAttr(flattenIdx++));
69 flatten_values.push_back(getTupleOp);
70 }
71 }
72
73 // FlattenTupleValue and CreateTupleValue is a pair of functions to create and
74 // flatten tuples in the exact same order. CreateTupleValue returns the result
75 // of the root TupleOp or given value if the type is not TupleType.
CreateTupleValue(OpBuilder & builder,Location loc,const llvm::ArrayRef<Value> & flatten_values,Type tuple_type)76 Value CreateTupleValue(OpBuilder &builder, Location loc,
77 const llvm::ArrayRef<Value> &flatten_values,
78 Type tuple_type) {
79 if (!tuple_type.isa<TupleType>()) {
80 assert(flatten_values.size() == 1);
81 return flatten_values[0];
82 }
83
84 assert(tuple_type.cast<TupleType>().getTypes().size() ==
85 flatten_values.size());
86 return builder.create<mhlo::TupleOp>(loc, flatten_values);
87 }
88
89 // Flattens the tuples in the region's arguments and returning values.
FlattenTupleInRegion(Region & region,PatternRewriter & rewriter)90 void FlattenTupleInRegion(Region ®ion, PatternRewriter &rewriter) {
91 OpBuilder regionOpBuilder(region);
92
93 // Flatten tuples in arguments. The order of arguments must match the order
94 // in FlattenTupleType, FlattenTupleValue and CreateTupleValue.
95 const int originalNumArgs = region.getNumArguments();
96 for (int argIdx : llvm::seq<int>(0, originalNumArgs)) {
97 auto argument = region.getArgument(argIdx);
98
99 // Adds new arguments to replace the tuple argument.
100 llvm::SmallVector<Type, 4> newTypes;
101 llvm::SmallVector<Value, 4> newArguments;
102 FlattenTupleType(argument, newTypes);
103 for (auto type : newTypes) {
104 newArguments.push_back(region.addArgument(type));
105 }
106
107 // Replaces uses of the replacing argument.
108 auto tupleValue = CreateTupleValue(regionOpBuilder, region.getLoc(),
109 newArguments, argument.getType());
110 argument.replaceAllUsesWith(tupleValue);
111 }
112 // Removes old tuple arguments.
113 for (int argIdx = originalNumArgs - 1; argIdx >= 0; --argIdx) {
114 region.eraseArgument(argIdx);
115 }
116
117 // Flatten tuples in results.
118 for (auto &block : region.getBlocks()) {
119 Operation *terminator = block.getTerminator();
120 assert(isa<mhlo::ReturnOp>(terminator));
121 auto returnOp = llvm::cast<mhlo::ReturnOp>(terminator);
122
123 // Creates a new ReturnOp with flatten values.
124 OpBuilder builder(returnOp);
125 llvm::SmallVector<Value, 4> results;
126 for (auto operand : returnOp.getOperands()) {
127 FlattenTupleValue(builder, returnOp.getLoc(), operand, results);
128 }
129 builder.create<mhlo::ReturnOp>(region.getLoc(), results);
130 rewriter.eraseOp(returnOp);
131 }
132 }
133
134 // Applies tuple flattening patterns to given target. This helper
135 // function is used to flatten ops recursively.
136 template <typename T>
137 void ApplyFlatteningTuplePatterns(T target, MLIRContext *context);
138
139 struct FlattenWhileOp : public RewritePattern {
FlattenWhileOpmlir::mhlo::__anon77c078a20111::FlattenWhileOp140 explicit FlattenWhileOp(MLIRContext *context)
141 : RewritePattern(mhlo::WhileOp::getOperationName(), 1, context,
142 {mhlo::WhileOp::getOperationName(),
143 mhlo::TupleOp::getOperationName(),
144 mhlo::GetTupleElementOp::getOperationName()}),
145 context(context) {}
146
matchAndRewritemlir::mhlo::__anon77c078a20111::FlattenWhileOp147 LogicalResult matchAndRewrite(Operation *op,
148 PatternRewriter &rewriter) const override {
149 auto whileOp = cast<mhlo::WhileOp>(op);
150 // HLO WhileOp should have two regions: cond and body.
151 if (whileOp->getNumRegions() != 2) return failure();
152
153 // Operands of mhlo::WhileOp can be a variadic list of tensors and
154 // tuples. Tuples need to be flattened in order to be used in
155 // TF::WhileOp. Note that in WhileOp, operand and result types are
156 // always the same.
157 OpBuilder builder(whileOp);
158 llvm::SmallVector<Value, 4> flattenedOperands;
159 llvm::SmallVector<Type, 4> flattenedOperandTypes;
160 for (auto operand : whileOp->getOperands()) {
161 FlattenTupleType(operand, flattenedOperandTypes);
162 FlattenTupleValue(builder, whileOp.getLoc(), operand, flattenedOperands);
163 }
164
165 // The applyPatternsAndFoldGreedily can't be called on child regions, so
166 // creates temporary regions to apply flattening rules recursively.
167 auto module = whileOp->getParentOfType<ModuleOp>();
168 BlockAndValueMapping mapping;
169 Region newCond(module);
170 whileOp.cond().cloneInto(&newCond, mapping);
171 Region newBody(module);
172 whileOp.body().cloneInto(&newBody, mapping);
173
174 // Flattens the tuples in child regions.
175 FlattenTupleInRegion(newCond, rewriter);
176 FlattenTupleInRegion(newBody, rewriter);
177
178 // There might be WhileOp in child regions, flattens tuple in them too.
179 ApplyFlatteningTuplePatterns<MutableArrayRef<Region>>(newCond, context);
180 ApplyFlatteningTuplePatterns<MutableArrayRef<Region>>(newBody, context);
181
182 // Creates a new mhlo::WhileOp with no tuples.
183 auto newWhile = builder.create<mhlo::WhileOp>(
184 whileOp.getLoc(), flattenedOperandTypes, flattenedOperands);
185 newCond.cloneInto(&newWhile.cond(), mapping);
186 newBody.cloneInto(&newWhile.body(), mapping);
187
188 // Replaces uses of the old WhileOp.
189 auto newResultIter = newWhile.result_begin();
190 for (auto oldResult : whileOp.getResults()) {
191 llvm::SmallVector<Type, 4> flattenedTypes;
192 FlattenTupleType(oldResult, flattenedTypes);
193 llvm::SmallVector<Value, 4> flattenedResults;
194 while (flattenedResults.size() < flattenedTypes.size()) {
195 assert(newResultIter != newWhile->result_end());
196 flattenedResults.push_back(*newResultIter++);
197 }
198 auto tupleValue = CreateTupleValue(builder, whileOp.getLoc(),
199 flattenedResults, oldResult.getType());
200 oldResult.replaceAllUsesWith(tupleValue);
201 }
202 rewriter.eraseOp(whileOp);
203 return success();
204 }
205
206 private:
207 MLIRContext *context;
208 };
209
210 template <typename T>
ApplyFlatteningTuplePatterns(T target,MLIRContext * context)211 void ApplyFlatteningTuplePatterns(T target, MLIRContext *context) {
212 OwningRewritePatternList patterns(context);
213 patterns.insert<FlattenWhileOp>(context);
214 (void)applyPatternsAndFoldGreedily(target, std::move(patterns));
215 }
216
217 class FlattenTuplePass : public FlattenTuplePassBase<FlattenTuplePass> {
218 public:
runOnFunction()219 void runOnFunction() override {
220 MLIRContext *ctx = &getContext();
221 ApplyFlatteningTuplePatterns(getFunction(), ctx);
222 }
223 };
224 } // end namespace
225
226 static PassRegistration<FlattenTuplePass> pass;
227
createFlattenTuplePass()228 std::unique_ptr<FunctionPass> createFlattenTuplePass() {
229 return std::make_unique<FlattenTuplePass>();
230 }
231
232 } // end namespace mhlo
233 } // end namespace mlir
234