1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for injecting execution context to the entry
17 // function.
18 //
19 // Below is an example. Before Conversion:
20 // ```
21 // func @main(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) ->
22 // memref<?x?xf32> {
23 // %0 = memref.alloc(...)
24 // "lmhlo.add"(%arg0, %arg1, %0) : (memref<?x?xf32>, memref<?x?xf32>,
25 // memref<?x?xf32>) -> memref<?x?xf32> return %0 : memref<?x?xf32>
26 // }
27 // ```
28 // After conversion:
29 // ```
30 // func @main(%ctx: !disc_ral.context) {
31 // %c0 = constant 0 : index
32 // %c1 = constant 1 : index
33 // "disc_ral.recv_input"(%ctx, %c0) : (!disc_ral.context, index) ->
34 // memref<?x?xf32> "disc_ral.recv_input"(%ctx, %c1) : (!disc_ral.context,
35 // index) -> memref<?x?xf32> %0 = memref.alloc(...) "lmhlo.add"(%arg0,
36 // %arg1, %0) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) ->
37 // memref<?x?xf32> "disc_ral.send_output"(%ctx, %c0, %0) :
38 // (!disc_ral.context, index, memref<?x?xf32>) -> ()
39 // }
40 // ```
41
42 // 1. rewrite entry function (supposed that no other function directly calls the
43 // entry function)
44 // - function signature rewrite
45 // - return-like ops rewrite.
46 // 2. Currently we suppose that functions except the entry function are inlined
47 // to the entry function. Thus, we don't rewrite all call ops and other
48 // functions a.t.m. Re-visit this assumption if necessary.
49
50 #include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h"
51 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
52 #include "mlir/Dialect/StandardOps/IR/Ops.h"
53 #include "mlir/IR/Attributes.h"
54 #include "mlir/IR/Builders.h"
55 #include "mlir/IR/BuiltinOps.h"
56 #include "mlir/IR/BuiltinTypes.h"
57 #include "mlir/IR/Location.h"
58 #include "mlir/IR/MLIRContext.h"
59 #include "mlir/IR/Operation.h"
60 #include "mlir/Pass/Pass.h"
61
62 namespace mlir {
63 namespace disc_ral {
64
65 namespace {
66
67 struct RalInjectExecutionContextPass
68 : public RalInjectExecutionContextPassBase<RalInjectExecutionContextPass> {
RalInjectExecutionContextPassmlir::disc_ral::__anon3f7c408e0111::RalInjectExecutionContextPass69 explicit RalInjectExecutionContextPass(const std::string& entry_func_name)
70 : RalInjectExecutionContextPassBase<RalInjectExecutionContextPass>::
71 RalInjectExecutionContextPassBase() {
72 this->entry_func_name_ = entry_func_name;
73 }
74
getDependentDialectsmlir::disc_ral::__anon3f7c408e0111::RalInjectExecutionContextPass75 void getDependentDialects(DialectRegistry& registry) const override {
76 registry.insert<RalDialect>();
77 }
78
runOnOperationmlir::disc_ral::__anon3f7c408e0111::RalInjectExecutionContextPass79 void runOnOperation() override {
80 ModuleOp m = getOperation();
81 FuncOp main = m.lookupSymbol<FuncOp>(entry_func_name_);
82 if (!main) {
83 m.emitError("entry func: " + entry_func_name_ + " not found");
84 signalPassFailure();
85 }
86
87 Location loc = main.getLoc();
88 FunctionType funcType = main.getType();
89 OpBuilder b(&main.getBody());
90 Block* entry_block = &main.getBody().front();
91 Type ctx_type = RalExecutionContextType::get(b.getContext());
92
93 // 1. Prepend context to the entry block arguments
94 Value ctx = entry_block->insertArgument(0u, ctx_type);
95
96 // 2. remap original arguments to recv_input ops
97 for (auto&& en : llvm::enumerate(
98 llvm::zip(funcType.getInputs(),
99 entry_block->getArguments().drop_front(1)))) {
100 Value idx = b.create<ConstantIndexOp>(loc, en.index());
101 Type argType = std::get<0>(en.value());
102 Value oldArgument = std::get<1>(en.value());
103 Value newInput = b.create<RecvInputOp>(loc, argType, ctx, idx);
104 oldArgument.replaceAllUsesWith(newInput);
105 }
106
107 // 3. remap all return-like ops to send_output ops
108 for (auto& block : main.getBody()) {
109 if (block.empty()) continue;
110 Operation& operation = block.back();
111 if (!operation.hasTrait<OpTrait::ReturnLike>()) continue;
112 b.setInsertionPoint(&operation);
113 for (auto& en : llvm::enumerate(operation.getOperands())) {
114 Value idx = b.create<ConstantIndexOp>(loc, en.index());
115 b.create<SendOutputOp>(loc, ctx, idx, en.value());
116 }
117 operation.eraseOperands(0, operation.getNumOperands());
118 }
119
120 // 4. remove unused block arguments of entry block
121 for (int i = 0, e = funcType.getInputs().size(); i < e; ++i) {
122 // continue to remove the 1st (starting from zero) argument
123 entry_block->eraseArgument(1);
124 }
125
126 // 5. set entry func to new type
127 main.setType(b.getFunctionType({ctx_type}, {}));
128 }
129 };
130
131 } // namespace
132
createRalInjectExecutionContextPass(const std::string & entry_func_name)133 std::unique_ptr<OperationPass<ModuleOp>> createRalInjectExecutionContextPass(
134 const std::string& entry_func_name) {
135 return std::make_unique<RalInjectExecutionContextPass>(entry_func_name);
136 }
137
138 } // namespace disc_ral
139 } // namespace mlir
140