• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/SmallVector.h"
17 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 
28 namespace mlir {
29 namespace mhlo {
30 
31 namespace {
32 
33 // Broadcasts the 1D value tensor 'value_1d' to the shape of 'result_type'. If
34 // 'shape_value' is initialized, creates a dynamic broadcast, otherwise creates
35 // a static broadcast.
BroadcastToFeatureDim(Location loc,RankedTensorType result_type,Value value_1d,Value shape_value,int64_t feature_dim,PatternRewriter & rewriter)36 Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type,
37                             Value value_1d, Value shape_value,
38                             int64_t feature_dim,
39                             PatternRewriter& rewriter) {  // NOLINT
40   Builder b(rewriter.getContext());
41   auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64));
42   auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim});
43   if (shape_value) {
44     return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
45         loc, result_type, value_1d, shape_value, dims);
46   }
47   assert(result_type.hasStaticShape());
48   return rewriter.create<mhlo::BroadcastInDimOp>(loc, result_type, value_1d,
49                                                  dims);
50 }
51 
52 // Calculate the shape value of operand, assuming it is a dynamic shape with
53 // static rank.
CalculateShapeValue(Location loc,Value operand,PatternRewriter & rewriter)54 Value CalculateShapeValue(Location loc, Value operand,
55                           PatternRewriter& rewriter) {  // NOLINT
56   RankedTensorType result_type = operand.getType().dyn_cast<RankedTensorType>();
57   llvm::SmallVector<Value, 4> shape_values;
58   int64_t rank = result_type.getRank();
59   shape_values.reserve(rank);
60   for (int64_t i = 0; i < rank; ++i) {
61     shape_values.push_back(
62         rewriter.create<mlir::tensor::DimOp>(loc, operand, i));
63   }
64   return rewriter.create<tensor::FromElementsOp>(loc, shape_values);
65 }
66 
MaterializeEpsilon(Operation * op,FloatAttr epsilon_attr,FloatType fp_type,Value variance,RankedTensorType broadcast_to_type,PatternRewriter & rewriter)67 Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,
68                          FloatType fp_type, Value variance,
69                          RankedTensorType broadcast_to_type,
70                          PatternRewriter& rewriter) {  // NOLINT
71   Builder b(rewriter.getContext());
72   if (epsilon_attr.getType() != fp_type) {
73     // Need to convert.
74     bool loses_info;
75     APFloat epsilon_float = epsilon_attr.getValue();
76     auto status = epsilon_float.convert(
77         fp_type.getFloatSemantics(), APFloat::rmNearestTiesToEven, &loses_info);
78     if ((status & (~APFloat::opInexact)) != APFloat::opOK) {
79       op->emitWarning() << "Could not convert batch_norm epsilon to target fp "
80                            "type: opStatus = "
81                         << static_cast<int>(status);
82       return nullptr;
83     }
84     if (loses_info) {
85       op->emitWarning("Conversion of epsilon loses precision");
86     }
87     epsilon_attr = b.getFloatAttr(fp_type, epsilon_float);
88   }
89 
90   auto scalar_type = RankedTensorType::get({}, fp_type);
91   auto epsilon_tensor_attr =
92       DenseElementsAttr::get(scalar_type, {epsilon_attr.cast<Attribute>()});
93   Value epsilon =
94       rewriter.create<mhlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
95   auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64));
96   auto dims = DenseIntElementsAttr::get(dims_type, SmallVector<int64_t, 1>{});
97   if (broadcast_to_type.hasStaticShape()) {
98     return rewriter.create<mhlo::BroadcastInDimOp>(
99         op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims);
100   }
101   Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter);
102   return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
103       op->getLoc(), broadcast_to_type, epsilon, shape_value,
104       /*broadcast_dims=*/dims);
105 }
106 
107 class UnfuseBatchNormInferencePattern
108     : public OpRewritePattern<mhlo::BatchNormInferenceOp> {
109  public:
110   using OpRewritePattern<mhlo::BatchNormInferenceOp>::OpRewritePattern;
111 
matchAndRewrite(mhlo::BatchNormInferenceOp bn_op,PatternRewriter & rewriter) const112   LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bn_op,
113                                 PatternRewriter& rewriter) const override {
114     // Enforce type invariants.
115     // Note that we deduce the actual element type from the variance,
116     // which should not be subject to quantization at a higher level.
117     auto input_type = bn_op.operand().getType().dyn_cast<RankedTensorType>();
118     auto variance_type =
119         bn_op.variance().getType().dyn_cast<RankedTensorType>();
120     if (!input_type || !variance_type) {
121       return failure();
122     }
123     auto fp_type = variance_type.getElementType().dyn_cast<FloatType>();
124     if (!fp_type) {
125       return failure();
126     }
127     int64_t feature_dim = bn_op.feature_index();
128 
129     // Add epsilon to the variance and sqrt to get stddev:
130     // stddev = sqrt(variance + epsilon)
131     auto epsilon =
132         MaterializeEpsilon(bn_op.getOperation(), bn_op.epsilonAttr(), fp_type,
133                            bn_op.variance(), variance_type, rewriter);
134     if (!epsilon) {
135       return failure();
136     }
137     Value stddev =
138         rewriter.create<mhlo::AddOp>(bn_op.getLoc(), bn_op.variance(), epsilon);
139     stddev = rewriter.create<mhlo::SqrtOp>(bn_op.getLoc(), stddev);
140 
141     // Broadcast all terms.
142     Value shape_value;
143     if (!input_type.hasStaticShape()) {
144       shape_value =
145           CalculateShapeValue(bn_op.getLoc(), bn_op.operand(), rewriter);
146     }
147     auto broadcast_scale =
148         BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.scale(),
149                               shape_value, feature_dim, rewriter);
150     auto broadcast_offset =
151         BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.offset(),
152                               shape_value, feature_dim, rewriter);
153     auto broadcast_mean =
154         BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.mean(),
155                               shape_value, feature_dim, rewriter);
156     auto broadcast_stddev = BroadcastToFeatureDim(
157         bn_op.getLoc(), input_type, stddev, shape_value, feature_dim, rewriter);
158 
159     // Compute:
160     // scale * (input - mean) / stddev + offset
161     Value result = rewriter.create<mhlo::SubOp>(bn_op.getLoc(), bn_op.operand(),
162                                                 broadcast_mean);
163     result =
164         rewriter.create<mhlo::MulOp>(bn_op.getLoc(), result, broadcast_scale);
165     result =
166         rewriter.create<mhlo::DivOp>(bn_op.getLoc(), result, broadcast_stddev);
167     rewriter.replaceOpWithNewOp<mhlo::AddOp>(bn_op, result, broadcast_offset);
168 
169     return success();
170   }
171 };
172 
173 }  // namespace
174 
175 // Populates conversion patterns to unfuse batch normalization operations.
176 // In combination with marking such ops as illegal, this allows backends that
177 // do not have special support for fused batchnorm to use simpler arithmetic
178 // primitives.
PopulateUnfuseBatchNormPatterns(MLIRContext * context,OwningRewritePatternList * patterns)179 void PopulateUnfuseBatchNormPatterns(MLIRContext* context,
180                                      OwningRewritePatternList* patterns) {
181   patterns->insert<UnfuseBatchNormInferencePattern>(context);
182 }
183 
184 }  // namespace mhlo
185 }  // namespace mlir
186