1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file contains the analysis and transformation to rewrite kernel
17 // functions such that information about alignment, aliasing and zero offsets
18 // steming from the tf_framework uses is propagated.
19
20 #include <cstdint>
21 #include <memory>
22
23 #include "llvm/ADT/Bitfields.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
26 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
27 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
29 #include "mlir/IR/MLIRContext.h" // from @llvm-project
30 #include "mlir/Support/LLVM.h" // from @llvm-project
31 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
32 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
33 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
34
35 namespace mlir {
36 namespace kernel_gen {
37 namespace transforms {
38 namespace {
39
40 #define GEN_PASS_CLASSES
41 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
42
43 struct PropagateTfAbiKnowledgeToKernelsPass
44 : public PropagateTfAbiKnowledgeToKernelsBase<
45 PropagateTfAbiKnowledgeToKernelsPass> {
runOnFunctionmlir::kernel_gen::transforms::__anonc32352a40111::PropagateTfAbiKnowledgeToKernelsPass46 void runOnFunction() override {
47 FuncOp function = getFunction();
48 llvm::SmallVector<Value, 4> worklist;
49 // We currently only handle entry functions and do not propagate across
50 // functions.
51 if (function->getAttrOfType<mlir::UnitAttr>(
52 tf_framework::TFFrameworkDialect::kTFEntryAttrName)) {
53 // For all operands of this function, we know they are aligned. Also, by
54 // construction of kernel generator, we know that there is no offset and
55 // the inner stride is one.
56 // TODO(herhut): Insert asserts in debug mode to check this.
57 for (auto argument : function.getArguments()) {
58 if (argument.getType().isa<BaseMemRefType>()) {
59 worklist.push_back(argument);
60 allocated_by_tf_runtime.insert(argument);
61 offset_is_zero.insert(argument);
62 inner_stride_is_constant.insert({argument, 1});
63 }
64 }
65 }
66
67 // For locally allocated values, we know they are aligned and have offset
68 // zero. Further, they also do not alias with other memrefs, except in
69 // benign ways. This is by construction and ensured by the reuse analysis.
70 function.walk([&](tf_framework::TFAllocOp op) {
71 Value allocated = op.getResult();
72 worklist.push_back(allocated);
73 no_alias.insert(allocated);
74 allocated_by_tf_runtime.insert(allocated);
75 offset_is_zero.insert(allocated);
76 inner_stride_is_constant.insert({allocated, 1});
77 });
78
79 // Next, take what we have and propagate it through known operations.
80 propagateThroughUses(worklist);
81
82 // Now look at launches and make use of the knowledge we have.
83 function.walk([&](gpu::LaunchFuncOp launch) {
84 auto module = launch->getParentOfType<ModuleOp>();
85 auto kernel = module.lookupSymbol<LLVM::LLVMFuncOp>(launch.kernel());
86
87 if (!kernel || kernel.isExternal()) return;
88
89 // Count the position of kernel operands independently, as they do not
90 // coincide with laucnh operands as memref parameters get expanded when
91 // lowered to llvm.
92 int kernel_p = 0;
93 OpBuilder b = OpBuilder::atBlockBegin(&kernel.body().front());
94 llvm::SmallDenseMap<int64_t, Value> constants;
95 auto loc = kernel.getLoc();
96 for (auto operand : launch.operands()) {
97 auto memref = operand.getType().dyn_cast<MemRefType>();
98 if (!memref) {
99 // Scalar argument, advance kernel position by one.
100 kernel_p++;
101 continue;
102 }
103 if (allocated_by_tf_runtime.contains(operand)) {
104 // This was allocated by the tf runtime, so the two pointers in the
105 // descriptor coincide. Rewrite the kernel accordingly.
106 Value alloc_ptr = kernel.getArgument(kernel_p);
107 Value align_ptr = kernel.getArgument(kernel_p + 1);
108 alloc_ptr.replaceAllUsesWith(align_ptr);
109 kernel.setArgAttr(
110 kernel_p + 1, LLVM::LLVMDialect::getAlignAttrName(),
111 b.getIndexAttr(
112 tf_framework::TFFrameworkDialect::kAllocationAlignment));
113 }
114 if (offset_is_zero.contains(operand)) {
115 Value offset = kernel.getArgument(kernel_p + 2);
116 Value &zero = constants[0];
117 if (!zero) {
118 zero = b.create<LLVM::ConstantOp>(loc, offset.getType(),
119 b.getIndexAttr(0));
120 }
121 offset.replaceAllUsesWith(zero);
122 }
123 auto const_stride = inner_stride_is_constant.find(operand);
124 if (const_stride != inner_stride_is_constant.end()) {
125 // The stride is the last argument belonging to this memref.
126 Value inner_stride =
127 kernel.getArgument(kernel_p + 2 + memref.getRank() * 2);
128 Value &stride_val = constants[const_stride->second];
129 if (!stride_val) {
130 stride_val = b.create<LLVM::ConstantOp>(
131 loc, inner_stride.getType(),
132 b.getIndexAttr(const_stride->second));
133 }
134 inner_stride.replaceAllUsesWith(stride_val);
135 }
136 if (no_alias.contains(operand)) {
137 // TODO(herhut): We also need to check whether any of the other args
138 // are aliases. This is currently never the case by construction
139 // but we could use the alias analysis from buffer placement here
140 // to make sure.
141 // Add the no_alias attribute to the corresponding pointer.
142 kernel.setArgAttr(kernel_p + 1,
143 LLVM::LLVMDialect::getNoAliasAttrName(),
144 b.getBoolAttr(true));
145 }
146 // Advance base, aligned, offset, strides and sizes many arguments.
147 kernel_p += memref.getRank() * 2 + 3;
148 }
149 });
150 }
151
152 private:
propagateThroughUsesmlir::kernel_gen::transforms::__anonc32352a40111::PropagateTfAbiKnowledgeToKernelsPass153 void propagateThroughUses(SmallVectorImpl<Value> &worklist) {
154 while (!worklist.empty()) {
155 Value candidate = worklist.pop_back_val();
156 for (auto user : candidate.getUsers()) {
157 if (isa<MemRefCastOp, MemRefReshapeOp>(user)) {
158 // Reshape and Cast propagate alignment, offset and innermost stride.
159 // TODO(herhut): This should be a trait.
160 Value result = user->getResult(0);
161 if (allocated_by_tf_runtime.contains(candidate)) {
162 allocated_by_tf_runtime.insert(result);
163 }
164 auto const_stride = inner_stride_is_constant.find(candidate);
165 if (const_stride != inner_stride_is_constant.end()) {
166 inner_stride_is_constant.insert({result, const_stride->second});
167 }
168 if (offset_is_zero.contains(candidate)) {
169 offset_is_zero.insert(result);
170 }
171 worklist.push_back(result);
172 }
173 if (auto cast = dyn_cast<MemRefReinterpretCastOp>(user)) {
174 // Check that we have offset 0.
175 Value result = cast.result();
176 if (!cast.isDynamicOffset(0) && cast.getStaticOffset(0) == 0) {
177 offset_is_zero.insert(result);
178 }
179 if (allocated_by_tf_runtime.contains(candidate)) {
180 allocated_by_tf_runtime.insert(result);
181 }
182 size_t last_stride = cast.getResultRank() - 1;
183 // TODO(herhut): Remove this once canonicalization handles this.
184 if (cast.isDynamicStride(last_stride)) {
185 auto dyn_stride = cast.getDynamicStride(last_stride)
186 .getDefiningOp<ConstantIndexOp>();
187 if (dyn_stride) {
188 inner_stride_is_constant.insert({result, dyn_stride.getValue()});
189 }
190 } else {
191 inner_stride_is_constant.insert(
192 {result, cast.getStaticStride(last_stride)});
193 }
194 worklist.push_back(result);
195 }
196 }
197 }
198 }
199
200 // Set of values that were allocated by the tf runtime and hence are aligned.
201 llvm::SmallPtrSet<Value, 8> allocated_by_tf_runtime;
202 // Set of values that are known to not have an offset of 0.
203 llvm::SmallPtrSet<Value, 8> offset_is_zero;
204 // Set of values that are known to have a constant stride.
205 llvm::SmallDenseMap<Value, int64_t, 8> inner_stride_is_constant;
206 // Set of values we know do not alias other values.
207 llvm::SmallPtrSet<Value, 8> no_alias;
208 };
209
210 } // namespace
211
CreatePropagateTfAbiKnowledgeToKernels()212 std::unique_ptr<FunctionPass> CreatePropagateTfAbiKnowledgeToKernels() {
213 return std::make_unique<PropagateTfAbiKnowledgeToKernelsPass>();
214 }
215
216 } // namespace transforms
217 } // namespace kernel_gen
218 } // namespace mlir
219