1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/utils/codegen_utils.h"
17
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/Pass/Pass.h"
23
24 using llvm::SmallVector;
25
26 namespace mlir {
27 namespace codegen_utils {
28
emitNumElementsComputation(OpBuilder & b,Location loc,Value memref)29 Value emitNumElementsComputation(OpBuilder& b, Location loc, Value memref) {
30 int rank = memref.getType().cast<MemRefType>().getRank();
31 Value num_elements;
32 num_elements = b.create<mlir::ConstantOp>(
33 loc, b.getIndexType(), b.getIntegerAttr(b.getIndexType(), 1));
34 for (int r = 0; r < rank; ++r) {
35 auto dim_size = b.create<memref::DimOp>(loc, memref, r);
36 num_elements = b.create<MulIOp>(loc, num_elements, dim_size);
37 }
38 return num_elements;
39 }
40
emitNumElementsComputation(OpBuilder & b,Location loc,Operation * op)41 Value emitNumElementsComputation(OpBuilder& b, Location loc, Operation* op) {
42 // only const rank is supported for now
43 assert(op->getDialect()->getNamespace() == "lmhlo");
44 int num_operands = op->getNumOperands();
45 Value result_memref = op->getOperand(num_operands - 1);
46 return emitNumElementsComputation(b, loc, result_memref);
47 }
48
calcMultiDimIndex(OpBuilder & b,Location loc,Value linear_index,ArrayRef<Value> shape)49 SmallVector<Value, 4> calcMultiDimIndex(OpBuilder& b, Location loc,
50 Value linear_index,
51 ArrayRef<Value> shape) {
52 int rank = shape.size();
53 SmallVector<Value, 4> result;
54 if (rank == 0) return result;
55 if (rank == 1) {
56 result.push_back(linear_index);
57 return result;
58 }
59
60 // dim_acc_mul_vec = [d, c*d, b*c*d]
61 std::vector<Value> dim_acc_mul_vec;
62 Value tmp_acc_mul = shape[rank - 1];
63 dim_acc_mul_vec.emplace_back(tmp_acc_mul);
64 for (int i = rank - 2; i > 0; --i) {
65 tmp_acc_mul = b.create<MulIOp>(loc, tmp_acc_mul, shape[i]);
66 dim_acc_mul_vec.emplace_back(tmp_acc_mul);
67 }
68 Value block_index = linear_index;
69 for (int i = 0; i < rank; ++i) {
70 Value index;
71 if (i == rank - 1) {
72 index = block_index;
73 } else {
74 index =
75 b.create<UnsignedDivIOp>(loc, block_index, dim_acc_mul_vec.back());
76 block_index =
77 b.create<UnsignedRemIOp>(loc, block_index, dim_acc_mul_vec.back());
78 dim_acc_mul_vec.pop_back();
79 }
80 result.push_back(index);
81 }
82 return result;
83 }
84
calcMultiDimIndex(OpBuilder & b,Location loc,Value linear_index,Value memref)85 SmallVector<Value, 4> calcMultiDimIndex(OpBuilder& b, Location loc,
86 Value linear_index, Value memref) {
87 int rank = memref.getType().cast<MemRefType>().getRank();
88 SmallVector<Value, 4> result;
89 if (rank == 0) return result;
90 if (rank == 1) {
91 result.push_back(linear_index);
92 return result;
93 }
94 // shape = [a, b, c, d]
95 SmallVector<Value, 4> shape_vec;
96 for (int i = 0; i < rank; ++i) {
97 shape_vec.push_back(b.create<memref::DimOp>(loc, memref, i));
98 }
99
100 return calcMultiDimIndex(b, loc, linear_index, shape_vec);
101 }
102
calcMultiDimIndexForFirstOperand(OpBuilder & b,Location loc,Value linear_index,Operation * op)103 SmallVector<Value, 4> calcMultiDimIndexForFirstOperand(OpBuilder& b,
104 Location loc,
105 Value linear_index,
106 Operation* op) {
107 assert(op->getDialect()->getNamespace() == "lmhlo");
108 Value operand_memref = op->getOperand(0);
109 return calcMultiDimIndex(b, loc, linear_index, operand_memref);
110 }
111
112 } // namespace codegen_utils
113 } // namespace mlir
114