1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements the logic to lower some specific ops to external library
17 // calls.
18 //
19 // Here the external function is model by a `disc_ral.dispatch` op. We use
20 // `disc_ral.dispatch` to serve as a unified entrance of disc external
21 // calls due to following reasons.
22 // - `disc_ral.dispatch` ensures that the first argument is always the
23 //   `disc_ral.context`
24 // - `disc_ral.dispatch` simplifies the logic to handle different instantiations
25 //   of one op for different devices and different element types. For example,
26 //   we may have GEMM ops with different element types.
27 
28 #include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h"
29 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
30 #include "mlir/Dialect/StandardOps/IR/Ops.h"
31 #include "mlir/IR/Attributes.h"
32 #include "mlir/IR/Builders.h"
33 #include "mlir/IR/BuiltinOps.h"
34 #include "mlir/IR/BuiltinTypes.h"
35 #include "mlir/IR/Location.h"
36 #include "mlir/IR/MLIRContext.h"
37 #include "mlir/IR/Operation.h"
38 #include "mlir/Pass/Pass.h"
39 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
40 
41 namespace mlir {
42 namespace disc_ral {
43 
44 namespace {
45 
46 // Converting:
47 //   %output = disc_ral.recv_input(ctx, input_idx)
48 //     to
49 //   %output = disc_ral.dispatch(ctx, input_idx) {call_target_name =
50 //   "ral_recv_input", backend_config = "cpu"}
51 struct RecvInputOpConvertor : public OpRewritePattern<RecvInputOp> {
52   using OpRewritePattern<RecvInputOp>::OpRewritePattern;
53 
matchAndRewritemlir::disc_ral::__anonf4e11f340111::RecvInputOpConvertor54   LogicalResult matchAndRewrite(RecvInputOp op,
55                                 PatternRewriter& rewriter) const override {
56     auto operands = op.getOperands();
57     rewriter.replaceOpWithNewOp<DispatchOp>(op, op.getType(), operands.front(),
58                                             operands.drop_front(),
59                                             "ral_recv_input", false, "cpu");
60     return success();
61   }
62 };
63 
64 // Converting:
65 //   disc_ral.send_output(ctx, output_idx, output)
66 //     to
67 //   disc_ral.dispatch(ctx, output_idx, output) {call_target_name =
68 //   "ral_send_output", backend_config = "cpu"}
69 struct SendOutputOpConvertor : public OpRewritePattern<SendOutputOp> {
70   using OpRewritePattern<SendOutputOp>::OpRewritePattern;
71 
matchAndRewritemlir::disc_ral::__anonf4e11f340111::SendOutputOpConvertor72   LogicalResult matchAndRewrite(SendOutputOp op,
73                                 PatternRewriter& rewriter) const override {
74     auto operands = op.getOperands();
75     rewriter.replaceOpWithNewOp<DispatchOp>(op, llvm::None, operands.front(),
76                                             operands.drop_front(),
77                                             "ral_send_output", false, "cpu");
78     return success();
79   }
80 };
81 
82 struct RalLowerToLibraryCallPass
83     : public RalLowerToLibraryCallPassBase<RalLowerToLibraryCallPass> {
84   using RalLowerToLibraryCallPassBase<
85       RalLowerToLibraryCallPass>::RalLowerToLibraryCallPassBase;
86 
runOnFunctionmlir::disc_ral::__anonf4e11f340111::RalLowerToLibraryCallPass87   void runOnFunction() override {
88     FuncOp func = getFunction();
89     MLIRContext* context = &getContext();
90     OwningRewritePatternList patterns(context);
91     patterns.insert<RecvInputOpConvertor, SendOutputOpConvertor>(context);
92     if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
93       func.emitError("applyPatternsAndFoldGreedily does not converge");
94       signalPassFailure();
95     }
96   }
97 };
98 
99 }  // namespace
100 
createRalLowerToLibraryCallPass()101 std::unique_ptr<mlir::FunctionPass> createRalLowerToLibraryCallPass() {
102   return std::make_unique<RalLowerToLibraryCallPass>();
103 }
104 
105 }  // namespace disc_ral
106 }  // namespace mlir
107