• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/coreml/builders/pooling_layer_builder.h"
16 
17 #include "tensorflow/lite/builtin_ops.h"
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/delegates/coreml/builders/op_factory.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 
24 namespace tflite {
25 namespace delegates {
26 namespace coreml {
27 
DebugName()28 const std::string& PoolingLayerBuilder::DebugName() {
29   if (!debug_name_.empty()) return debug_name_;
30   switch (pooling_type_) {
31     case kTfLiteBuiltinAveragePool2d:
32       SetDebugName("PoolingLayerBuilder (AVERAGE)", node_id_);
33       break;
34     case kTfLiteBuiltinMaxPool2d:
35       SetDebugName("PoolingLayerBuilder (MAX)", node_id_);
36       break;
37     case kTfLiteBuiltinL2Pool2d:
38       SetDebugName("PoolingLayerBuilder (L2, unsupported)", node_id_);
39       break;
40     case kTfLiteBuiltinMean:
41       SetDebugName("PoolingLayerBuilder (MEAN)", node_id_);
42       break;
43     default:
44       SetDebugName("PoolingLayerBuilder (ERROR)", node_id_);
45   }
46   return debug_name_;
47 }
48 
Build()49 CoreML::Specification::NeuralNetworkLayer* PoolingLayerBuilder::Build() {
50   layer_->set_name(DebugName());
51   auto* pooling_params = layer_->mutable_pooling();
52 
53   if (pooling_type_ == kTfLiteBuiltinMean) {
54     pooling_params->set_type(
55         CoreML::Specification::PoolingLayerParams::AVERAGE);
56     pooling_params->set_globalpooling(true);
57     return layer_.release();
58   }
59 
60   const TfLitePoolParams* params =
61       reinterpret_cast<const TfLitePoolParams*>(builtin_data_);
62   pooling_params->mutable_stride()->Add(params->stride_height);
63   pooling_params->mutable_stride()->Add(params->stride_width);
64   pooling_params->mutable_kernelsize()->Add(params->filter_height);
65   pooling_params->mutable_kernelsize()->Add(params->filter_width);
66 
67   if (params->padding == kTfLitePaddingSame) {
68     pooling_params->mutable_same();
69   } else {
70     pooling_params->mutable_valid();
71   }
72 
73   switch (pooling_type_) {
74     case kTfLiteBuiltinAveragePool2d:
75       pooling_params->set_type(
76           CoreML::Specification::PoolingLayerParams::AVERAGE);
77       pooling_params->set_avgpoolexcludepadding(true);
78       break;
79     case kTfLiteBuiltinMaxPool2d:
80       pooling_params->set_type(CoreML::Specification::PoolingLayerParams::MAX);
81       break;
82     case kTfLiteBuiltinL2Pool2d:
83       // TODO(b/145873272) implement L2 pooling
84       // NOLINTNEXTLINE: minimize absl usage
85       fprintf(stderr, "L2 pooling is not supported yet.\n");
86       return nullptr;
87     default:
88       // NOLINTNEXTLINE: minimize absl usage
89       fprintf(stderr, "Unexpected pooling type.\n");  // Should not reach here.
90       return nullptr;
91   }
92 
93   // TODO(b/145582958): Add padding values.
94   // TODO(b/145582958): Handle fused activation function.
95   return layer_.release();
96 }
97 
RegisterInputs(const TfLiteIntArray * inputs,TfLiteContext * context)98 TfLiteStatus PoolingLayerBuilder::RegisterInputs(const TfLiteIntArray* inputs,
99                                                  TfLiteContext* context) {
100   if (pooling_type_ == kTfLiteBuiltinMean) {
101     if (inputs->size != 2) {
102       TF_LITE_KERNEL_LOG(context, "Wrong # of inputs to Mean!.");
103       return kTfLiteError;
104     }
105   } else if (inputs->size != 1) {
106     TF_LITE_KERNEL_LOG(context, "Wrong # of inputs to Pooling!.");
107     return kTfLiteError;
108   }
109   AddInput(inputs->data[0]);
110   return kTfLiteOk;
111 }
112 
RegisterOutputs(const TfLiteIntArray * outputs,TfLiteContext * context)113 TfLiteStatus PoolingLayerBuilder::RegisterOutputs(const TfLiteIntArray* outputs,
114                                                   TfLiteContext* context) {
115   if (outputs->size != 1) {
116     TF_LITE_KERNEL_LOG(context, "Wrong # of outputs to Pooling!.");
117     return kTfLiteError;
118   }
119   graph_builder_->AddTensorWithID(outputs->data[0], GetOutput(context));
120   return kTfLiteOk;
121 }
122 
CreateAveragePool2dOpBuilder(GraphBuilder * graph_builder)123 OpBuilder* CreateAveragePool2dOpBuilder(GraphBuilder* graph_builder) {
124   return new PoolingLayerBuilder(graph_builder, kTfLiteBuiltinAveragePool2d);
125 }
126 
CreateMaxPool2dOpBuilder(GraphBuilder * graph_builder)127 OpBuilder* CreateMaxPool2dOpBuilder(GraphBuilder* graph_builder) {
128   return new PoolingLayerBuilder(graph_builder, kTfLiteBuiltinMaxPool2d);
129 }
130 
CreateMeanOpBuilder(GraphBuilder * graph_builder)131 OpBuilder* CreateMeanOpBuilder(GraphBuilder* graph_builder) {
132   return new PoolingLayerBuilder(graph_builder, kTfLiteBuiltinMean);
133 }
134 
135 // Only supports averaging over H and W dimensions, as
IsMeanOpSupported(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context)136 bool IsMeanOpSupported(const TfLiteRegistration* registration,
137                        const TfLiteNode* node, TfLiteContext* context) {
138   const TfLiteTensor* input = GetInput(context, node, 0);
139   const TfLiteTensor* axis = GetInput(context, node, 1);
140   const auto* params =
141       reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
142 
143   if (!params->keep_dims) {
144     TF_LITE_KERNEL_LOG(context, "keep_dims should be true for Mean op.");
145     return false;
146   }
147   if (input->dims->size != 4) {
148     TF_LITE_KERNEL_LOG(context, "Mean op is only supported for 4D input.");
149     return false;
150   }
151   const int* axis_data = GetTensorData<int>(axis);
152   std::vector<bool> axis_mask = {false, true, true, false};
153   for (int i = 0; i < axis->dims->data[0]; ++i) {
154     if (!axis_mask[(axis_data[i] + 4) % 4]) {
155       TF_LITE_KERNEL_LOG(context,
156                          "Mean op should reduce for H and W dimensions.");
157       return false;
158     }
159   }
160   return true;
161 }
162 
163 }  // namespace coreml
164 }  // namespace delegates
165 }  // namespace tflite
166