1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <vector>
19
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h"
25 #include "mlir/IR/Attributes.h"
26 #include "mlir/IR/Builders.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/PatternMatch.h"
31 #include "mlir/IR/Value.h"
32 #include "mlir/Pass/Pass.h"
33 #include "mlir/Pass/PassRegistry.h"
34
35 namespace mlir {
36 namespace mhlo {
37 namespace {
38
39 // This pass assumes the function to be expanded has no callees, to be specific,
40 // the function is more like the main function.
41 class ExpandHloTuplesPass
42 : public ExpandHloTuplesPassBase<ExpandHloTuplesPass> {
43 public:
44 ExpandHloTuplesPass() = default;
ExpandHloTuplesPass(const ExpandHloTuplesPass &)45 ExpandHloTuplesPass(const ExpandHloTuplesPass&) {}
ExpandHloTuplesPass(const std::string & entry_function_name)46 explicit ExpandHloTuplesPass(const std::string& entry_function_name) {
47 entry_function_name_ = entry_function_name;
48 }
49
50 // Expands the mhlo.tuple used in return op. Also updates function
51 // signature accordingly.
ExpandTupledTensorInReturnOp(FuncOp func)52 void ExpandTupledTensorInReturnOp(FuncOp func) {
53 FunctionType old_func_type = func.getType();
54 // Update input signatures.
55 // We will flatten the tuples for the function inputs as well.
56 // So if an input is tuple, will be flattened and packed as following:
57 // func_1(%arg0: tuple<input1, input2>) =>
58 //
59 // func_1(%arg0: <input1>, %arg1: <input2>) {
60 // %0 = mhlo.tuple(%arg0, %arg1)
61 // }
62 SmallVector<Type, 4> expanded_input_types;
63 SmallVector<BlockArgument, 20> func_arguments(func.getArguments().begin(),
64 func.getArguments().end());
65 for (auto argument : func_arguments) {
66 auto type = argument.getType();
67 auto tuple_type = type.dyn_cast_or_null<TupleType>();
68 if (!tuple_type) {
69 expanded_input_types.push_back(type);
70 } else {
71 // We need to
72 // 1) expand the tuple
73 // 2) insert a new tuple
74 // 3) rewire the new tuple
75 int original_argument_index = argument.getArgNumber();
76 int argument_index = original_argument_index;
77 SmallVector<Value, 4> flattened_operands;
78 // insert the flattened tuples after the original tuple.
79 for (auto flattened_type : tuple_type.getTypes()) {
80 expanded_input_types.push_back(flattened_type);
81 func.insertArgument(++argument_index, flattened_type, {});
82 flattened_operands.push_back(func.getArgument(argument_index));
83 }
84
85 // Construct a new tuple and rewire it.
86 OpBuilder builder(func.body());
87 builder.setInsertionPointToStart(&func.body().front());
88 auto new_tuple = builder.create<mhlo::TupleOp>(
89 func.body().getLoc(), tuple_type, flattened_operands);
90 func.getArgument(original_argument_index).replaceAllUsesWith(new_tuple);
91
92 // Now the original argument has been rewired, we should be able to
93 // safely erase it.
94 func.eraseArgument(original_argument_index);
95 }
96 }
97
98 // Update output signatures.
99 auto return_op = cast<mlir::ReturnOp>(func.getBody().back().back());
100
101 // Expand all tuples in old return operands.
102 SmallVector<Value, 4> expanded_return_operands;
103 SmallVector<Type, 4> expanded_result_types;
104 for (auto value : return_op.getOperands()) {
105 auto tuple = dyn_cast_or_null<mhlo::TupleOp>(value.getDefiningOp());
106 if (!tuple) {
107 expanded_return_operands.push_back(value);
108 expanded_result_types.push_back(value.getType());
109 continue;
110 }
111
112 for (auto tuple_operand : tuple.getOperands()) {
113 expanded_return_operands.push_back(tuple_operand);
114 expanded_result_types.push_back(tuple_operand.getType());
115 }
116 }
117
118 if (expanded_return_operands.empty()) return;
119
120 OpBuilder builder(return_op);
121 builder.create<mlir::ReturnOp>(return_op.getLoc(),
122 expanded_return_operands);
123 return_op.erase();
124 auto new_func_type =
125 FunctionType::get(old_func_type.getContext(), expanded_input_types,
126 expanded_result_types);
127 func.setType(new_func_type);
128 }
129
runOnOperation()130 void runOnOperation() override {
131 auto module = getOperation();
132 // Find `main` function.
133 auto entry_function = module.lookupSymbol<FuncOp>(entry_function_name_);
134 if (!entry_function) {
135 return;
136 }
137
138 ExpandTupledTensorInReturnOp(entry_function);
139 }
140 };
141
142 } // end namespace
143
CreateExpandHloTuplesPass(const std::string & entry_function_name)144 std::unique_ptr<OperationPass<ModuleOp>> CreateExpandHloTuplesPass(
145 const std::string& entry_function_name) {
146 return std::make_unique<ExpandHloTuplesPass>(entry_function_name);
147 }
148
149 } // namespace mhlo
150 } // namespace mlir
151