1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file contains the analysis and transformation to rewrite kernel
17 // functions such that they use a single set of arguments for the strides and
18 // sizes of operands with equal shapes.
19
20 #include <memory>
21
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseMapInfo.h"
24 #include "llvm/ADT/EquivalenceClasses.h"
25 #include "llvm/ADT/Hashing.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/Support/Debug.h"
28 #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
29 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
30 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
31 #include "mlir/IR/AsmState.h" // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
33 #include "mlir/IR/MLIRContext.h" // from @llvm-project
34 #include "mlir/IR/Value.h" // from @llvm-project
35 #include "mlir/Support/LLVM.h" // from @llvm-project
36 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
37 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
38 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
39
40 #define DEBUG_TYPE "kernel-gen-shapes"
41
42 namespace {
43
44 using mlir::ArrayRef;
45 using mlir::SmallVector;
46 using mlir::Value;
47
48 /// Represents a value or constant. Used to unify operands for operations that
49 /// take both ssa values and attributes.
50 struct ValueOrConst {
ValueOrConst__anonb552731d0111::ValueOrConst51 explicit ValueOrConst(Value v) : value_or_constant(v), is_constant(false) {}
ValueOrConst__anonb552731d0111::ValueOrConst52 explicit ValueOrConst(int64_t c) : value_or_constant(c), is_constant(true) {}
53
value__anonb552731d0111::ValueOrConst54 Value value() const {
55 assert(!is_constant);
56 return value_or_constant.value;
57 }
58
constant__anonb552731d0111::ValueOrConst59 int64_t constant() const {
60 assert(is_constant);
61 return value_or_constant.constant;
62 }
63
isConstant__anonb552731d0111::ValueOrConst64 bool isConstant() const { return is_constant; }
65
66 private:
67 union ValueOrConstStorage {
ValueOrConstStorage(Value v)68 explicit ValueOrConstStorage(Value v) : value(v) {}
ValueOrConstStorage(size_t c)69 explicit ValueOrConstStorage(size_t c) : constant(c) {}
70
71 Value value;
72 int64_t constant;
73 } value_or_constant;
74
75 bool is_constant;
76 };
77
hash_value(ValueOrConst value)78 llvm::hash_code hash_value(ValueOrConst value) {
79 return value.isConstant() ? static_cast<llvm::hash_code>(value.constant())
80 : mlir::hash_value(value.value());
81 }
82
operator ==(ValueOrConst lhs,ValueOrConst rhs)83 bool operator==(ValueOrConst lhs, ValueOrConst rhs) {
84 if (lhs.isConstant()) {
85 return rhs.isConstant() && lhs.constant() == rhs.constant();
86 } else {
87 return !rhs.isConstant() && lhs.value() == rhs.value();
88 }
89 }
90
operator <<(llvm::raw_ostream & os,const ValueOrConst & value)91 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
92 const ValueOrConst &value) {
93 if (value.isConstant()) {
94 os << value.constant();
95 } else {
96 Value val = value.value();
97 mlir::AsmState asm_state(
98 val.getParentRegion()->getParentOfType<mlir::FuncOp>());
99 val.printAsOperand(os, asm_state);
100 }
101 return os;
102 }
103
104 /// Represents a shape, as either a single SSA value that represents the entire
105 /// shape vector or as a vector of SSA values representing scalars.
106 struct ShapeValue {
ShapeValue__anonb552731d0111::ShapeValue107 explicit ShapeValue(Value vector)
108 : shape({ValueOrConst{vector}}), is_vector(true) {}
ShapeValue__anonb552731d0111::ShapeValue109 explicit ShapeValue(ValueOrConst vector) : shape({vector}), is_vector(true) {
110 assert(!vector.isConstant());
111 }
112 template <typename T>
ShapeValue__anonb552731d0111::ShapeValue113 explicit ShapeValue(T values)
114 : shape(values.begin(), values.end()), is_vector(false) {}
115
vector__anonb552731d0111::ShapeValue116 ValueOrConst vector() const {
117 assert(is_vector);
118 return shape.front();
119 }
120
scalars__anonb552731d0111::ShapeValue121 ArrayRef<ValueOrConst> scalars() const {
122 assert(!is_vector);
123 return llvm::makeArrayRef(shape);
124 }
125
isVector__anonb552731d0111::ShapeValue126 bool isVector() const { return is_vector; }
127
128 private:
129 SmallVector<ValueOrConst, 4> shape;
130 bool is_vector;
131 };
132
hash_value(ShapeValue shape)133 llvm::hash_code hash_value(ShapeValue shape) {
134 return shape.isVector() ? hash_value(shape.vector())
135 : hash_value(shape.scalars());
136 }
137
operator ==(ShapeValue lhs,ShapeValue rhs)138 bool operator==(ShapeValue lhs, ShapeValue rhs) {
139 if (lhs.isVector()) {
140 return rhs.isVector() && lhs.vector() == rhs.vector();
141 } else {
142 return !rhs.isVector() && lhs.scalars() == rhs.scalars();
143 }
144 }
145
operator <<(llvm::raw_ostream & os,const ShapeValue & shape)146 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
147 const ShapeValue &shape) {
148 if (shape.isVector()) {
149 os << shape.vector();
150 return os;
151 }
152 os << "[";
153 bool first = true;
154 for (auto scalar : shape.scalars()) {
155 if (!first) {
156 os << ", ";
157 }
158 first = false;
159 os << scalar;
160 }
161 os << "]";
162 return os;
163 }
164
165 } // namespace
166
167 namespace llvm {
168
169 template <>
170 struct DenseMapInfo<ShapeValue> {
getEmptyKeyllvm::DenseMapInfo171 static ShapeValue getEmptyKey() {
172 return ShapeValue(DenseMapInfo<mlir::Value>::getEmptyKey());
173 }
getTombstoneKeyllvm::DenseMapInfo174 static ShapeValue getTombstoneKey() {
175 return ShapeValue(DenseMapInfo<mlir::Value>::getTombstoneKey());
176 }
getHashValuellvm::DenseMapInfo177 static unsigned getHashValue(ShapeValue shape) { return hash_value(shape); }
isEqualllvm::DenseMapInfo178 static bool isEqual(ShapeValue LHS, ShapeValue RHS) { return LHS == RHS; }
179 };
180
181 } // namespace llvm
182
183 namespace mlir {
184 namespace kernel_gen {
185 namespace transforms {
186
187 namespace {
188
189 #define GEN_PASS_CLASSES
190 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
191
192 // A basic shape equality inference. This should be superceeded by a proper
193 // inference once available. Until then, we just build this out to the needs of
194 // the kernel generator project.
195 class ShapeEqualityKnowledge {
196 public:
197 /// Checks all operations for potential shape equality of their respective
198 /// results.
build(FuncOp function)199 void build(FuncOp function) {
200 function.walk([&](Operation *op) {
201 if (auto reshape = dyn_cast<MemRefReshapeOp>(op)) {
202 registerAssociation(ShapeValue{reshape.shape()}, reshape.result());
203 return;
204 }
205 if (auto cast = dyn_cast<MemRefReinterpretCastOp>(op)) {
206 // Only support fully dynamic sizes for now.
207 // TODO(herhut): Fix once the op has canonicalizers that break this.
208 for (unsigned int p = 0, e = cast.getResultRank(); p < e; ++p) {
209 if (!cast.isDynamicSize(p)) {
210 return;
211 }
212 }
213 registerAssociation(ShapeValue{cast.sizes()}, cast.result());
214 return;
215 }
216 if (auto alloc = dyn_cast<AllocOp>(op)) {
217 SmallVector<ValueOrConst, 4> shape;
218 ShapedType type = alloc.getResult().getType().cast<ShapedType>();
219 fillShapeFromAllocLike(alloc.getDynamicSizes(), type, shape);
220 registerAssociation(ShapeValue{shape}, alloc.getResult());
221 return;
222 }
223 if (auto alloc = dyn_cast<tf_framework::TFAllocOp>(op)) {
224 // Construct a symbol representing the allocated shape.
225 SmallVector<ValueOrConst, 4> shape;
226 ShapedType type = alloc.getResult().getType().cast<ShapedType>();
227 fillShapeFromAllocLike(alloc.dyn_sizes(), type, shape);
228 registerAssociation(ShapeValue{shape}, alloc.getResult());
229 return;
230 }
231 });
232 }
233
234 /// Checks whether `one` and `other` are known to have the same shape and
235 /// strides.
haveSameShape(Value one,Value other)236 bool haveSameShape(Value one, Value other) {
237 return equal_shapes_.isEquivalent(one.getAsOpaquePointer(),
238 other.getAsOpaquePointer());
239 }
240
241 private:
fillShapeFromAllocLike(mlir::OperandRange operands,ShapedType type,SmallVectorImpl<ValueOrConst> & shape)242 static void fillShapeFromAllocLike(mlir::OperandRange operands,
243 ShapedType type,
244 SmallVectorImpl<ValueOrConst> &shape) {
245 assert(type.hasRank());
246 auto dynamic_sizes = operands.begin();
247 for (auto extent : type.getShape()) {
248 shape.push_back(ShapedType::isDynamic(extent)
249 ? ValueOrConst{*(dynamic_sizes++)}
250 : ValueOrConst{extent});
251 }
252 }
253
254 /// Registers the value `value` to have the shape represented by `shape`. If
255 /// `shape` has been registered before, place `value` into the same
256 /// equivalence class. Otherwise register `value` as an equivalence class of
257 /// its own.
registerAssociation(ShapeValue shape,Value value)258 void registerAssociation(ShapeValue shape, Value value) {
259 LLVM_DEBUG({ llvm::dbgs() << "Processing " << value << "\n"; });
260 auto insert_symbolic = symbolic_shapes_.insert({shape, value});
261 if (insert_symbolic.second) {
262 LLVM_DEBUG({ llvm::dbgs() << "New symbolic shape " << shape << "\n"; });
263 equal_shapes_.insert(value.getAsOpaquePointer());
264 // We have seen this symbolic shape for the first time. Try to match it
265 // with a vector or shape we already know and alias classes if possible.
266 // This could be based on shape dialect if we weren't late in the
267 // lowering.
268 tryEvaluateShapeToRoot(shape, value);
269 } else {
270 auto rep = insert_symbolic.first->second;
271 LLVM_DEBUG({ llvm::dbgs() << "Aliasing with rep " << rep << "\n"; });
272 equal_shapes_.unionSets(rep.getAsOpaquePointer(),
273 value.getAsOpaquePointer());
274 }
275 }
276
277 /// Follows the definition chains of the ShapeValue `shape` to identify cases
278 /// where `shape` is derived from some other value's shape. In such case, the
279 /// equivalence classes of that other value and `value` are unioned.
280 /// This is based on pattern matching and not complete.
tryEvaluateShapeToRoot(ShapeValue shape,Value value)281 void tryEvaluateShapeToRoot(ShapeValue shape, Value value) {
282 // Just some pattern matching for common cases here.
283 if (!shape.isVector()) {
284 // Patterns that revolve around scalars.
285 // Check whether the scalars are all dim operations for some other memref.
286 Value candidate;
287 bool all_are_dimops =
288 llvm::all_of(llvm::enumerate(shape.scalars()), [&candidate](auto p) {
289 ValueOrConst val = p.value();
290 if (val.isConstant()) return false;
291 auto dimOp = val.value().getDefiningOp<DimOp>();
292 if (!dimOp) return false;
293 if (!candidate) candidate = dimOp.memrefOrTensor();
294 auto index = dimOp.getConstantIndex();
295 if (!index.hasValue()) return false;
296 return candidate == dimOp.memrefOrTensor() &&
297 p.index() == index.getValue();
298 });
299 if (all_are_dimops && candidate) {
300 equal_shapes_.unionSets(candidate.getAsOpaquePointer(),
301 value.getAsOpaquePointer());
302 }
303 }
304 }
305
306 // These are values with identical shapes (or rather their opaque pointers).
307 llvm::EquivalenceClasses<void *> equal_shapes_;
308 // A map from a value that encodes a shape to a value that has this shape.
309 llvm::DenseMap<ShapeValue, Value> symbolic_shapes_;
310 };
311
312 /// For arguments to kernels that have the same shape, use the stride and
313 /// shape information of the left-most argument inside of the kernel function.
314 /// That way, llvm can CSE index computations on same-shaped inputs.
315 struct PropagateShapeKnowledgeToKernels
316 : public PropagateShapeKnowledgeToKernelsBase<
317 PropagateShapeKnowledgeToKernels> {
runOnFunctionmlir::kernel_gen::transforms::__anonb552731d0211::PropagateShapeKnowledgeToKernels318 void runOnFunction() override {
319 ShapeEqualityKnowledge knowledge;
320
321 knowledge.build(getFunction());
322
323 getFunction().walk([&](gpu::LaunchFuncOp launch) {
324 auto module = launch->getParentOfType<ModuleOp>();
325 auto kernel = module.lookupSymbol<LLVM::LLVMFuncOp>(launch.kernel());
326
327 if (!kernel || kernel.isExternal()) return;
328
329 llvm::SmallVector<std::pair<Value, int>, 4> seen_memrefs;
330 // Position of the kernel argument we are currently at.
331 int kernel_p = 0;
332 for (auto operand : launch.operands()) {
333 auto memref = operand.getType().dyn_cast<MemRefType>();
334 if (!memref) {
335 // Scalar argument, advance kernel position by one.
336 kernel_p++;
337 continue;
338 }
339 for (auto previous : seen_memrefs) {
340 if (!knowledge.haveSameShape(operand, previous.first)) {
341 continue;
342 }
343 auto previous_type = previous.first.getType().cast<MemRefType>();
344 // We use the first equality found and replace uses of corresponding
345 // size and (potentially) stride information here.
346 auto args_to_replace = memref.getRank();
347 auto all_maps_are_identity = [](ArrayRef<AffineMap> maps) {
348 return llvm::all_of(maps,
349 [](AffineMap map) { return map.isIdentity(); });
350 };
351 // If both memrefs have identity maps, we can also reuse the strides
352 // here, as they are the identity strides and hence fully determinded
353 // by the shape.
354 if (all_maps_are_identity(previous_type.getAffineMaps()) &&
355 all_maps_are_identity(memref.getAffineMaps())) {
356 args_to_replace *= 2;
357 }
358 int previous_args_pos = previous.second;
359 auto previous_args = kernel.getArguments()
360 .drop_front(previous_args_pos + 3)
361 .take_front(args_to_replace);
362 auto current_args = kernel.getArguments()
363 .drop_front(kernel_p + 3)
364 .take_front(args_to_replace);
365 for (auto pair : llvm::zip(previous_args, current_args)) {
366 mlir::BlockArgument prev, curr;
367 std::tie(prev, curr) = pair;
368 curr.replaceAllUsesWith(prev);
369 }
370 break;
371 }
372 seen_memrefs.push_back({operand, kernel_p});
373 // Advance base, aligned, offset, strides and sizes many arguments.
374 kernel_p += memref.getRank() * 2 + 3;
375 }
376 });
377 }
378 };
379
380 } // namespace
381
CreatePropagateShapeKnowledgeToKernels()382 std::unique_ptr<FunctionPass> CreatePropagateShapeKnowledgeToKernels() {
383 return std::make_unique<PropagateShapeKnowledgeToKernels>();
384 }
385
386 } // namespace transforms
387 } // namespace kernel_gen
388 } // namespace mlir
389