1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file provides basic utilities for the elemental lowering of
17 // each node
18
19 #include "mlir-hlo/Dialect/mhlo/transforms/lhlo_elemental_utils.h"
20
21 #include "llvm/Support/Debug.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
24 #include "mlir-hlo/utils/codegen_utils.h"
25 #include "mlir/Dialect/GPU/GPUDialect.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/SCF/SCF.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/Location.h"
32 #include "mlir/IR/MLIRContext.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Pass/Pass.h"
35
36 using mlir::memref::DimOp;
37 using mlir::memref::LoadOp;
38 using mlir::memref::StoreOp;
39
40 namespace mlir {
41 namespace lmhlo {
42
createLoadOrUseCachedValue(Location loc,OpBuilder * b,Value memref,ValueRange indices,OpBuilder::InsertPoint insert_point)43 Value createLoadOrUseCachedValue(Location loc, OpBuilder* b, Value memref,
44 ValueRange indices,
45 OpBuilder::InsertPoint insert_point) {
46 // Check if there are any cached value that can be reused,
47 // within the current Block. Alternatively we can do this for
48 // all the Blocks that dominant this Block, but that will be
49 // complicated anyway.
50 std::vector<StoreOp> store_ops;
51 insert_point.getBlock()->walk(
52 insert_point.getBlock()->begin(), insert_point.getPoint(),
53 [&](StoreOp store_op) {
54 if (store_op.getOperation()->getBlock() != insert_point.getBlock())
55 return;
56 if ((store_op.getMemRef() == memref) &&
57 (store_op.getIndices() == indices))
58 store_ops.emplace_back(store_op);
59 });
60 if (!store_ops.empty()) return store_ops[0].getOperand(0);
61 int rank = memref.getType().dyn_cast<MemRefType>().getRank();
62 return rank > 0 ? b->create<LoadOp>(loc, memref, indices)
63 : b->create<LoadOp>(loc, memref);
64 }
65
NoLoaderUser(SmallVectorImpl<Operation * > & ops)66 DenseSet<Operation*> NoLoaderUser(SmallVectorImpl<Operation*>& ops) {
67 SmallVector<Operation*, 4> worklist;
68 DenseSet<Operation*> has_loader_ops;
69 for (Operation* op : ops) {
70 Value memref = cast<LmhloOp>(op).getResultBuffer();
71 if (memref == nullptr) continue;
72 for (auto* user : memref.getUsers()) {
73 if (isa<memref::LoadOp>(user)) {
74 worklist.push_back(op);
75 has_loader_ops.insert(op);
76 }
77 }
78 }
79
80 while (!worklist.empty()) {
81 Operation* op = worklist.pop_back_val();
82 int num_operands = op->getNumOperands();
83 for (int i = 0; i < num_operands - 1; ++i) {
84 Value memref = op->getOperand(i);
85 for (Operation* user : memref.getUsers()) {
86 if ((!isa<LmhloOp>(user)) || has_loader_ops.count(user)) continue;
87 if (cast<LmhloOp>(user).getResultBuffer() == memref) {
88 worklist.push_back(user);
89 has_loader_ops.insert(user);
90 }
91 }
92 }
93 }
94
95 DenseSet<Operation*> no_loader_ops;
96 for (Operation* op : ops)
97 if (!has_loader_ops.count(op)) no_loader_ops.insert(op);
98 return no_loader_ops;
99 }
100
cleanUnusedLhloOps(Block * parent)101 void cleanUnusedLhloOps(Block* parent) {
102 SmallVector<Operation*, 4> lhlo_ops;
103 for (Operation& op : parent->getOperations()) {
104 if (op.getDialect() == op.getContext()->getLoadedDialect("lmhlo") &&
105 (!isa<lmhlo::TerminatorOp>(op)))
106 lhlo_ops.push_back(&op);
107 }
108 const DenseSet<Operation*>& no_loader_user = NoLoaderUser(lhlo_ops);
109 for (auto* lhlo_op : no_loader_user) lhlo_op->erase();
110 }
111
112 template <typename LHLO_OpTy>
113 Value elementalLower(OpBuilder* b, Location loc, LHLO_OpTy op,
114 ValueRange output_index, bool check_cache);
115
116 template <>
elementalLower(OpBuilder * b,Location loc,lmhlo::RealDynamicSliceOp op,ValueRange output_index,bool check_cache)117 Value elementalLower<lmhlo::RealDynamicSliceOp>(OpBuilder* b, Location loc,
118 lmhlo::RealDynamicSliceOp op,
119 ValueRange output_index,
120 bool check_cache) {
121 Value start_indices_memref = op->getOperand(1);
122 Value strides_memref = op->getOperand(3);
123 int rank = output_index.size();
124 SmallVector<Value, 4> input_index;
125 for (int dim = 0; dim < rank; ++dim) {
126 SmallVector<Value, 4> dim_index;
127 dim_index.push_back(b->create<ConstantOp>(
128 loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), dim)));
129 auto start_index_load =
130 b->create<LoadOp>(loc, start_indices_memref, ValueRange{dim_index});
131 auto start_index =
132 b->create<IndexCastOp>(loc, b->getIndexType(), start_index_load);
133 auto stride_load =
134 b->create<LoadOp>(loc, strides_memref, ValueRange{dim_index});
135 auto stride = b->create<IndexCastOp>(loc, b->getIndexType(), stride_load);
136 // input_dim = out_dim * stride + start_index
137 auto input_dim = b->create<AddIOp>(
138 loc, b->create<MulIOp>(loc, output_index[dim], stride), start_index);
139 input_index.push_back(input_dim);
140 }
141
142 Value operand_memref = *(op->getOperands().begin());
143
144 if (!check_cache) return b->create<LoadOp>(loc, operand_memref, input_index);
145 return createLoadOrUseCachedValue(loc, b, operand_memref, input_index,
146 b->saveInsertionPoint());
147 }
148
149 namespace {
150
151 template <typename T>
elementalLowerImplForBroadcastInDimOps(OpBuilder * b,Location loc,T broadcast_in_dim,ValueRange output_index,bool check_cache)152 Value elementalLowerImplForBroadcastInDimOps(OpBuilder* b, Location loc,
153 T broadcast_in_dim,
154 ValueRange output_index,
155 bool check_cache) {
156 auto broadcast_dimensions =
157 broadcast_in_dim.broadcast_dimensions().template getValues<int64_t>();
158 int out_rank = output_index.size();
159 Value operand_memref = broadcast_in_dim->getOperand(0);
160 SmallVector<Value, 4> input_index;
161 for (int64_t dim = 0; dim < out_rank; ++dim) {
162 auto it = std::find(broadcast_dimensions.begin(),
163 broadcast_dimensions.end(), dim);
164
165 bool is_broadcast_dim = (it != broadcast_dimensions.end());
166 if (is_broadcast_dim) {
167 int input_dim = std::distance(broadcast_dimensions.begin(), it);
168 int64_t static_dim_size =
169 operand_memref.getType().cast<MemRefType>().getShape()[input_dim];
170 if (static_dim_size == 1) {
171 // we know this dim is to be broadcasted at compile time
172 auto zero = b->create<ConstantOp>(
173 loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 0));
174 input_index.push_back(zero);
175 } else if (static_dim_size == ShapedType::kDynamicSize) {
176 // we are not sure if this dim is to be broadcasted at compile time
177 auto dim_size = b->create<DimOp>(loc, operand_memref, input_dim);
178 auto one = b->create<ConstantOp>(
179 loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 1));
180 auto zero = b->create<ConstantOp>(
181 loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 0));
182 auto dim_size_is_1 =
183 b->create<CmpIOp>(loc, CmpIPredicate::eq, dim_size, one);
184 input_index.push_back(b->create<mlir::SelectOp>(
185 loc, dim_size_is_1, zero, output_index[dim]));
186 } else {
187 // we know this dim is not to be broadcasted at compile time
188 input_index.push_back(output_index[dim]);
189 }
190 }
191 }
192
193 if (!check_cache) {
194 int rank = operand_memref.getType().dyn_cast<MemRefType>().getRank();
195 return (rank > 0) ? b->create<LoadOp>(loc, operand_memref, input_index)
196 : b->create<LoadOp>(loc, operand_memref, ValueRange());
197 }
198 return createLoadOrUseCachedValue(loc, b, operand_memref, input_index,
199 b->saveInsertionPoint());
200 }
201
202 } // namespace
203
204 template <>
elementalLower(OpBuilder * b,Location loc,lmhlo::DynamicBroadcastInDimOp op,ValueRange output_index,bool check_cache)205 Value elementalLower<lmhlo::DynamicBroadcastInDimOp>(
206 OpBuilder* b, Location loc, lmhlo::DynamicBroadcastInDimOp op,
207 ValueRange output_index, bool check_cache) {
208 return elementalLowerImplForBroadcastInDimOps(b, loc, op, output_index,
209 check_cache);
210 }
211
212 template <>
elementalLower(OpBuilder * b,Location loc,lmhlo::BroadcastInDimOp op,ValueRange output_index,bool check_cache)213 Value elementalLower<lmhlo::BroadcastInDimOp>(OpBuilder* b, Location loc,
214 lmhlo::BroadcastInDimOp op,
215 ValueRange output_index,
216 bool check_cache) {
217 return elementalLowerImplForBroadcastInDimOps(b, loc, op, output_index,
218 check_cache);
219 }
220
createLoopAndSetInsPt(OpBuilder & b,Location loc,Value & var,Value lb,Value ub,Value step,ArrayRef<Value> init_values)221 scf::ForOp createLoopAndSetInsPt(OpBuilder& b, Location loc, Value& var,
222 Value lb, Value ub, Value step,
223 ArrayRef<Value> init_values) {
224 auto for_op = b.create<scf::ForOp>(loc, lb, ub, step, init_values);
225 b.setInsertionPointToStart(for_op.getBody());
226 var = for_op.getInductionVar();
227 return for_op;
228 }
229
createParallelAndSetInsPt(OpBuilder & b,Location loc,SmallVectorImpl<Value> & vars,ArrayRef<Value> lbs,ArrayRef<Value> ubs,ArrayRef<Value> steps,ArrayRef<Value> init_values)230 scf::ParallelOp createParallelAndSetInsPt(OpBuilder& b, Location loc,
231 SmallVectorImpl<Value>& vars,
232 ArrayRef<Value> lbs,
233 ArrayRef<Value> ubs,
234 ArrayRef<Value> steps,
235 ArrayRef<Value> init_values) {
236 auto par_op = b.create<scf::ParallelOp>(loc, lbs, ubs, steps, init_values,
237 /*bodyBuilderFn=*/nullptr);
238 b.setInsertionPointToStart(par_op.getBody());
239 vars.append(par_op.getInductionVars().begin(),
240 par_op.getInductionVars().end());
241 return par_op;
242 }
243
244 // reinterpret_cast the input memref into 1D
createMemRef1DReinterpretCast(OpBuilder & b,Location loc,Value memref)245 memref::ReinterpretCastOp createMemRef1DReinterpretCast(OpBuilder& b,
246 Location loc,
247 Value memref) {
248 auto memref_ty = memref.getType().cast<MemRefType>();
249 assert(memref_ty.getAffineMaps().empty());
250 Value size = codegen_utils::emitNumElementsComputation(b, loc, memref);
251 Value stride = b.create<mlir::ConstantOp>(
252 loc, b.getIndexType(), b.getIntegerAttr(b.getIndexType(), 1));
253 Value zero = b.create<mlir::ConstantOp>(
254 loc, b.getIndexType(), b.getIntegerAttr(b.getIndexType(), 0));
255 auto memref_1d_type =
256 MemRefType::get({MemRefType::kDynamicSize}, memref_ty.getElementType(),
257 memref_ty.getAffineMaps(), memref_ty.getMemorySpace());
258 return b.create<memref::ReinterpretCastOp>(
259 loc, memref_1d_type, memref, zero, ValueRange{size}, ValueRange{stride});
260 }
261
createOffsetStore(OpBuilder & b,Location loc,Value res,Value memref,Value offset)262 void createOffsetStore(OpBuilder& b, Location loc, Value res, Value memref,
263 Value offset) {
264 Value memref_1d = createMemRef1DReinterpretCast(b, loc, memref);
265 b.create<memref::StoreOp>(loc, res, memref_1d, ValueRange{offset});
266 }
267
createOffsetLoad(OpBuilder & b,Location loc,Value memref,Value offset)268 memref::LoadOp createOffsetLoad(OpBuilder& b, Location loc, Value memref,
269 Value offset) {
270 Value memref_1d = createMemRef1DReinterpretCast(b, loc, memref);
271 return b.create<memref::LoadOp>(loc, memref_1d, ValueRange{offset});
272 }
273
274 } // namespace lmhlo
275 } // namespace mlir
276