1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/SmallVector.h"
17 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/Transforms/DialectConversion.h"
27
28 namespace mlir {
29 namespace mhlo {
30
31 namespace {
32
33 // Broadcasts the 1D value tensor 'value_1d' to the shape of 'result_type'. If
34 // 'shape_value' is initialized, creates a dynamic broadcast, otherwise creates
35 // a static broadcast.
BroadcastToFeatureDim(Location loc,RankedTensorType result_type,Value value_1d,Value shape_value,int64_t feature_dim,PatternRewriter & rewriter)36 Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type,
37 Value value_1d, Value shape_value,
38 int64_t feature_dim,
39 PatternRewriter& rewriter) { // NOLINT
40 Builder b(rewriter.getContext());
41 auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64));
42 auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim});
43 if (shape_value) {
44 return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
45 loc, result_type, value_1d, shape_value, dims);
46 }
47 assert(result_type.hasStaticShape());
48 return rewriter.create<mhlo::BroadcastInDimOp>(loc, result_type, value_1d,
49 dims);
50 }
51
52 // Calculate the shape value of operand, assuming it is a dynamic shape with
53 // static rank.
CalculateShapeValue(Location loc,Value operand,PatternRewriter & rewriter)54 Value CalculateShapeValue(Location loc, Value operand,
55 PatternRewriter& rewriter) { // NOLINT
56 RankedTensorType result_type = operand.getType().dyn_cast<RankedTensorType>();
57 llvm::SmallVector<Value, 4> shape_values;
58 int64_t rank = result_type.getRank();
59 shape_values.reserve(rank);
60 for (int64_t i = 0; i < rank; ++i) {
61 shape_values.push_back(
62 rewriter.create<mlir::tensor::DimOp>(loc, operand, i));
63 }
64 return rewriter.create<tensor::FromElementsOp>(loc, shape_values);
65 }
66
MaterializeEpsilon(Operation * op,FloatAttr epsilon_attr,FloatType fp_type,Value variance,RankedTensorType broadcast_to_type,PatternRewriter & rewriter)67 Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,
68 FloatType fp_type, Value variance,
69 RankedTensorType broadcast_to_type,
70 PatternRewriter& rewriter) { // NOLINT
71 Builder b(rewriter.getContext());
72 if (epsilon_attr.getType() != fp_type) {
73 // Need to convert.
74 bool loses_info;
75 APFloat epsilon_float = epsilon_attr.getValue();
76 auto status = epsilon_float.convert(
77 fp_type.getFloatSemantics(), APFloat::rmNearestTiesToEven, &loses_info);
78 if ((status & (~APFloat::opInexact)) != APFloat::opOK) {
79 op->emitWarning() << "Could not convert batch_norm epsilon to target fp "
80 "type: opStatus = "
81 << static_cast<int>(status);
82 return nullptr;
83 }
84 if (loses_info) {
85 op->emitWarning("Conversion of epsilon loses precision");
86 }
87 epsilon_attr = b.getFloatAttr(fp_type, epsilon_float);
88 }
89
90 auto scalar_type = RankedTensorType::get({}, fp_type);
91 auto epsilon_tensor_attr =
92 DenseElementsAttr::get(scalar_type, {epsilon_attr.cast<Attribute>()});
93 Value epsilon =
94 rewriter.create<mhlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
95 auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64));
96 auto dims = DenseIntElementsAttr::get(dims_type, SmallVector<int64_t, 1>{});
97 if (broadcast_to_type.hasStaticShape()) {
98 return rewriter.create<mhlo::BroadcastInDimOp>(
99 op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims);
100 }
101 Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter);
102 return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
103 op->getLoc(), broadcast_to_type, epsilon, shape_value,
104 /*broadcast_dims=*/dims);
105 }
106
107 class UnfuseBatchNormInferencePattern
108 : public OpRewritePattern<mhlo::BatchNormInferenceOp> {
109 public:
110 using OpRewritePattern<mhlo::BatchNormInferenceOp>::OpRewritePattern;
111
matchAndRewrite(mhlo::BatchNormInferenceOp bn_op,PatternRewriter & rewriter) const112 LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bn_op,
113 PatternRewriter& rewriter) const override {
114 // Enforce type invariants.
115 // Note that we deduce the actual element type from the variance,
116 // which should not be subject to quantization at a higher level.
117 auto input_type = bn_op.operand().getType().dyn_cast<RankedTensorType>();
118 auto variance_type =
119 bn_op.variance().getType().dyn_cast<RankedTensorType>();
120 if (!input_type || !variance_type) {
121 return failure();
122 }
123 auto fp_type = variance_type.getElementType().dyn_cast<FloatType>();
124 if (!fp_type) {
125 return failure();
126 }
127 int64_t feature_dim = bn_op.feature_index();
128
129 // Add epsilon to the variance and sqrt to get stddev:
130 // stddev = sqrt(variance + epsilon)
131 auto epsilon =
132 MaterializeEpsilon(bn_op.getOperation(), bn_op.epsilonAttr(), fp_type,
133 bn_op.variance(), variance_type, rewriter);
134 if (!epsilon) {
135 return failure();
136 }
137 Value stddev =
138 rewriter.create<mhlo::AddOp>(bn_op.getLoc(), bn_op.variance(), epsilon);
139 stddev = rewriter.create<mhlo::SqrtOp>(bn_op.getLoc(), stddev);
140
141 // Broadcast all terms.
142 Value shape_value;
143 if (!input_type.hasStaticShape()) {
144 shape_value =
145 CalculateShapeValue(bn_op.getLoc(), bn_op.operand(), rewriter);
146 }
147 auto broadcast_scale =
148 BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.scale(),
149 shape_value, feature_dim, rewriter);
150 auto broadcast_offset =
151 BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.offset(),
152 shape_value, feature_dim, rewriter);
153 auto broadcast_mean =
154 BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.mean(),
155 shape_value, feature_dim, rewriter);
156 auto broadcast_stddev = BroadcastToFeatureDim(
157 bn_op.getLoc(), input_type, stddev, shape_value, feature_dim, rewriter);
158
159 // Compute:
160 // scale * (input - mean) / stddev + offset
161 Value result = rewriter.create<mhlo::SubOp>(bn_op.getLoc(), bn_op.operand(),
162 broadcast_mean);
163 result =
164 rewriter.create<mhlo::MulOp>(bn_op.getLoc(), result, broadcast_scale);
165 result =
166 rewriter.create<mhlo::DivOp>(bn_op.getLoc(), result, broadcast_stddev);
167 rewriter.replaceOpWithNewOp<mhlo::AddOp>(bn_op, result, broadcast_offset);
168
169 return success();
170 }
171 };
172
173 } // namespace
174
175 // Populates conversion patterns to unfuse batch normalization operations.
176 // In combination with marking such ops as illegal, this allows backends that
177 // do not have special support for fused batchnorm to use simpler arithmetic
178 // primitives.
PopulateUnfuseBatchNormPatterns(MLIRContext * context,OwningRewritePatternList * patterns)179 void PopulateUnfuseBatchNormPatterns(MLIRContext* context,
180 OwningRewritePatternList* patterns) {
181 patterns->insert<UnfuseBatchNormInferencePattern>(context);
182 }
183
184 } // namespace mhlo
185 } // namespace mlir
186