1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <vector>
19
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/IR/Attributes.h"
26 #include "mlir/IR/Builders.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/PatternMatch.h"
31 #include "mlir/IR/Value.h"
32 #include "mlir/Pass/Pass.h"
33 #include "mlir/Pass/PassRegistry.h"
34
35 namespace mlir {
36 namespace mhlo {
37 namespace {
38
39 // This pass assumes the function to be expanded has no callees, to be specific,
40 // the function is more like the main function.
41 class ExpandHloTuplesPass
42 : public ExpandHloTuplesPassBase<ExpandHloTuplesPass> {
43 public:
44 ExpandHloTuplesPass() = default;
45 ExpandHloTuplesPass(const ExpandHloTuplesPass&) = default;
ExpandHloTuplesPass(const std::string & entryFunctionName)46 explicit ExpandHloTuplesPass(const std::string& entryFunctionName) {
47 entry_function_name_ = entryFunctionName;
48 }
49
50 // Expands the mhlo.tuple used in return op. Also updates function
51 // signature accordingly.
expandTupledTensorInReturnOp(func::FuncOp func)52 void expandTupledTensorInReturnOp(func::FuncOp func) {
53 FunctionType oldFuncType = func.getFunctionType();
54 // Update input signatures.
55 // We will flatten the tuples for the function inputs as well.
56 // So if an input is tuple, will be flattened and packed as following:
57 // func_1(%arg0: tuple<input1, input2>) =>
58 //
59 // func_1(%arg0: <input1>, %arg1: <input2>) {
60 // %0 = mhlo.tuple(%arg0, %arg1)
61 // }
62 SmallVector<Type, 4> expandedInputTypes;
63 SmallVector<BlockArgument, 20> funcArguments(func.getArguments().begin(),
64 func.getArguments().end());
65 for (auto argument : funcArguments) {
66 auto type = argument.getType();
67 auto tupleType = type.dyn_cast_or_null<TupleType>();
68 if (!tupleType) {
69 expandedInputTypes.push_back(type);
70 } else {
71 // We need to
72 // 1) expand the tuple
73 // 2) insert a new tuple
74 // 3) rewire the new tuple
75 int originalArgumentIndex = argument.getArgNumber();
76 int argumentIndex = originalArgumentIndex;
77 SmallVector<Value, 4> flattenedOperands;
78 // insert the flattened tuples after the original tuple.
79 Location loc = func.getBody().getLoc();
80 for (auto flattenedType : tupleType.getTypes()) {
81 expandedInputTypes.push_back(flattenedType);
82 func.insertArgument(++argumentIndex, flattenedType, {}, loc);
83 flattenedOperands.push_back(func.getArgument(argumentIndex));
84 }
85
86 // Construct a new tuple and rewire it.
87 OpBuilder builder(func.getBody());
88 builder.setInsertionPointToStart(&func.getBody().front());
89 auto newTuple =
90 builder.create<mhlo::TupleOp>(loc, tupleType, flattenedOperands);
91 func.getArgument(originalArgumentIndex).replaceAllUsesWith(newTuple);
92
93 // Now the original argument has been rewired, we should be able to
94 // safely erase it.
95 func.eraseArgument(originalArgumentIndex);
96 }
97 }
98
99 // Update output signatures.
100 auto returnOp = cast<mlir::func::ReturnOp>(func.getBody().back().back());
101
102 // Expand all tuples in old return operands.
103 SmallVector<Value, 4> expandedReturnOperands;
104 SmallVector<Type, 4> expandedResultTypes;
105 for (auto value : returnOp.getOperands()) {
106 auto tuple = dyn_cast_or_null<mhlo::TupleOp>(value.getDefiningOp());
107 if (!tuple) {
108 expandedReturnOperands.push_back(value);
109 expandedResultTypes.push_back(value.getType());
110 continue;
111 }
112
113 for (auto tupleOperand : tuple.getOperands()) {
114 expandedReturnOperands.push_back(tupleOperand);
115 expandedResultTypes.push_back(tupleOperand.getType());
116 }
117 }
118
119 if (returnOp.getOperands() == expandedReturnOperands) return;
120
121 OpBuilder builder(returnOp);
122 builder.create<mlir::func::ReturnOp>(returnOp.getLoc(),
123 expandedReturnOperands);
124 returnOp.erase();
125 auto newFuncType = FunctionType::get(
126 oldFuncType.getContext(), expandedInputTypes, expandedResultTypes);
127 func.setType(newFuncType);
128 }
129
runOnOperation()130 void runOnOperation() override {
131 auto module = getOperation();
132 // Find `main` function.
133 auto entryFunction =
134 module.lookupSymbol<func::FuncOp>(entry_function_name_);
135 if (!entryFunction) {
136 return;
137 }
138
139 expandTupledTensorInReturnOp(entryFunction);
140 }
141 };
142
143 } // end namespace
144
createExpandHloTuplesPass(const std::string & entryFunctionName)145 std::unique_ptr<OperationPass<ModuleOp>> createExpandHloTuplesPass(
146 const std::string& entryFunctionName) {
147 return std::make_unique<ExpandHloTuplesPass>(entryFunctionName);
148 }
149
150 } // namespace mhlo
151 } // namespace mlir
152