1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/coreml/builders/pooling_layer_builder.h"
16
17 #include "tensorflow/lite/builtin_ops.h"
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/delegates/coreml/builders/op_factory.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23
24 namespace tflite {
25 namespace delegates {
26 namespace coreml {
27
DebugName()28 const std::string& PoolingLayerBuilder::DebugName() {
29 if (!debug_name_.empty()) return debug_name_;
30 switch (pooling_type_) {
31 case kTfLiteBuiltinAveragePool2d:
32 SetDebugName("PoolingLayerBuilder (AVERAGE)", node_id_);
33 break;
34 case kTfLiteBuiltinMaxPool2d:
35 SetDebugName("PoolingLayerBuilder (MAX)", node_id_);
36 break;
37 case kTfLiteBuiltinL2Pool2d:
38 SetDebugName("PoolingLayerBuilder (L2, unsupported)", node_id_);
39 break;
40 case kTfLiteBuiltinMean:
41 SetDebugName("PoolingLayerBuilder (MEAN)", node_id_);
42 break;
43 default:
44 SetDebugName("PoolingLayerBuilder (ERROR)", node_id_);
45 }
46 return debug_name_;
47 }
48
Build()49 CoreML::Specification::NeuralNetworkLayer* PoolingLayerBuilder::Build() {
50 layer_->set_name(DebugName());
51 auto* pooling_params = layer_->mutable_pooling();
52
53 if (pooling_type_ == kTfLiteBuiltinMean) {
54 pooling_params->set_type(
55 CoreML::Specification::PoolingLayerParams::AVERAGE);
56 pooling_params->set_globalpooling(true);
57 return layer_.release();
58 }
59
60 const TfLitePoolParams* params =
61 reinterpret_cast<const TfLitePoolParams*>(builtin_data_);
62 pooling_params->mutable_stride()->Add(params->stride_height);
63 pooling_params->mutable_stride()->Add(params->stride_width);
64 pooling_params->mutable_kernelsize()->Add(params->filter_height);
65 pooling_params->mutable_kernelsize()->Add(params->filter_width);
66
67 if (params->padding == kTfLitePaddingSame) {
68 pooling_params->mutable_same();
69 } else {
70 pooling_params->mutable_valid();
71 }
72
73 switch (pooling_type_) {
74 case kTfLiteBuiltinAveragePool2d:
75 pooling_params->set_type(
76 CoreML::Specification::PoolingLayerParams::AVERAGE);
77 pooling_params->set_avgpoolexcludepadding(true);
78 break;
79 case kTfLiteBuiltinMaxPool2d:
80 pooling_params->set_type(CoreML::Specification::PoolingLayerParams::MAX);
81 break;
82 case kTfLiteBuiltinL2Pool2d:
83 // TODO(b/145873272) implement L2 pooling
84 // NOLINTNEXTLINE: minimize absl usage
85 fprintf(stderr, "L2 pooling is not supported yet.\n");
86 return nullptr;
87 default:
88 // NOLINTNEXTLINE: minimize absl usage
89 fprintf(stderr, "Unexpected pooling type.\n"); // Should not reach here.
90 return nullptr;
91 }
92
93 // TODO(b/145582958): Add padding values.
94 // TODO(b/145582958): Handle fused activation function.
95 return layer_.release();
96 }
97
RegisterInputs(const TfLiteIntArray * inputs,TfLiteContext * context)98 TfLiteStatus PoolingLayerBuilder::RegisterInputs(const TfLiteIntArray* inputs,
99 TfLiteContext* context) {
100 if (pooling_type_ == kTfLiteBuiltinMean) {
101 if (inputs->size != 2) {
102 TF_LITE_KERNEL_LOG(context, "Wrong # of inputs to Mean!.");
103 return kTfLiteError;
104 }
105 } else if (inputs->size != 1) {
106 TF_LITE_KERNEL_LOG(context, "Wrong # of inputs to Pooling!.");
107 return kTfLiteError;
108 }
109 AddInput(inputs->data[0]);
110 return kTfLiteOk;
111 }
112
RegisterOutputs(const TfLiteIntArray * outputs,TfLiteContext * context)113 TfLiteStatus PoolingLayerBuilder::RegisterOutputs(const TfLiteIntArray* outputs,
114 TfLiteContext* context) {
115 if (outputs->size != 1) {
116 TF_LITE_KERNEL_LOG(context, "Wrong # of outputs to Pooling!.");
117 return kTfLiteError;
118 }
119 graph_builder_->AddTensorWithID(outputs->data[0], GetOutput(context));
120 return kTfLiteOk;
121 }
122
CreateAveragePool2dOpBuilder(GraphBuilder * graph_builder)123 OpBuilder* CreateAveragePool2dOpBuilder(GraphBuilder* graph_builder) {
124 return new PoolingLayerBuilder(graph_builder, kTfLiteBuiltinAveragePool2d);
125 }
126
CreateMaxPool2dOpBuilder(GraphBuilder * graph_builder)127 OpBuilder* CreateMaxPool2dOpBuilder(GraphBuilder* graph_builder) {
128 return new PoolingLayerBuilder(graph_builder, kTfLiteBuiltinMaxPool2d);
129 }
130
CreateMeanOpBuilder(GraphBuilder * graph_builder)131 OpBuilder* CreateMeanOpBuilder(GraphBuilder* graph_builder) {
132 return new PoolingLayerBuilder(graph_builder, kTfLiteBuiltinMean);
133 }
134
135 // Only supports averaging over H and W dimensions, as
IsMeanOpSupported(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context)136 bool IsMeanOpSupported(const TfLiteRegistration* registration,
137 const TfLiteNode* node, TfLiteContext* context) {
138 const TfLiteTensor* input = GetInput(context, node, 0);
139 const TfLiteTensor* axis = GetInput(context, node, 1);
140 const auto* params =
141 reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
142
143 if (!params->keep_dims) {
144 TF_LITE_KERNEL_LOG(context, "keep_dims should be true for Mean op.");
145 return false;
146 }
147 if (input->dims->size != 4) {
148 TF_LITE_KERNEL_LOG(context, "Mean op is only supported for 4D input.");
149 return false;
150 }
151 const int* axis_data = GetTensorData<int>(axis);
152 std::vector<bool> axis_mask = {false, true, true, false};
153 for (int i = 0; i < axis->dims->data[0]; ++i) {
154 if (!axis_mask[(axis_data[i] + 4) % 4]) {
155 TF_LITE_KERNEL_LOG(context,
156 "Mean op should reduce for H and W dimensions.");
157 return false;
158 }
159 }
160 return true;
161 }
162
163 } // namespace coreml
164 } // namespace delegates
165 } // namespace tflite
166