1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file canonicalize reduction ops in hlo dialect to match the
17 // capacity of codegen backend.
18
19 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22
23 namespace mlir {
24 namespace mhlo {
25 namespace {
26
27 // All the reduce ops can be divided into following four types:
28 // - a) column reduction, only reduce the most significant dimensions.
29 // - b) row reduction, only reduce the least significant dimensions.
30 // - c) reduce to scalar, all dimensions are reduced.
31 // - d) others. (not support now, maybe use transpose to canonicalize)
32 //
33 // Currently we do following canonicalization to match the capacity of codegen
34 // backend.
35 //
36 // For case a):
37 // ====================================================================================
38 // we convert all column reduction to rank-2 column reduction.
39 // For example, suppose we have:
40 // ```
41 // func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
42 // ...
43 // %2 = "mhlo.reduce"(%arg0, ...) ( {...})
44 // {dimensions = dense<[0]> : tensor<1xi64>} :
45 // (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
46 // return %2 : tensor<?x?xf32>
47 // }
48 // ```
49 // After conversion:
50 // func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
51 // // [a, b, c] -> [a, b*c]
52 // %1 = mhlo.dynamic_reshape(%arg0, ...) : (tensor<?x?x?xf32>,
53 // tensor<2xi64>) -> tensor<?x?xf32> %2 = "mhlo.reduce"(%1, ...) ( {...})
54 // {dimensions = dense<[0]> : tensor<1xi64>} :
55 // (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
56 // %3 = "mhlo.dynamic_reshape"(%2, ...) : (tensor<?xf32>, tensor<1xi64>)
57 // -> tensor<?x?f32> return %3 : tensor<?x?xf32>
58 // }
59 // ```
60 //
61 // For case b):
62 // ====================================================================================
63 // we convert all row reduction to rank-2 row reduction.
64 // For example, suppose we have:
65 // ```
66 // func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
67 // ...
68 // %2 = "mhlo.reduce"(%arg0, ...) ( {...})
69 // {dimensions = dense<[2]> : tensor<1xi64>} :
70 // (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
71 // return %2 : tensor<?x?xf32>
72 // }
73 // ```
74 // After conversion:
75 // func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
76 // // [a, b, c] -> [a*b, c]
77 // %1 = mhlo.dynamic_reshape(%arg0, ...) : (tensor<?x?x?xf32>,
78 // tensor<2xi64>) -> tensor<?x?xf32> %2 = "mhlo.reduce"(%1, ...) ( {...})
79 // {dimensions = dense<[1]> : tensor<1xi64>} :
80 // (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
81 // %3 = "mhlo.dynamic_reshape"(%2, ...) : (tensor<?xf32>, tensor<1xi64>)
82 // -> tensor<?x?f32> return %3 : tensor<?x?xf32>
83 // }
84 // ```
85 //
86 // For case c):
87 // ====================================================================================
88 // we convert all reduce-to-scalar to rank-2 column reduction.
89 //
90 // For example, suppose we have:
91 // ```
92 // func @test(%arg0: tensor<?x?x?xf32>) -> tensor<f32> {
93 // ...
94 // %2 = "mhlo.reduce"(%arg0, ...) ( {...})
95 // {dimensions = dense<[0,1,2]> : tensor<3xi64>} :
96 // (tensor<?x?x?xf32>, tensor<f32>) -> tensor<f32>
97 // return %2 : tensor<f32>
98 // }
99 // ```
100 // After conversion:
101 // func @test(%arg0: tensor<?x?x?xf32>) -> tensor<f32> {
102 // // [a, b, c] -> [a*b*c, 1]
103 // %1 = mhlo.dynamic_reshape(%arg0, ...) : (tensor<?x?x?xf32>,
104 // tensor<2xi64>) -> tensor<?x?xf32> %2 = "mhlo.reduce"(%1, ...) ( {...})
105 // {dimensions = dense<[0]> : tensor<1xi64>} :
106 // (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
107 // %3 = "mhlo.reshape"(%2, ...) : (tensor<?xf32>, tensor<1xi64>) ->
108 // tensor<f32> return %3 : tensor<f32>
109 // }
110 // ```
111
112 struct HloCanonicalizeReductionPass
113 : HloCanonicalizeReductionPassBase<HloCanonicalizeReductionPass> {
getDependentDialectsmlir::mhlo::__anon8940eb270111::HloCanonicalizeReductionPass114 void getDependentDialects(DialectRegistry& registry) const override {
115 registry.insert<tensor::TensorDialect>();
116 }
runOnFunctionmlir::mhlo::__anon8940eb270111::HloCanonicalizeReductionPass117 void runOnFunction() override {
118 getFunction().walk([&](ReduceOp op) {
119 SmallVector<int64_t, 4> dims_to_reduce;
120 DenseSet<int64_t> dims_to_reduce_set;
121 for (auto dim : op.dimensions().getIntValues()) {
122 dims_to_reduce.push_back(dim.getSExtValue());
123 dims_to_reduce_set.insert(dims_to_reduce.back());
124 }
125
126 // empty reduction is just a no-op, thus no need to do codegen.
127 if (dims_to_reduce.empty()) return;
128
129 // suppose reduce input is a ranked tensor
130 auto ty = op.getOperand(0).getType().dyn_cast<RankedTensorType>();
131 if (!ty) return signalPassFailure();
132 int rank = ty.getRank();
133 int ndims_to_reduce = dims_to_reduce.size();
134 auto elem_ty = ty.getElementType();
135 llvm::sort(dims_to_reduce);
136
137 // skip case d) form since we don't support it.
138 if ((dims_to_reduce.back() - dims_to_reduce[0]) !=
139 (ndims_to_reduce - 1) ||
140 (dims_to_reduce[0] != 0 && dims_to_reduce.back() != (rank - 1))) {
141 return;
142 }
143
144 // rank 2 row/column reduction is already supported.
145 if (rank == 2 && ndims_to_reduce == 1) {
146 return;
147 }
148
149 SmallVector<int64_t, 4> dims_to_keep;
150 for (int i = 0; i < rank; ++i) {
151 if (!dims_to_reduce_set.count(i)) dims_to_keep.push_back(i);
152 }
153
154 OpBuilder b(op);
155 auto loc = op.getLoc();
156 // TODO(disc): uniformed shape_scalar_type with shape_derivation
157 auto shape_scalar_type = b.getIntegerType(32);
158 auto one = b.create<ConstantIntOp>(loc, 1ll, shape_scalar_type);
159
160 // funtion to get total elements in selected dimensions
161 auto dim_prod = [&](ArrayRef<int64_t> dims) {
162 Value nelems = one;
163 for (int64_t v : dims) {
164 Value dim_index = b.create<tensor::DimOp>(loc, op.getOperand(0), v);
165 nelems = b.create<MulIOp>(
166 loc, nelems,
167 b.create<IndexCastOp>(loc, dim_index, shape_scalar_type));
168 }
169 return nelems;
170 };
171
172 SmallVector<Value, 2> new_operand_dims;
173 DenseIntElementsAttr attr;
174 Value nelem_to_reduce = dim_prod(dims_to_reduce);
175 Value nelem_to_keep = dim_prod(dims_to_keep);
176 if (rank == ndims_to_reduce) {
177 // case c) Reduce to scalar.
178 // Currently we don't support reduce to scalar directly.
179 // As a workaround, we convert the `reduce to scalar` to a rank 2
180 // column reduction having following form:
181 // Suppose nelems = ProdutionOp(ShapeOp(I)), We convert I into
182 // shape `[nelems, 1]`.
183 // TODO(disc): this may have performance issue. Implements a reduce to
184 // scalar schedule if necessary.
185 new_operand_dims.push_back(nelem_to_reduce);
186 new_operand_dims.push_back(nelem_to_keep);
187 attr = DenseIntElementsAttr::get(
188 RankedTensorType::get({1}, b.getIntegerType(64)), {0ll});
189 } else if (dims_to_reduce[0] == 0) {
190 // case a) column reduction
191 new_operand_dims.push_back(nelem_to_reduce);
192 new_operand_dims.push_back(nelem_to_keep);
193 attr = DenseIntElementsAttr::get(
194 RankedTensorType::get({1}, b.getIntegerType(64)), {0ll});
195 } else {
196 // case b) row reduction
197 new_operand_dims.push_back(nelem_to_keep);
198 new_operand_dims.push_back(nelem_to_reduce);
199 attr = DenseIntElementsAttr::get(
200 RankedTensorType::get({1}, b.getIntegerType(64)), {1ll});
201 }
202
203 Value new_operand_shape =
204 b.create<tensor::FromElementsOp>(loc, new_operand_dims);
205
206 SmallVector<Value, 4> new_operands;
207 for (Value operand : op.inputs()) {
208 new_operands.push_back(b.create<DynamicReshapeOp>(
209 loc,
210 RankedTensorType::get(
211 SmallVector<int64_t, 4>(new_operand_dims.size(),
212 ShapedType::kDynamicSize),
213 elem_ty),
214 operand, new_operand_shape));
215 }
216 auto new_op =
217 b.create<ReduceOp>(loc, new_operands, op.init_values(), attr);
218 new_op.body().takeBody(op.body());
219
220 SmallVector<Value, 4> new_results;
221 if (dims_to_keep.empty()) {
222 // case c) reduce to scalar
223 // reshape rank 1 tensor with size 1 to a rank 0 tensor
224 for (Value result : new_op.getResults()) {
225 new_results.push_back(b.create<ReshapeOp>(
226 loc, RankedTensorType::get({}, elem_ty), result));
227 }
228 } else {
229 SmallVector<Value, 4> result_dims;
230 for (int64_t i : dims_to_keep) {
231 Value dim_index = b.create<tensor::DimOp>(loc, op.getOperand(0), i);
232 result_dims.push_back(
233 b.create<IndexCastOp>(loc, dim_index, shape_scalar_type));
234 }
235 Value result_shape = b.create<tensor::FromElementsOp>(loc, result_dims);
236 for (auto&& e : llvm::zip(op.getResults(), new_op.getResults())) {
237 new_results.push_back(b.create<DynamicReshapeOp>(
238 loc, std::get<0>(e).getType(), std::get<1>(e), result_shape));
239 }
240 }
241 for (auto&& e : llvm::zip(op.getResults(), new_results)) {
242 std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
243 }
244 op.erase();
245 });
246 }
247 };
248
249 } // namespace
250
createHloCanonicalizeReductionPass()251 std::unique_ptr<FunctionPass> createHloCanonicalizeReductionPass() {
252 return std::make_unique<HloCanonicalizeReductionPass>();
253 }
254
255 } // namespace mhlo
256 } // namespace mlir
257