• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for injecting execution context to the entry
17 // function.
18 //
19 // Below is an example. Before Conversion:
20 //  ```
21 //   func @main(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) ->
22 //   memref<?x?xf32> {
23 //     %0 = memref.alloc(...)
24 //     "lmhlo.add"(%arg0, %arg1, %0) : (memref<?x?xf32>, memref<?x?xf32>,
25 //     memref<?x?xf32>) -> memref<?x?xf32> return %0 : memref<?x?xf32>
26 //   }
27 //  ```
28 // After conversion:
29 //  ```
30 //   func @main(%ctx: !disc_ral.context) {
31 //     %c0 = constant 0 : index
32 //     %c1 = constant 1 : index
33 //     "disc_ral.recv_input"(%ctx, %c0) : (!disc_ral.context, index) ->
34 //     memref<?x?xf32> "disc_ral.recv_input"(%ctx, %c1) : (!disc_ral.context,
35 //     index) -> memref<?x?xf32> %0 = memref.alloc(...) "lmhlo.add"(%arg0,
36 //     %arg1, %0) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) ->
37 //     memref<?x?xf32> "disc_ral.send_output"(%ctx, %c0, %0) :
38 //     (!disc_ral.context, index, memref<?x?xf32>) -> ()
39 //   }
40 //  ```
41 
42 // 1. rewrite entry function (supposed that no other function directly calls the
43 // entry function)
44 //    - function signature rewrite
45 //    - return-like ops rewrite.
46 // 2. Currently we suppose that functions except the entry function are inlined
47 // to the entry function. Thus, we don't rewrite all call ops and other
48 // functions a.t.m. Re-visit this assumption if necessary.
49 
50 #include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h"
51 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
52 #include "mlir/Dialect/StandardOps/IR/Ops.h"
53 #include "mlir/IR/Attributes.h"
54 #include "mlir/IR/Builders.h"
55 #include "mlir/IR/BuiltinOps.h"
56 #include "mlir/IR/BuiltinTypes.h"
57 #include "mlir/IR/Location.h"
58 #include "mlir/IR/MLIRContext.h"
59 #include "mlir/IR/Operation.h"
60 #include "mlir/Pass/Pass.h"
61 
62 namespace mlir {
63 namespace disc_ral {
64 
65 namespace {
66 
67 struct RalInjectExecutionContextPass
68     : public RalInjectExecutionContextPassBase<RalInjectExecutionContextPass> {
RalInjectExecutionContextPassmlir::disc_ral::__anon3f7c408e0111::RalInjectExecutionContextPass69   explicit RalInjectExecutionContextPass(const std::string& entry_func_name)
70       : RalInjectExecutionContextPassBase<RalInjectExecutionContextPass>::
71             RalInjectExecutionContextPassBase() {
72     this->entry_func_name_ = entry_func_name;
73   }
74 
getDependentDialectsmlir::disc_ral::__anon3f7c408e0111::RalInjectExecutionContextPass75   void getDependentDialects(DialectRegistry& registry) const override {
76     registry.insert<RalDialect>();
77   }
78 
runOnOperationmlir::disc_ral::__anon3f7c408e0111::RalInjectExecutionContextPass79   void runOnOperation() override {
80     ModuleOp m = getOperation();
81     FuncOp main = m.lookupSymbol<FuncOp>(entry_func_name_);
82     if (!main) {
83       m.emitError("entry func: " + entry_func_name_ + " not found");
84       signalPassFailure();
85     }
86 
87     Location loc = main.getLoc();
88     FunctionType funcType = main.getType();
89     OpBuilder b(&main.getBody());
90     Block* entry_block = &main.getBody().front();
91     Type ctx_type = RalExecutionContextType::get(b.getContext());
92 
93     // 1. Prepend context to the entry block arguments
94     Value ctx = entry_block->insertArgument(0u, ctx_type);
95 
96     // 2. remap original arguments to recv_input ops
97     for (auto&& en : llvm::enumerate(
98              llvm::zip(funcType.getInputs(),
99                        entry_block->getArguments().drop_front(1)))) {
100       Value idx = b.create<ConstantIndexOp>(loc, en.index());
101       Type argType = std::get<0>(en.value());
102       Value oldArgument = std::get<1>(en.value());
103       Value newInput = b.create<RecvInputOp>(loc, argType, ctx, idx);
104       oldArgument.replaceAllUsesWith(newInput);
105     }
106 
107     // 3. remap all return-like ops to send_output ops
108     for (auto& block : main.getBody()) {
109       if (block.empty()) continue;
110       Operation& operation = block.back();
111       if (!operation.hasTrait<OpTrait::ReturnLike>()) continue;
112       b.setInsertionPoint(&operation);
113       for (auto& en : llvm::enumerate(operation.getOperands())) {
114         Value idx = b.create<ConstantIndexOp>(loc, en.index());
115         b.create<SendOutputOp>(loc, ctx, idx, en.value());
116       }
117       operation.eraseOperands(0, operation.getNumOperands());
118     }
119 
120     // 4. remove unused block arguments of entry block
121     for (int i = 0, e = funcType.getInputs().size(); i < e; ++i) {
122       // continue to remove the 1st (starting from zero) argument
123       entry_block->eraseArgument(1);
124     }
125 
126     // 5. set entry func to new type
127     main.setType(b.getFunctionType({ctx_type}, {}));
128   }
129 };
130 
131 }  // namespace
132 
createRalInjectExecutionContextPass(const std::string & entry_func_name)133 std::unique_ptr<OperationPass<ModuleOp>> createRalInjectExecutionContextPass(
134     const std::string& entry_func_name) {
135   return std::make_unique<RalInjectExecutionContextPass>(entry_func_name);
136 }
137 
138 }  // namespace disc_ral
139 }  // namespace mlir
140