• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for flattening tuples in HLO ops.
17 
18 #include <cassert>
19 #include <string>
20 #include <utility>
21 
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/MapVector.h"
24 #include "llvm/ADT/SmallSet.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
28 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
29 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
30 #include "mlir/IR/BlockAndValueMapping.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/Location.h"
33 #include "mlir/IR/Operation.h"
34 #include "mlir/IR/Region.h"
35 #include "mlir/IR/Value.h"
36 #include "mlir/Transforms/DialectConversion.h"
37 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
38 
39 namespace mlir {
40 namespace mhlo {
41 namespace {
42 
43 // Calculates the flatten types of a value.
FlattenTupleType(Value value,llvm::SmallVectorImpl<Type> & types)44 void FlattenTupleType(Value value, llvm::SmallVectorImpl<Type> &types) {
45   if (!value.getType().isa<TupleType>()) {
46     types.push_back(value.getType());
47     return;
48   }
49 
50   // This function doesn't handle nested tuple.
51   auto tupleType = value.getType().cast<TupleType>();
52   types.append(tupleType.begin(), tupleType.end());
53 }
54 
55 // Flattens value into flatten_values.
FlattenTupleValue(OpBuilder & builder,Location loc,Value value,llvm::SmallVectorImpl<Value> & flatten_values)56 void FlattenTupleValue(OpBuilder &builder, Location loc, Value value,
57                        llvm::SmallVectorImpl<Value> &flatten_values) {
58   if (!value.getType().isa<TupleType>()) {
59     flatten_values.push_back(value);
60     return;
61   }
62 
63   // This function doesn't handle nested tuple.
64   int flattenIdx = 0;
65   auto tupleType = value.getType().cast<TupleType>();
66   for (auto childType : tupleType.getTypes()) {
67     auto getTupleOp = builder.create<mhlo::GetTupleElementOp>(
68         loc, childType, value, builder.getI32IntegerAttr(flattenIdx++));
69     flatten_values.push_back(getTupleOp);
70   }
71 }
72 
73 // FlattenTupleValue and CreateTupleValue is a pair of functions to create and
74 // flatten tuples in the exact same order. CreateTupleValue returns the result
75 // of the root TupleOp or given value if the type is not TupleType.
CreateTupleValue(OpBuilder & builder,Location loc,const llvm::ArrayRef<Value> & flatten_values,Type tuple_type)76 Value CreateTupleValue(OpBuilder &builder, Location loc,
77                        const llvm::ArrayRef<Value> &flatten_values,
78                        Type tuple_type) {
79   if (!tuple_type.isa<TupleType>()) {
80     assert(flatten_values.size() == 1);
81     return flatten_values[0];
82   }
83 
84   assert(tuple_type.cast<TupleType>().getTypes().size() ==
85          flatten_values.size());
86   return builder.create<mhlo::TupleOp>(loc, flatten_values);
87 }
88 
89 // Flattens the tuples in the region's arguments and returning values.
FlattenTupleInRegion(Region & region,PatternRewriter & rewriter)90 void FlattenTupleInRegion(Region &region, PatternRewriter &rewriter) {
91   OpBuilder regionOpBuilder(region);
92 
93   // Flatten tuples in arguments. The order of arguments must match the order
94   // in FlattenTupleType, FlattenTupleValue and CreateTupleValue.
95   const int originalNumArgs = region.getNumArguments();
96   for (int argIdx : llvm::seq<int>(0, originalNumArgs)) {
97     auto argument = region.getArgument(argIdx);
98 
99     // Adds new arguments to replace the tuple argument.
100     llvm::SmallVector<Type, 4> newTypes;
101     llvm::SmallVector<Value, 4> newArguments;
102     FlattenTupleType(argument, newTypes);
103     for (auto type : newTypes) {
104       newArguments.push_back(region.addArgument(type));
105     }
106 
107     // Replaces uses of the replacing argument.
108     auto tupleValue = CreateTupleValue(regionOpBuilder, region.getLoc(),
109                                        newArguments, argument.getType());
110     argument.replaceAllUsesWith(tupleValue);
111   }
112   // Removes old tuple arguments.
113   for (int argIdx = originalNumArgs - 1; argIdx >= 0; --argIdx) {
114     region.eraseArgument(argIdx);
115   }
116 
117   // Flatten tuples in results.
118   for (auto &block : region.getBlocks()) {
119     Operation *terminator = block.getTerminator();
120     assert(isa<mhlo::ReturnOp>(terminator));
121     auto returnOp = llvm::cast<mhlo::ReturnOp>(terminator);
122 
123     // Creates a new ReturnOp with flatten values.
124     OpBuilder builder(returnOp);
125     llvm::SmallVector<Value, 4> results;
126     for (auto operand : returnOp.getOperands()) {
127       FlattenTupleValue(builder, returnOp.getLoc(), operand, results);
128     }
129     builder.create<mhlo::ReturnOp>(region.getLoc(), results);
130     rewriter.eraseOp(returnOp);
131   }
132 }
133 
134 // Applies tuple flattening patterns to given target. This helper
135 // function is used to flatten ops recursively.
136 template <typename T>
137 void ApplyFlatteningTuplePatterns(T target, MLIRContext *context);
138 
139 struct FlattenWhileOp : public RewritePattern {
FlattenWhileOpmlir::mhlo::__anon77c078a20111::FlattenWhileOp140   explicit FlattenWhileOp(MLIRContext *context)
141       : RewritePattern(mhlo::WhileOp::getOperationName(), 1, context,
142                        {mhlo::WhileOp::getOperationName(),
143                         mhlo::TupleOp::getOperationName(),
144                         mhlo::GetTupleElementOp::getOperationName()}),
145         context(context) {}
146 
matchAndRewritemlir::mhlo::__anon77c078a20111::FlattenWhileOp147   LogicalResult matchAndRewrite(Operation *op,
148                                 PatternRewriter &rewriter) const override {
149     auto whileOp = cast<mhlo::WhileOp>(op);
150     // HLO WhileOp should have two regions: cond and body.
151     if (whileOp->getNumRegions() != 2) return failure();
152 
153     // Operands of mhlo::WhileOp can be a variadic list of tensors and
154     // tuples. Tuples need to be flattened in order to be used in
155     // TF::WhileOp. Note that in WhileOp, operand and result types are
156     // always the same.
157     OpBuilder builder(whileOp);
158     llvm::SmallVector<Value, 4> flattenedOperands;
159     llvm::SmallVector<Type, 4> flattenedOperandTypes;
160     for (auto operand : whileOp->getOperands()) {
161       FlattenTupleType(operand, flattenedOperandTypes);
162       FlattenTupleValue(builder, whileOp.getLoc(), operand, flattenedOperands);
163     }
164 
165     // The applyPatternsAndFoldGreedily can't be called on child regions, so
166     // creates temporary regions to apply flattening rules recursively.
167     auto module = whileOp->getParentOfType<ModuleOp>();
168     BlockAndValueMapping mapping;
169     Region newCond(module);
170     whileOp.cond().cloneInto(&newCond, mapping);
171     Region newBody(module);
172     whileOp.body().cloneInto(&newBody, mapping);
173 
174     // Flattens the tuples in child regions.
175     FlattenTupleInRegion(newCond, rewriter);
176     FlattenTupleInRegion(newBody, rewriter);
177 
178     // There might be WhileOp in child regions, flattens tuple in them too.
179     ApplyFlatteningTuplePatterns<MutableArrayRef<Region>>(newCond, context);
180     ApplyFlatteningTuplePatterns<MutableArrayRef<Region>>(newBody, context);
181 
182     // Creates a new mhlo::WhileOp with no tuples.
183     auto newWhile = builder.create<mhlo::WhileOp>(
184         whileOp.getLoc(), flattenedOperandTypes, flattenedOperands);
185     newCond.cloneInto(&newWhile.cond(), mapping);
186     newBody.cloneInto(&newWhile.body(), mapping);
187 
188     // Replaces uses of the old WhileOp.
189     auto newResultIter = newWhile.result_begin();
190     for (auto oldResult : whileOp.getResults()) {
191       llvm::SmallVector<Type, 4> flattenedTypes;
192       FlattenTupleType(oldResult, flattenedTypes);
193       llvm::SmallVector<Value, 4> flattenedResults;
194       while (flattenedResults.size() < flattenedTypes.size()) {
195         assert(newResultIter != newWhile->result_end());
196         flattenedResults.push_back(*newResultIter++);
197       }
198       auto tupleValue = CreateTupleValue(builder, whileOp.getLoc(),
199                                          flattenedResults, oldResult.getType());
200       oldResult.replaceAllUsesWith(tupleValue);
201     }
202     rewriter.eraseOp(whileOp);
203     return success();
204   }
205 
206  private:
207   MLIRContext *context;
208 };
209 
210 template <typename T>
ApplyFlatteningTuplePatterns(T target,MLIRContext * context)211 void ApplyFlatteningTuplePatterns(T target, MLIRContext *context) {
212   OwningRewritePatternList patterns(context);
213   patterns.insert<FlattenWhileOp>(context);
214   (void)applyPatternsAndFoldGreedily(target, std::move(patterns));
215 }
216 
217 class FlattenTuplePass : public FlattenTuplePassBase<FlattenTuplePass> {
218  public:
runOnFunction()219   void runOnFunction() override {
220     MLIRContext *ctx = &getContext();
221     ApplyFlatteningTuplePatterns(getFunction(), ctx);
222   }
223 };
224 }  // end namespace
225 
226 static PassRegistration<FlattenTuplePass> pass;
227 
createFlattenTuplePass()228 std::unique_ptr<FunctionPass> createFlattenTuplePass() {
229   return std::make_unique<FlattenTuplePass>();
230 }
231 
232 }  // end namespace mhlo
233 }  // end namespace mlir
234