• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file provides basic utilities for the elemental lowering of
17 // each node
18 
19 #include "mlir-hlo/Dialect/mhlo/transforms/lhlo_elemental_utils.h"
20 
21 #include "llvm/Support/Debug.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
24 #include "mlir-hlo/utils/codegen_utils.h"
25 #include "mlir/Dialect/GPU/GPUDialect.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/SCF/SCF.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/Location.h"
32 #include "mlir/IR/MLIRContext.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Pass/Pass.h"
35 
36 using mlir::memref::DimOp;
37 using mlir::memref::LoadOp;
38 using mlir::memref::StoreOp;
39 
40 namespace mlir {
41 namespace lmhlo {
42 
createLoadOrUseCachedValue(Location loc,OpBuilder * b,Value memref,ValueRange indices,OpBuilder::InsertPoint insert_point)43 Value createLoadOrUseCachedValue(Location loc, OpBuilder* b, Value memref,
44                                  ValueRange indices,
45                                  OpBuilder::InsertPoint insert_point) {
46   // Check if there are any cached value that can be reused,
47   // within the current Block. Alternatively we can do this for
48   // all the Blocks that dominant this Block, but that will be
49   // complicated anyway.
50   std::vector<StoreOp> store_ops;
51   insert_point.getBlock()->walk(
52       insert_point.getBlock()->begin(), insert_point.getPoint(),
53       [&](StoreOp store_op) {
54         if (store_op.getOperation()->getBlock() != insert_point.getBlock())
55           return;
56         if ((store_op.getMemRef() == memref) &&
57             (store_op.getIndices() == indices))
58           store_ops.emplace_back(store_op);
59       });
60   if (!store_ops.empty()) return store_ops[0].getOperand(0);
61   int rank = memref.getType().dyn_cast<MemRefType>().getRank();
62   return rank > 0 ? b->create<LoadOp>(loc, memref, indices)
63                   : b->create<LoadOp>(loc, memref);
64 }
65 
NoLoaderUser(SmallVectorImpl<Operation * > & ops)66 DenseSet<Operation*> NoLoaderUser(SmallVectorImpl<Operation*>& ops) {
67   SmallVector<Operation*, 4> worklist;
68   DenseSet<Operation*> has_loader_ops;
69   for (Operation* op : ops) {
70     Value memref = cast<LmhloOp>(op).getResultBuffer();
71     if (memref == nullptr) continue;
72     for (auto* user : memref.getUsers()) {
73       if (isa<memref::LoadOp>(user)) {
74         worklist.push_back(op);
75         has_loader_ops.insert(op);
76       }
77     }
78   }
79 
80   while (!worklist.empty()) {
81     Operation* op = worklist.pop_back_val();
82     int num_operands = op->getNumOperands();
83     for (int i = 0; i < num_operands - 1; ++i) {
84       Value memref = op->getOperand(i);
85       for (Operation* user : memref.getUsers()) {
86         if ((!isa<LmhloOp>(user)) || has_loader_ops.count(user)) continue;
87         if (cast<LmhloOp>(user).getResultBuffer() == memref) {
88           worklist.push_back(user);
89           has_loader_ops.insert(user);
90         }
91       }
92     }
93   }
94 
95   DenseSet<Operation*> no_loader_ops;
96   for (Operation* op : ops)
97     if (!has_loader_ops.count(op)) no_loader_ops.insert(op);
98   return no_loader_ops;
99 }
100 
cleanUnusedLhloOps(Block * parent)101 void cleanUnusedLhloOps(Block* parent) {
102   SmallVector<Operation*, 4> lhlo_ops;
103   for (Operation& op : parent->getOperations()) {
104     if (op.getDialect() == op.getContext()->getLoadedDialect("lmhlo") &&
105         (!isa<lmhlo::TerminatorOp>(op)))
106       lhlo_ops.push_back(&op);
107   }
108   const DenseSet<Operation*>& no_loader_user = NoLoaderUser(lhlo_ops);
109   for (auto* lhlo_op : no_loader_user) lhlo_op->erase();
110 }
111 
112 template <typename LHLO_OpTy>
113 Value elementalLower(OpBuilder* b, Location loc, LHLO_OpTy op,
114                      ValueRange output_index, bool check_cache);
115 
116 template <>
elementalLower(OpBuilder * b,Location loc,lmhlo::RealDynamicSliceOp op,ValueRange output_index,bool check_cache)117 Value elementalLower<lmhlo::RealDynamicSliceOp>(OpBuilder* b, Location loc,
118                                                 lmhlo::RealDynamicSliceOp op,
119                                                 ValueRange output_index,
120                                                 bool check_cache) {
121   Value start_indices_memref = op->getOperand(1);
122   Value strides_memref = op->getOperand(3);
123   int rank = output_index.size();
124   SmallVector<Value, 4> input_index;
125   for (int dim = 0; dim < rank; ++dim) {
126     SmallVector<Value, 4> dim_index;
127     dim_index.push_back(b->create<ConstantOp>(
128         loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), dim)));
129     auto start_index_load =
130         b->create<LoadOp>(loc, start_indices_memref, ValueRange{dim_index});
131     auto start_index =
132         b->create<IndexCastOp>(loc, b->getIndexType(), start_index_load);
133     auto stride_load =
134         b->create<LoadOp>(loc, strides_memref, ValueRange{dim_index});
135     auto stride = b->create<IndexCastOp>(loc, b->getIndexType(), stride_load);
136     // input_dim = out_dim * stride + start_index
137     auto input_dim = b->create<AddIOp>(
138         loc, b->create<MulIOp>(loc, output_index[dim], stride), start_index);
139     input_index.push_back(input_dim);
140   }
141 
142   Value operand_memref = *(op->getOperands().begin());
143 
144   if (!check_cache) return b->create<LoadOp>(loc, operand_memref, input_index);
145   return createLoadOrUseCachedValue(loc, b, operand_memref, input_index,
146                                     b->saveInsertionPoint());
147 }
148 
149 namespace {
150 
151 template <typename T>
elementalLowerImplForBroadcastInDimOps(OpBuilder * b,Location loc,T broadcast_in_dim,ValueRange output_index,bool check_cache)152 Value elementalLowerImplForBroadcastInDimOps(OpBuilder* b, Location loc,
153                                              T broadcast_in_dim,
154                                              ValueRange output_index,
155                                              bool check_cache) {
156   auto broadcast_dimensions =
157       broadcast_in_dim.broadcast_dimensions().template getValues<int64_t>();
158   int out_rank = output_index.size();
159   Value operand_memref = broadcast_in_dim->getOperand(0);
160   SmallVector<Value, 4> input_index;
161   for (int64_t dim = 0; dim < out_rank; ++dim) {
162     auto it = std::find(broadcast_dimensions.begin(),
163                         broadcast_dimensions.end(), dim);
164 
165     bool is_broadcast_dim = (it != broadcast_dimensions.end());
166     if (is_broadcast_dim) {
167       int input_dim = std::distance(broadcast_dimensions.begin(), it);
168       int64_t static_dim_size =
169           operand_memref.getType().cast<MemRefType>().getShape()[input_dim];
170       if (static_dim_size == 1) {
171         // we know this dim is to be broadcasted at compile time
172         auto zero = b->create<ConstantOp>(
173             loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 0));
174         input_index.push_back(zero);
175       } else if (static_dim_size == ShapedType::kDynamicSize) {
176         // we are not sure if this dim is to be broadcasted at compile time
177         auto dim_size = b->create<DimOp>(loc, operand_memref, input_dim);
178         auto one = b->create<ConstantOp>(
179             loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 1));
180         auto zero = b->create<ConstantOp>(
181             loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 0));
182         auto dim_size_is_1 =
183             b->create<CmpIOp>(loc, CmpIPredicate::eq, dim_size, one);
184         input_index.push_back(b->create<mlir::SelectOp>(
185             loc, dim_size_is_1, zero, output_index[dim]));
186       } else {
187         // we know this dim is not to be broadcasted at compile time
188         input_index.push_back(output_index[dim]);
189       }
190     }
191   }
192 
193   if (!check_cache) {
194     int rank = operand_memref.getType().dyn_cast<MemRefType>().getRank();
195     return (rank > 0) ? b->create<LoadOp>(loc, operand_memref, input_index)
196                       : b->create<LoadOp>(loc, operand_memref, ValueRange());
197   }
198   return createLoadOrUseCachedValue(loc, b, operand_memref, input_index,
199                                     b->saveInsertionPoint());
200 }
201 
202 }  // namespace
203 
204 template <>
elementalLower(OpBuilder * b,Location loc,lmhlo::DynamicBroadcastInDimOp op,ValueRange output_index,bool check_cache)205 Value elementalLower<lmhlo::DynamicBroadcastInDimOp>(
206     OpBuilder* b, Location loc, lmhlo::DynamicBroadcastInDimOp op,
207     ValueRange output_index, bool check_cache) {
208   return elementalLowerImplForBroadcastInDimOps(b, loc, op, output_index,
209                                                 check_cache);
210 }
211 
212 template <>
elementalLower(OpBuilder * b,Location loc,lmhlo::BroadcastInDimOp op,ValueRange output_index,bool check_cache)213 Value elementalLower<lmhlo::BroadcastInDimOp>(OpBuilder* b, Location loc,
214                                               lmhlo::BroadcastInDimOp op,
215                                               ValueRange output_index,
216                                               bool check_cache) {
217   return elementalLowerImplForBroadcastInDimOps(b, loc, op, output_index,
218                                                 check_cache);
219 }
220 
createLoopAndSetInsPt(OpBuilder & b,Location loc,Value & var,Value lb,Value ub,Value step,ArrayRef<Value> init_values)221 scf::ForOp createLoopAndSetInsPt(OpBuilder& b, Location loc, Value& var,
222                                  Value lb, Value ub, Value step,
223                                  ArrayRef<Value> init_values) {
224   auto for_op = b.create<scf::ForOp>(loc, lb, ub, step, init_values);
225   b.setInsertionPointToStart(for_op.getBody());
226   var = for_op.getInductionVar();
227   return for_op;
228 }
229 
createParallelAndSetInsPt(OpBuilder & b,Location loc,SmallVectorImpl<Value> & vars,ArrayRef<Value> lbs,ArrayRef<Value> ubs,ArrayRef<Value> steps,ArrayRef<Value> init_values)230 scf::ParallelOp createParallelAndSetInsPt(OpBuilder& b, Location loc,
231                                           SmallVectorImpl<Value>& vars,
232                                           ArrayRef<Value> lbs,
233                                           ArrayRef<Value> ubs,
234                                           ArrayRef<Value> steps,
235                                           ArrayRef<Value> init_values) {
236   auto par_op = b.create<scf::ParallelOp>(loc, lbs, ubs, steps, init_values,
237                                           /*bodyBuilderFn=*/nullptr);
238   b.setInsertionPointToStart(par_op.getBody());
239   vars.append(par_op.getInductionVars().begin(),
240               par_op.getInductionVars().end());
241   return par_op;
242 }
243 
244 // reinterpret_cast the input memref into 1D
createMemRef1DReinterpretCast(OpBuilder & b,Location loc,Value memref)245 memref::ReinterpretCastOp createMemRef1DReinterpretCast(OpBuilder& b,
246                                                         Location loc,
247                                                         Value memref) {
248   auto memref_ty = memref.getType().cast<MemRefType>();
249   assert(memref_ty.getAffineMaps().empty());
250   Value size = codegen_utils::emitNumElementsComputation(b, loc, memref);
251   Value stride = b.create<mlir::ConstantOp>(
252       loc, b.getIndexType(), b.getIntegerAttr(b.getIndexType(), 1));
253   Value zero = b.create<mlir::ConstantOp>(
254       loc, b.getIndexType(), b.getIntegerAttr(b.getIndexType(), 0));
255   auto memref_1d_type =
256       MemRefType::get({MemRefType::kDynamicSize}, memref_ty.getElementType(),
257                       memref_ty.getAffineMaps(), memref_ty.getMemorySpace());
258   return b.create<memref::ReinterpretCastOp>(
259       loc, memref_1d_type, memref, zero, ValueRange{size}, ValueRange{stride});
260 }
261 
createOffsetStore(OpBuilder & b,Location loc,Value res,Value memref,Value offset)262 void createOffsetStore(OpBuilder& b, Location loc, Value res, Value memref,
263                        Value offset) {
264   Value memref_1d = createMemRef1DReinterpretCast(b, loc, memref);
265   b.create<memref::StoreOp>(loc, res, memref_1d, ValueRange{offset});
266 }
267 
createOffsetLoad(OpBuilder & b,Location loc,Value memref,Value offset)268 memref::LoadOp createOffsetLoad(OpBuilder& b, Location loc, Value memref,
269                                 Value offset) {
270   Value memref_1d = createMemRef1DReinterpretCast(b, loc, memref);
271   return b.create<memref::LoadOp>(loc, memref_1d, ValueRange{offset});
272 }
273 
274 }  // namespace lmhlo
275 }  // namespace mlir
276