• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/utils/codegen_utils.h"
17 
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/Pass/Pass.h"
23 
24 using llvm::SmallVector;
25 
26 namespace mlir {
27 namespace codegen_utils {
28 
emitNumElementsComputation(OpBuilder & b,Location loc,Value memref)29 Value emitNumElementsComputation(OpBuilder& b, Location loc, Value memref) {
30   int rank = memref.getType().cast<MemRefType>().getRank();
31   Value num_elements;
32   num_elements = b.create<mlir::ConstantOp>(
33       loc, b.getIndexType(), b.getIntegerAttr(b.getIndexType(), 1));
34   for (int r = 0; r < rank; ++r) {
35     auto dim_size = b.create<memref::DimOp>(loc, memref, r);
36     num_elements = b.create<MulIOp>(loc, num_elements, dim_size);
37   }
38   return num_elements;
39 }
40 
emitNumElementsComputation(OpBuilder & b,Location loc,Operation * op)41 Value emitNumElementsComputation(OpBuilder& b, Location loc, Operation* op) {
42   // only const rank is supported for now
43   assert(op->getDialect()->getNamespace() == "lmhlo");
44   int num_operands = op->getNumOperands();
45   Value result_memref = op->getOperand(num_operands - 1);
46   return emitNumElementsComputation(b, loc, result_memref);
47 }
48 
calcMultiDimIndex(OpBuilder & b,Location loc,Value linear_index,ArrayRef<Value> shape)49 SmallVector<Value, 4> calcMultiDimIndex(OpBuilder& b, Location loc,
50                                         Value linear_index,
51                                         ArrayRef<Value> shape) {
52   int rank = shape.size();
53   SmallVector<Value, 4> result;
54   if (rank == 0) return result;
55   if (rank == 1) {
56     result.push_back(linear_index);
57     return result;
58   }
59 
60   // dim_acc_mul_vec = [d, c*d, b*c*d]
61   std::vector<Value> dim_acc_mul_vec;
62   Value tmp_acc_mul = shape[rank - 1];
63   dim_acc_mul_vec.emplace_back(tmp_acc_mul);
64   for (int i = rank - 2; i > 0; --i) {
65     tmp_acc_mul = b.create<MulIOp>(loc, tmp_acc_mul, shape[i]);
66     dim_acc_mul_vec.emplace_back(tmp_acc_mul);
67   }
68   Value block_index = linear_index;
69   for (int i = 0; i < rank; ++i) {
70     Value index;
71     if (i == rank - 1) {
72       index = block_index;
73     } else {
74       index =
75           b.create<UnsignedDivIOp>(loc, block_index, dim_acc_mul_vec.back());
76       block_index =
77           b.create<UnsignedRemIOp>(loc, block_index, dim_acc_mul_vec.back());
78       dim_acc_mul_vec.pop_back();
79     }
80     result.push_back(index);
81   }
82   return result;
83 }
84 
calcMultiDimIndex(OpBuilder & b,Location loc,Value linear_index,Value memref)85 SmallVector<Value, 4> calcMultiDimIndex(OpBuilder& b, Location loc,
86                                         Value linear_index, Value memref) {
87   int rank = memref.getType().cast<MemRefType>().getRank();
88   SmallVector<Value, 4> result;
89   if (rank == 0) return result;
90   if (rank == 1) {
91     result.push_back(linear_index);
92     return result;
93   }
94   // shape = [a, b, c, d]
95   SmallVector<Value, 4> shape_vec;
96   for (int i = 0; i < rank; ++i) {
97     shape_vec.push_back(b.create<memref::DimOp>(loc, memref, i));
98   }
99 
100   return calcMultiDimIndex(b, loc, linear_index, shape_vec);
101 }
102 
calcMultiDimIndexForFirstOperand(OpBuilder & b,Location loc,Value linear_index,Operation * op)103 SmallVector<Value, 4> calcMultiDimIndexForFirstOperand(OpBuilder& b,
104                                                        Location loc,
105                                                        Value linear_index,
106                                                        Operation* op) {
107   assert(op->getDialect()->getNamespace() == "lmhlo");
108   Value operand_memref = op->getOperand(0);
109   return calcMultiDimIndex(b, loc, linear_index, operand_memref);
110 }
111 
112 }  // namespace codegen_utils
113 }  // namespace mlir
114