• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file contains the analysis and transformation to rewrite kernel
17 // functions such that information about alignment, aliasing and zero offsets
18 // steming from the tf_framework uses is propagated.
19 
20 #include <cstdint>
21 #include <memory>
22 
23 #include "llvm/ADT/Bitfields.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
26 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
27 #include "mlir/Dialect/GPU/IR/GPUDialect.h"  // from @llvm-project
28 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
31 #include "mlir/Support/LLVM.h"  // from @llvm-project
32 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
33 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
34 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
35 
36 namespace mlir {
37 namespace kernel_gen {
38 namespace transforms {
39 namespace {
40 
41 #define GEN_PASS_CLASSES
42 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
43 
44 struct PropagateTfAbiKnowledgeToKernelsPass
45     : public PropagateTfAbiKnowledgeToKernelsBase<
46           PropagateTfAbiKnowledgeToKernelsPass> {
runOnOperationmlir::kernel_gen::transforms::__anon402b45730111::PropagateTfAbiKnowledgeToKernelsPass47   void runOnOperation() override {
48     func::FuncOp function = getOperation();
49     llvm::SmallVector<Value, 4> worklist;
50     // We currently only handle entry functions and do not propagate across
51     // functions.
52     if (function->getAttrOfType<mlir::UnitAttr>(
53             tf_framework::TFFrameworkDialect::kTFEntryAttrName)) {
54       // For all operands of this function, we know they are aligned. Also, by
55       // construction of kernel generator, we know that there is no offset and
56       // the inner stride is one.
57       // TODO(herhut): Insert asserts in debug mode to check this.
58       for (auto argument : function.getArguments()) {
59         if (argument.getType().isa<BaseMemRefType>()) {
60           worklist.push_back(argument);
61           allocated_by_tf_runtime.insert(argument);
62           offset_is_zero.insert(argument);
63           inner_stride_is_constant.insert({argument, 1});
64         }
65       }
66     }
67 
68     // For locally allocated values, we know they are aligned and have offset
69     // zero. Further, they also do not alias with other memrefs, except in
70     // benign ways. This is by construction and ensured by the reuse analysis.
71     function.walk([&](tf_framework::TFAllocOp op) {
72       Value allocated = op.getResult();
73       worklist.push_back(allocated);
74       no_alias.insert(allocated);
75       allocated_by_tf_runtime.insert(allocated);
76       offset_is_zero.insert(allocated);
77       inner_stride_is_constant.insert({allocated, 1});
78     });
79 
80     // Next, take what we have and propagate it through known operations.
81     propagateThroughUses(worklist);
82 
83     // Now look at launches and make use of the knowledge we have.
84     function.walk([&](gpu::LaunchFuncOp launch) {
85       auto module = launch->getParentOfType<ModuleOp>();
86       auto kernel = module.lookupSymbol<LLVM::LLVMFuncOp>(launch.kernel());
87 
88       if (!kernel || kernel.isExternal()) return;
89 
90       // Count the position of kernel operands independently, as they do not
91       // coincide with laucnh operands as memref parameters get expanded when
92       // lowered to llvm.
93       int kernel_p = 0;
94       OpBuilder b = OpBuilder::atBlockBegin(&kernel.getBody().front());
95       llvm::SmallDenseMap<int64_t, Value> constants;
96       auto loc = kernel.getLoc();
97       for (auto operand : launch.operands()) {
98         auto memref = operand.getType().dyn_cast<MemRefType>();
99         if (!memref) {
100           // Scalar argument, advance kernel position by one.
101           kernel_p++;
102           continue;
103         }
104         if (allocated_by_tf_runtime.contains(operand)) {
105           // This was allocated by the tf runtime, so the two pointers in the
106           // descriptor coincide. Rewrite the kernel accordingly.
107           Value alloc_ptr = kernel.getArgument(kernel_p);
108           Value align_ptr = kernel.getArgument(kernel_p + 1);
109           alloc_ptr.replaceAllUsesWith(align_ptr);
110           kernel.setArgAttr(
111               kernel_p + 1, LLVM::LLVMDialect::getAlignAttrName(),
112               b.getIndexAttr(
113                   tf_framework::TFFrameworkDialect::kAllocationAlignment));
114         }
115         if (offset_is_zero.contains(operand)) {
116           Value offset = kernel.getArgument(kernel_p + 2);
117           Value &zero = constants[0];
118           if (!zero) {
119             zero = b.create<LLVM::ConstantOp>(loc, offset.getType(),
120                                               b.getIndexAttr(0));
121           }
122           offset.replaceAllUsesWith(zero);
123         }
124         auto const_stride = inner_stride_is_constant.find(operand);
125         if (const_stride != inner_stride_is_constant.end()) {
126           // The stride is the last argument belonging to this memref.
127           Value inner_stride =
128               kernel.getArgument(kernel_p + 2 + memref.getRank() * 2);
129           Value &stride_val = constants[const_stride->second];
130           if (!stride_val) {
131             stride_val = b.create<LLVM::ConstantOp>(
132                 loc, inner_stride.getType(),
133                 b.getIndexAttr(const_stride->second));
134           }
135           inner_stride.replaceAllUsesWith(stride_val);
136         }
137         if (no_alias.contains(operand)) {
138           // TODO(herhut): We also need to check whether any of the other args
139           //     are aliases. This is currently never the case by construction
140           //     but we could use the alias analysis from buffer placement here
141           //     to make sure.
142           // Add the no_alias attribute to the corresponding pointer.
143           kernel.setArgAttr(kernel_p + 1,
144                             LLVM::LLVMDialect::getNoAliasAttrName(),
145                             b.getUnitAttr());
146         }
147         // Advance base, aligned, offset, strides and sizes many arguments.
148         kernel_p += memref.getRank() * 2 + 3;
149       }
150     });
151   }
152 
153  private:
propagateThroughUsesmlir::kernel_gen::transforms::__anon402b45730111::PropagateTfAbiKnowledgeToKernelsPass154   void propagateThroughUses(SmallVectorImpl<Value> &worklist) {
155     while (!worklist.empty()) {
156       Value candidate = worklist.pop_back_val();
157       for (auto user : candidate.getUsers()) {
158         if (isa<memref::CastOp, memref::ReshapeOp>(user)) {
159           // Reshape and Cast propagate alignment, offset and innermost stride.
160           // TODO(herhut): This should be a trait.
161           Value result = user->getResult(0);
162           if (allocated_by_tf_runtime.contains(candidate)) {
163             allocated_by_tf_runtime.insert(result);
164           }
165           auto const_stride = inner_stride_is_constant.find(candidate);
166           if (const_stride != inner_stride_is_constant.end()) {
167             inner_stride_is_constant.insert({result, const_stride->second});
168           }
169           if (offset_is_zero.contains(candidate)) {
170             offset_is_zero.insert(result);
171           }
172           worklist.push_back(result);
173         }
174         if (auto cast = dyn_cast<memref::ReinterpretCastOp>(user)) {
175           // Check that we have offset 0.
176           Value result = cast.getResult();
177           if (!cast.isDynamicOffset(0) && cast.getStaticOffset(0) == 0) {
178             offset_is_zero.insert(result);
179           }
180           if (allocated_by_tf_runtime.contains(candidate)) {
181             allocated_by_tf_runtime.insert(result);
182           }
183           size_t last_stride = cast.getResultRank() - 1;
184           // TODO(herhut): Remove this once canonicalization handles this.
185           if (cast.isDynamicStride(last_stride)) {
186             auto dyn_stride = cast.getDynamicStride(last_stride)
187                                   .getDefiningOp<arith::ConstantIndexOp>();
188             if (dyn_stride) {
189               inner_stride_is_constant.insert({result, dyn_stride.value()});
190             }
191           } else {
192             inner_stride_is_constant.insert(
193                 {result, cast.getStaticStride(last_stride)});
194           }
195           worklist.push_back(result);
196         }
197       }
198     }
199   }
200 
201   // Set of values that were allocated by the tf runtime and hence are aligned.
202   llvm::SmallPtrSet<Value, 8> allocated_by_tf_runtime;
203   // Set of values that are known to not have an offset of 0.
204   llvm::SmallPtrSet<Value, 8> offset_is_zero;
205   // Set of values that are known to have a constant stride.
206   llvm::SmallDenseMap<Value, int64_t, 8> inner_stride_is_constant;
207   // Set of values we know do not alias other values.
208   llvm::SmallPtrSet<Value, 8> no_alias;
209 };
210 
211 }  // namespace
212 
213 std::unique_ptr<OperationPass<func::FuncOp>>
CreatePropagateTfAbiKnowledgeToKernels()214 CreatePropagateTfAbiKnowledgeToKernels() {
215   return std::make_unique<PropagateTfAbiKnowledgeToKernelsPass>();
216 }
217 
218 }  // namespace transforms
219 }  // namespace kernel_gen
220 }  // namespace mlir
221