1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h"
25 #include "mlir/IR/Attributes.h"
26 #include "mlir/IR/Builders.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/PatternMatch.h"
31 #include "mlir/IR/Value.h"
32 #include "mlir/Pass/Pass.h"
33 #include "mlir/Pass/PassRegistry.h"
34 
35 namespace mlir {
36 namespace mhlo {
37 namespace {
38 
39 // This pass assumes the function to be expanded has no callees, to be specific,
40 // the function is more like the main function.
41 class ExpandHloTuplesPass
42     : public ExpandHloTuplesPassBase<ExpandHloTuplesPass> {
43  public:
44   ExpandHloTuplesPass() = default;
ExpandHloTuplesPass(const ExpandHloTuplesPass &)45   ExpandHloTuplesPass(const ExpandHloTuplesPass&) {}
ExpandHloTuplesPass(const std::string & entry_function_name)46   explicit ExpandHloTuplesPass(const std::string& entry_function_name) {
47     entry_function_name_ = entry_function_name;
48   }
49 
50   // Expands the mhlo.tuple used in return op. Also updates function
51   // signature accordingly.
ExpandTupledTensorInReturnOp(FuncOp func)52   void ExpandTupledTensorInReturnOp(FuncOp func) {
53     FunctionType old_func_type = func.getType();
54     // Update input signatures.
55     // We will flatten the tuples for the function inputs as well.
56     // So if an input is tuple, will be flattened and packed as following:
57     // func_1(%arg0: tuple<input1, input2>) =>
58     //
59     // func_1(%arg0: <input1>, %arg1: <input2>) {
60     //  %0 = mhlo.tuple(%arg0, %arg1)
61     // }
62     SmallVector<Type, 4> expanded_input_types;
63     SmallVector<BlockArgument, 20> func_arguments(func.getArguments().begin(),
64                                                   func.getArguments().end());
65     for (auto argument : func_arguments) {
66       auto type = argument.getType();
67       auto tuple_type = type.dyn_cast_or_null<TupleType>();
68       if (!tuple_type) {
69         expanded_input_types.push_back(type);
70       } else {
71         // We need to
72         // 1) expand the tuple
73         // 2) insert a new tuple
74         // 3) rewire the new tuple
75         int original_argument_index = argument.getArgNumber();
76         int argument_index = original_argument_index;
77         SmallVector<Value, 4> flattened_operands;
78         // insert the flattened tuples after the original tuple.
79         for (auto flattened_type : tuple_type.getTypes()) {
80           expanded_input_types.push_back(flattened_type);
81           func.insertArgument(++argument_index, flattened_type, {});
82           flattened_operands.push_back(func.getArgument(argument_index));
83         }
84 
85         // Construct a new tuple and rewire it.
86         OpBuilder builder(func.body());
87         builder.setInsertionPointToStart(&func.body().front());
88         auto new_tuple = builder.create<mhlo::TupleOp>(
89             func.body().getLoc(), tuple_type, flattened_operands);
90         func.getArgument(original_argument_index).replaceAllUsesWith(new_tuple);
91 
92         // Now the original argument has been rewired, we should be able to
93         // safely erase it.
94         func.eraseArgument(original_argument_index);
95       }
96     }
97 
98     // Update output signatures.
99     auto return_op = cast<mlir::ReturnOp>(func.getBody().back().back());
100 
101     // Expand all tuples in old return operands.
102     SmallVector<Value, 4> expanded_return_operands;
103     SmallVector<Type, 4> expanded_result_types;
104     for (auto value : return_op.getOperands()) {
105       auto tuple = dyn_cast_or_null<mhlo::TupleOp>(value.getDefiningOp());
106       if (!tuple) {
107         expanded_return_operands.push_back(value);
108         expanded_result_types.push_back(value.getType());
109         continue;
110       }
111 
112       for (auto tuple_operand : tuple.getOperands()) {
113         expanded_return_operands.push_back(tuple_operand);
114         expanded_result_types.push_back(tuple_operand.getType());
115       }
116     }
117 
118     if (expanded_return_operands.empty()) return;
119 
120     OpBuilder builder(return_op);
121     builder.create<mlir::ReturnOp>(return_op.getLoc(),
122                                    expanded_return_operands);
123     return_op.erase();
124     auto new_func_type =
125         FunctionType::get(old_func_type.getContext(), expanded_input_types,
126                           expanded_result_types);
127     func.setType(new_func_type);
128   }
129 
runOnOperation()130   void runOnOperation() override {
131     auto module = getOperation();
132     // Find `main` function.
133     auto entry_function = module.lookupSymbol<FuncOp>(entry_function_name_);
134     if (!entry_function) {
135       return;
136     }
137 
138     ExpandTupledTensorInReturnOp(entry_function);
139   }
140 };
141 
142 }  // end namespace
143 
CreateExpandHloTuplesPass(const std::string & entry_function_name)144 std::unique_ptr<OperationPass<ModuleOp>> CreateExpandHloTuplesPass(
145     const std::string& entry_function_name) {
146   return std::make_unique<ExpandHloTuplesPass>(entry_function_name);
147 }
148 
149 }  // namespace mhlo
150 }  // namespace mlir
151