• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file contains the analysis and transformation to rewrite kernel
17 // functions such that they use a single set of arguments for the strides and
18 // sizes of operands with equal shapes.
19 
20 #include <memory>
21 
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseMapInfo.h"
24 #include "llvm/ADT/EquivalenceClasses.h"
25 #include "llvm/ADT/Hashing.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/Support/Debug.h"
28 #include "mlir/Dialect/GPU/GPUDialect.h"  // from @llvm-project
29 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
30 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
31 #include "mlir/IR/AsmState.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
33 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
34 #include "mlir/IR/Value.h"  // from @llvm-project
35 #include "mlir/Support/LLVM.h"  // from @llvm-project
36 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
37 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
38 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
39 
40 #define DEBUG_TYPE "kernel-gen-shapes"
41 
42 namespace {
43 
44 using mlir::ArrayRef;
45 using mlir::SmallVector;
46 using mlir::Value;
47 
48 /// Represents a value or constant. Used to unify operands for operations that
49 /// take both ssa values and attributes.
50 struct ValueOrConst {
ValueOrConst__anonb552731d0111::ValueOrConst51   explicit ValueOrConst(Value v) : value_or_constant(v), is_constant(false) {}
ValueOrConst__anonb552731d0111::ValueOrConst52   explicit ValueOrConst(int64_t c) : value_or_constant(c), is_constant(true) {}
53 
value__anonb552731d0111::ValueOrConst54   Value value() const {
55     assert(!is_constant);
56     return value_or_constant.value;
57   }
58 
constant__anonb552731d0111::ValueOrConst59   int64_t constant() const {
60     assert(is_constant);
61     return value_or_constant.constant;
62   }
63 
isConstant__anonb552731d0111::ValueOrConst64   bool isConstant() const { return is_constant; }
65 
66  private:
67   union ValueOrConstStorage {
ValueOrConstStorage(Value v)68     explicit ValueOrConstStorage(Value v) : value(v) {}
ValueOrConstStorage(size_t c)69     explicit ValueOrConstStorage(size_t c) : constant(c) {}
70 
71     Value value;
72     int64_t constant;
73   } value_or_constant;
74 
75   bool is_constant;
76 };
77 
hash_value(ValueOrConst value)78 llvm::hash_code hash_value(ValueOrConst value) {
79   return value.isConstant() ? static_cast<llvm::hash_code>(value.constant())
80                             : mlir::hash_value(value.value());
81 }
82 
operator ==(ValueOrConst lhs,ValueOrConst rhs)83 bool operator==(ValueOrConst lhs, ValueOrConst rhs) {
84   if (lhs.isConstant()) {
85     return rhs.isConstant() && lhs.constant() == rhs.constant();
86   } else {
87     return !rhs.isConstant() && lhs.value() == rhs.value();
88   }
89 }
90 
operator <<(llvm::raw_ostream & os,const ValueOrConst & value)91 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
92                                      const ValueOrConst &value) {
93   if (value.isConstant()) {
94     os << value.constant();
95   } else {
96     Value val = value.value();
97     mlir::AsmState asm_state(
98         val.getParentRegion()->getParentOfType<mlir::FuncOp>());
99     val.printAsOperand(os, asm_state);
100   }
101   return os;
102 }
103 
104 /// Represents a shape, as either a single SSA value that represents the entire
105 /// shape vector or as a vector of SSA values representing scalars.
106 struct ShapeValue {
ShapeValue__anonb552731d0111::ShapeValue107   explicit ShapeValue(Value vector)
108       : shape({ValueOrConst{vector}}), is_vector(true) {}
ShapeValue__anonb552731d0111::ShapeValue109   explicit ShapeValue(ValueOrConst vector) : shape({vector}), is_vector(true) {
110     assert(!vector.isConstant());
111   }
112   template <typename T>
ShapeValue__anonb552731d0111::ShapeValue113   explicit ShapeValue(T values)
114       : shape(values.begin(), values.end()), is_vector(false) {}
115 
vector__anonb552731d0111::ShapeValue116   ValueOrConst vector() const {
117     assert(is_vector);
118     return shape.front();
119   }
120 
scalars__anonb552731d0111::ShapeValue121   ArrayRef<ValueOrConst> scalars() const {
122     assert(!is_vector);
123     return llvm::makeArrayRef(shape);
124   }
125 
isVector__anonb552731d0111::ShapeValue126   bool isVector() const { return is_vector; }
127 
128  private:
129   SmallVector<ValueOrConst, 4> shape;
130   bool is_vector;
131 };
132 
hash_value(ShapeValue shape)133 llvm::hash_code hash_value(ShapeValue shape) {
134   return shape.isVector() ? hash_value(shape.vector())
135                           : hash_value(shape.scalars());
136 }
137 
operator ==(ShapeValue lhs,ShapeValue rhs)138 bool operator==(ShapeValue lhs, ShapeValue rhs) {
139   if (lhs.isVector()) {
140     return rhs.isVector() && lhs.vector() == rhs.vector();
141   } else {
142     return !rhs.isVector() && lhs.scalars() == rhs.scalars();
143   }
144 }
145 
operator <<(llvm::raw_ostream & os,const ShapeValue & shape)146 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
147                                      const ShapeValue &shape) {
148   if (shape.isVector()) {
149     os << shape.vector();
150     return os;
151   }
152   os << "[";
153   bool first = true;
154   for (auto scalar : shape.scalars()) {
155     if (!first) {
156       os << ", ";
157     }
158     first = false;
159     os << scalar;
160   }
161   os << "]";
162   return os;
163 }
164 
165 }  // namespace
166 
167 namespace llvm {
168 
169 template <>
170 struct DenseMapInfo<ShapeValue> {
getEmptyKeyllvm::DenseMapInfo171   static ShapeValue getEmptyKey() {
172     return ShapeValue(DenseMapInfo<mlir::Value>::getEmptyKey());
173   }
getTombstoneKeyllvm::DenseMapInfo174   static ShapeValue getTombstoneKey() {
175     return ShapeValue(DenseMapInfo<mlir::Value>::getTombstoneKey());
176   }
getHashValuellvm::DenseMapInfo177   static unsigned getHashValue(ShapeValue shape) { return hash_value(shape); }
isEqualllvm::DenseMapInfo178   static bool isEqual(ShapeValue LHS, ShapeValue RHS) { return LHS == RHS; }
179 };
180 
181 }  // namespace llvm
182 
183 namespace mlir {
184 namespace kernel_gen {
185 namespace transforms {
186 
187 namespace {
188 
189 #define GEN_PASS_CLASSES
190 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
191 
192 // A basic shape equality inference. This should be superceeded by a proper
193 // inference once available. Until then, we just build this out to the needs of
194 // the kernel generator project.
195 class ShapeEqualityKnowledge {
196  public:
197   /// Checks all operations for potential shape equality of their respective
198   /// results.
build(FuncOp function)199   void build(FuncOp function) {
200     function.walk([&](Operation *op) {
201       if (auto reshape = dyn_cast<MemRefReshapeOp>(op)) {
202         registerAssociation(ShapeValue{reshape.shape()}, reshape.result());
203         return;
204       }
205       if (auto cast = dyn_cast<MemRefReinterpretCastOp>(op)) {
206         // Only support fully dynamic sizes for now.
207         // TODO(herhut): Fix once the op has canonicalizers that break this.
208         for (unsigned int p = 0, e = cast.getResultRank(); p < e; ++p) {
209           if (!cast.isDynamicSize(p)) {
210             return;
211           }
212         }
213         registerAssociation(ShapeValue{cast.sizes()}, cast.result());
214         return;
215       }
216       if (auto alloc = dyn_cast<AllocOp>(op)) {
217         SmallVector<ValueOrConst, 4> shape;
218         ShapedType type = alloc.getResult().getType().cast<ShapedType>();
219         fillShapeFromAllocLike(alloc.getDynamicSizes(), type, shape);
220         registerAssociation(ShapeValue{shape}, alloc.getResult());
221         return;
222       }
223       if (auto alloc = dyn_cast<tf_framework::TFAllocOp>(op)) {
224         // Construct a symbol representing the allocated shape.
225         SmallVector<ValueOrConst, 4> shape;
226         ShapedType type = alloc.getResult().getType().cast<ShapedType>();
227         fillShapeFromAllocLike(alloc.dyn_sizes(), type, shape);
228         registerAssociation(ShapeValue{shape}, alloc.getResult());
229         return;
230       }
231     });
232   }
233 
234   /// Checks whether `one` and `other` are known to have the same shape and
235   /// strides.
haveSameShape(Value one,Value other)236   bool haveSameShape(Value one, Value other) {
237     return equal_shapes_.isEquivalent(one.getAsOpaquePointer(),
238                                       other.getAsOpaquePointer());
239   }
240 
241  private:
fillShapeFromAllocLike(mlir::OperandRange operands,ShapedType type,SmallVectorImpl<ValueOrConst> & shape)242   static void fillShapeFromAllocLike(mlir::OperandRange operands,
243                                      ShapedType type,
244                                      SmallVectorImpl<ValueOrConst> &shape) {
245     assert(type.hasRank());
246     auto dynamic_sizes = operands.begin();
247     for (auto extent : type.getShape()) {
248       shape.push_back(ShapedType::isDynamic(extent)
249                           ? ValueOrConst{*(dynamic_sizes++)}
250                           : ValueOrConst{extent});
251     }
252   }
253 
254   /// Registers the value `value` to have the shape represented by `shape`. If
255   /// `shape` has been registered before, place `value` into the same
256   /// equivalence class. Otherwise register `value` as an equivalence class of
257   /// its own.
registerAssociation(ShapeValue shape,Value value)258   void registerAssociation(ShapeValue shape, Value value) {
259     LLVM_DEBUG({ llvm::dbgs() << "Processing " << value << "\n"; });
260     auto insert_symbolic = symbolic_shapes_.insert({shape, value});
261     if (insert_symbolic.second) {
262       LLVM_DEBUG({ llvm::dbgs() << "New symbolic shape " << shape << "\n"; });
263       equal_shapes_.insert(value.getAsOpaquePointer());
264       // We have seen this symbolic shape for the first time. Try to match it
265       // with a vector or shape we already know and alias classes if possible.
266       // This could be based on shape dialect if we weren't late in the
267       // lowering.
268       tryEvaluateShapeToRoot(shape, value);
269     } else {
270       auto rep = insert_symbolic.first->second;
271       LLVM_DEBUG({ llvm::dbgs() << "Aliasing with rep " << rep << "\n"; });
272       equal_shapes_.unionSets(rep.getAsOpaquePointer(),
273                               value.getAsOpaquePointer());
274     }
275   }
276 
277   /// Follows the definition chains of the ShapeValue `shape` to identify cases
278   /// where `shape` is derived from some other value's shape. In such case, the
279   /// equivalence classes of that other value and `value` are unioned.
280   /// This is based on pattern matching and not complete.
tryEvaluateShapeToRoot(ShapeValue shape,Value value)281   void tryEvaluateShapeToRoot(ShapeValue shape, Value value) {
282     // Just some pattern matching for common cases here.
283     if (!shape.isVector()) {
284       // Patterns that revolve around scalars.
285       // Check whether the scalars are all dim operations for some other memref.
286       Value candidate;
287       bool all_are_dimops =
288           llvm::all_of(llvm::enumerate(shape.scalars()), [&candidate](auto p) {
289             ValueOrConst val = p.value();
290             if (val.isConstant()) return false;
291             auto dimOp = val.value().getDefiningOp<DimOp>();
292             if (!dimOp) return false;
293             if (!candidate) candidate = dimOp.memrefOrTensor();
294             auto index = dimOp.getConstantIndex();
295             if (!index.hasValue()) return false;
296             return candidate == dimOp.memrefOrTensor() &&
297                    p.index() == index.getValue();
298           });
299       if (all_are_dimops && candidate) {
300         equal_shapes_.unionSets(candidate.getAsOpaquePointer(),
301                                 value.getAsOpaquePointer());
302       }
303     }
304   }
305 
306   // These are values with identical shapes (or rather their opaque pointers).
307   llvm::EquivalenceClasses<void *> equal_shapes_;
308   // A map from a value that encodes a shape to a value that has this shape.
309   llvm::DenseMap<ShapeValue, Value> symbolic_shapes_;
310 };
311 
312 /// For arguments to kernels that have the same shape, use the stride and
313 /// shape information of the left-most argument inside of the kernel function.
314 /// That way, llvm can CSE index computations on same-shaped inputs.
315 struct PropagateShapeKnowledgeToKernels
316     : public PropagateShapeKnowledgeToKernelsBase<
317           PropagateShapeKnowledgeToKernels> {
runOnFunctionmlir::kernel_gen::transforms::__anonb552731d0211::PropagateShapeKnowledgeToKernels318   void runOnFunction() override {
319     ShapeEqualityKnowledge knowledge;
320 
321     knowledge.build(getFunction());
322 
323     getFunction().walk([&](gpu::LaunchFuncOp launch) {
324       auto module = launch->getParentOfType<ModuleOp>();
325       auto kernel = module.lookupSymbol<LLVM::LLVMFuncOp>(launch.kernel());
326 
327       if (!kernel || kernel.isExternal()) return;
328 
329       llvm::SmallVector<std::pair<Value, int>, 4> seen_memrefs;
330       // Position of the kernel argument we are currently at.
331       int kernel_p = 0;
332       for (auto operand : launch.operands()) {
333         auto memref = operand.getType().dyn_cast<MemRefType>();
334         if (!memref) {
335           // Scalar argument, advance kernel position by one.
336           kernel_p++;
337           continue;
338         }
339         for (auto previous : seen_memrefs) {
340           if (!knowledge.haveSameShape(operand, previous.first)) {
341             continue;
342           }
343           auto previous_type = previous.first.getType().cast<MemRefType>();
344           // We use the first equality found and replace uses of corresponding
345           // size and (potentially) stride information here.
346           auto args_to_replace = memref.getRank();
347           auto all_maps_are_identity = [](ArrayRef<AffineMap> maps) {
348             return llvm::all_of(maps,
349                                 [](AffineMap map) { return map.isIdentity(); });
350           };
351           // If both memrefs have identity maps, we can also reuse the strides
352           // here, as they are the identity strides and hence fully determinded
353           // by the shape.
354           if (all_maps_are_identity(previous_type.getAffineMaps()) &&
355               all_maps_are_identity(memref.getAffineMaps())) {
356             args_to_replace *= 2;
357           }
358           int previous_args_pos = previous.second;
359           auto previous_args = kernel.getArguments()
360                                    .drop_front(previous_args_pos + 3)
361                                    .take_front(args_to_replace);
362           auto current_args = kernel.getArguments()
363                                   .drop_front(kernel_p + 3)
364                                   .take_front(args_to_replace);
365           for (auto pair : llvm::zip(previous_args, current_args)) {
366             mlir::BlockArgument prev, curr;
367             std::tie(prev, curr) = pair;
368             curr.replaceAllUsesWith(prev);
369           }
370           break;
371         }
372         seen_memrefs.push_back({operand, kernel_p});
373         // Advance base, aligned, offset, strides and sizes many arguments.
374         kernel_p += memref.getRank() * 2 + 3;
375       }
376     });
377   }
378 };
379 
380 }  // namespace
381 
CreatePropagateShapeKnowledgeToKernels()382 std::unique_ptr<FunctionPass> CreatePropagateShapeKnowledgeToKernels() {
383   return std::make_unique<PropagateShapeKnowledgeToKernels>();
384 }
385 
386 }  // namespace transforms
387 }  // namespace kernel_gen
388 }  // namespace mlir
389