1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"  // from @llvm-project
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
24 #include "mlir/IR/ImplicitLocOpBuilder.h"  // from @llvm-project
25 #include "tensorflow/compiler/xla/mlir/ir/runtime/rt_ops.h"
26 #include "tensorflow/compiler/xla/mlir/transforms/runtime/passes.h"
27 
28 namespace xla {
29 namespace runtime {
30 
31 using namespace mlir;  // NOLINT
32 
33 #define GEN_PASS_CLASSES
34 #include "tensorflow/compiler/xla/mlir/transforms/runtime/passes.h.inc"
35 
36 class ConvertToEntrypointPass
37     : public ConvertToEntrypointBase<ConvertToEntrypointPass> {
38   void runOnOperation() override;
39 };
40 
ConvertCustomCallOperations(func::FuncOp func,Value exec_ctx)41 static void ConvertCustomCallOperations(func::FuncOp func, Value exec_ctx) {
42   MLIRContext* ctx = func->getContext();
43 
44   SymbolTable sym_table(func->getParentOfType<ModuleOp>());
45 
46   struct CustomCall {
47     func::CallOp call;
48     func::FuncOp callee;
49     llvm::StringRef target;
50     bool direct;
51   };
52 
53   // Collect function calls that have to be converted to custom calls.
54   llvm::SmallVector<CustomCall> custom_calls;
55   func.walk([&](func::CallOp op) {
56     auto callee = dyn_cast<func::FuncOp>(sym_table.lookup(op.getCallee()));
57     if (!callee) return;
58 
59     // Check if the call is an indirect custom call ...
60     StringAttr target = callee->getAttrOfType<StringAttr>("rt.custom_call");
61     if (target) custom_calls.push_back({op, callee, target.strref(), false});
62 
63     // ... or a direct custom call.
64     target = callee->getAttrOfType<StringAttr>("rt.direct_custom_call");
65     if (target) custom_calls.push_back({op, callee, target.strref(), true});
66   });
67 
68   // After converting to custom call we need to clean up all declarations.
69   llvm::DenseSet<func::FuncOp> erase_declarations;
70 
71   // Rewrite function calls to `rt.custom_call` operations.
72   for (CustomCall custom_call : custom_calls) {
73     ImplicitLocOpBuilder b(custom_call.call.getLoc(), custom_call.call);
74 
75     // Custom call intrinsic always returns the status flag.
76     llvm::SmallVector<Type> results = {StatusType::get(ctx)};
77     results.append(custom_call.call->getResultTypes().begin(),
78                    custom_call.call->getResultTypes().end());
79 
80     // Rewrite function call with a custom call, and check the return status.
81     auto call = b.create<CustomCallOp>(results, exec_ctx, custom_call.target,
82                                        custom_call.direct,
83                                        custom_call.call.getOperands());
84 
85     // Copy optional attributes from the custom call function declaration.
86     llvm::ArrayRef<llvm::StringRef> callee_attrs =
87         custom_call.callee.getAttributeNames();
88     for (auto& attr : custom_call.callee->getAttrs()) {
89       if (isa_and_nonnull<RuntimeDialect>(attr.getNameDialect())) continue;
90       if (llvm::find(callee_attrs, attr.getName()) == callee_attrs.end())
91         call->setAttr(attr.getName(), attr.getValue());
92     }
93 
94     // Copy optional attributes from the call operation to the custom call.
95     llvm::ArrayRef<llvm::StringRef> orig_attrs =
96         custom_call.call.getAttributeNames();
97     for (auto& attr : custom_call.call->getAttrs()) {
98       if (llvm::find(orig_attrs, attr.getName()) == orig_attrs.end())
99         call->setAttr(attr.getName(), attr.getValue());
100     }
101 
102     b.create<cf::AssertOp>(
103         b.create<IsOkOp>(TypeRange(b.getI1Type()), call.status()),
104         b.getStringAttr("custom call '" + custom_call.target + "' failed"));
105 
106     // Forward users of the original results to custom call results.
107     auto rets = llvm::zip(custom_call.call.getResults(),
108                           llvm::drop_begin(call.getResults()));
109     llvm::for_each(rets, [](auto ret) {
110       std::get<0>(ret).replaceAllUsesWith(std::get<1>(ret));
111     });
112 
113     // Keep track of custom call declaration to erase.
114     erase_declarations.insert(custom_call.callee);
115 
116     // Erase the original function call operation.
117     custom_call.call.erase();
118   }
119 
120   // Erase all converted custom calls declarations.
121   for (auto func : erase_declarations) sym_table.erase(func);
122 }
123 
ConvertReturnOperations(func::FuncOp func,Value exec_ctx)124 static void ConvertReturnOperations(func::FuncOp func, Value exec_ctx) {
125   // Convert all returns to the Runtime API calls.
126   func.walk([&](func::ReturnOp ret) {
127     ImplicitLocOpBuilder b(ret.getLoc(), ret);
128 
129     // Return all outputs via the `rt.set_output` operation.
130     for (auto& pair : llvm::enumerate(ret.getOperands())) {
131       b.create<SetOutputOp>(exec_ctx, pair.index(), pair.value());
132     }
133 
134     // Replace original return with an empty one.
135     b.create<func::ReturnOp>();
136     ret.erase();
137   });
138 
139   // Update function type to the function with empty results.
140   auto type = FunctionType::get(func.getContext(), func.getArgumentTypes(), {});
141   func.setType(type);
142 }
143 
ConvertAssertOperations(func::FuncOp func,Value exec_ctx)144 static void ConvertAssertOperations(func::FuncOp func, Value exec_ctx) {
145   // Collect all assert operations in the function body.
146   llvm::SmallVector<cf::AssertOp> asserts;
147   func.walk([&](cf::AssertOp op) { asserts.push_back(op); });
148 
149   // Rewrite all asserts to the Runtime API calls.
150   for (cf::AssertOp assert : asserts) {
151     ImplicitLocOpBuilder b(assert.getLoc(), assert);
152 
153     // Split the block at the assert operation.
154     Block* block = assert->getBlock();
155     Block* ok = block->splitBlock(assert);
156 
157     // Set up block for returning error.
158     Block* err = func.addBlock();
159     b.setInsertionPointToStart(err);
160     b.create<SetErrorOp>(exec_ctx, assert.getMsg());
161     b.create<func::ReturnOp>();
162 
163     // Branch into the error block if assertion failed.
164     b.setInsertionPointToEnd(block);
165     b.create<cf::CondBranchOp>(assert.getArg(), ok, err);
166 
167     // Erase the original assert operation.
168     assert.erase();
169   }
170 }
171 
PrependExecutionContextArgument(func::FuncOp func)172 static Value PrependExecutionContextArgument(func::FuncOp func) {
173   Type new_type = KernelContextType::get(func.getContext());
174   DictionaryAttr attr = DictionaryAttr::get(func.getContext());
175   func.insertArguments({0}, {new_type}, {attr}, {func.getLoc()});
176   return func.getArgument(0);
177 }
178 
ConvertToEntrypoint(func::FuncOp func)179 static void ConvertToEntrypoint(func::FuncOp func) {
180   assert(func->hasAttr(kEntrypointAttrName));
181 
182   Value exec_ctx = PrependExecutionContextArgument(func);
183   ConvertCustomCallOperations(func, exec_ctx);
184   ConvertReturnOperations(func, exec_ctx);
185   ConvertAssertOperations(func, exec_ctx);
186 
187   // After conversion !rt.execution_context is a marker of an entrypoint.
188   func->removeAttr(kEntrypointAttrName);
189 }
190 
runOnOperation()191 void ConvertToEntrypointPass::runOnOperation() {
192   llvm::SmallVector<func::FuncOp> entry_points;
193 
194   // Collect entrypoint functions.
195   getOperation().walk([&](func::FuncOp op) {
196     if (op->hasAttr(kEntrypointAttrName)) entry_points.push_back(op);
197   });
198 
199   llvm::for_each(entry_points, ConvertToEntrypoint);
200 }
201 
CreateConvertToEntrypoint()202 std::unique_ptr<OperationPass<ModuleOp>> CreateConvertToEntrypoint() {
203   return std::make_unique<ConvertToEntrypointPass>();
204 }
205 
206 }  // namespace runtime
207 }  // namespace xla
208