1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/coreml/builders/fully_connected_op_builder.h"
16
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/delegates/coreml/builders/activation_layer_builder.h"
20 #include "tensorflow/lite/delegates/coreml/builders/op_factory.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23
24 namespace tflite {
25 namespace delegates {
26 namespace coreml {
DebugName()27 const std::string& FullyConnectedOpBuilder::DebugName() {
28 if (debug_name_.empty()) SetDebugName("FullyConnectedOpBuilder", node_id_);
29 return debug_name_;
30 }
31
SetWeights(TfLiteTensor * weights)32 void FullyConnectedOpBuilder::SetWeights(TfLiteTensor* weights) {
33 weights_ = weights;
34 }
35
SetBias(TfLiteTensor * bias)36 void FullyConnectedOpBuilder::SetBias(TfLiteTensor* bias) { bias_ = bias; }
37
Build()38 CoreML::Specification::NeuralNetworkLayer* FullyConnectedOpBuilder::Build() {
39 if (layer_ == nullptr) {
40 layer_.reset(new CoreML::Specification::NeuralNetworkLayer);
41 }
42 layer_->set_name(DebugName());
43
44 FillCoreMLWeights();
45 FillCoreMLBias();
46
47 return layer_.release();
48 }
49
FillCoreMLWeights()50 void FullyConnectedOpBuilder::FillCoreMLWeights() {
51 layer_->mutable_innerproduct()->set_inputchannels(weights_->dims->data[1]);
52 layer_->mutable_innerproduct()->set_outputchannels(weights_->dims->data[0]);
53 if (weights_->type == kTfLiteFloat32) {
54 const float* weights_data = GetTensorData<float>(weights_);
55 std::copy(weights_data, weights_data + NumElements(weights_),
56 google::protobuf::RepeatedFieldBackInserter(layer_->mutable_innerproduct()
57 ->mutable_weights()
58 ->mutable_floatvalue()));
59 } else if (weights_->type == kTfLiteFloat16) {
60 // float16value has type of bytes (std::string)
61 layer_->mutable_innerproduct()
62 ->mutable_weights()
63 ->mutable_float16value()
64 ->assign(weights_->data.raw, weights_->bytes);
65 }
66 }
67
FillCoreMLBias()68 void FullyConnectedOpBuilder::FillCoreMLBias() {
69 if (bias_ != nullptr) {
70 layer_->mutable_innerproduct()->set_hasbias(true);
71 if (bias_->type == kTfLiteFloat32) {
72 const float* bias_data = GetTensorData<float>(bias_);
73 std::copy(bias_data, bias_data + NumElements(bias_),
74 google::protobuf::RepeatedFieldBackInserter(layer_->mutable_innerproduct()
75 ->mutable_bias()
76 ->mutable_floatvalue()));
77 } else if (bias_->type == kTfLiteFloat16) {
78 // float16value has type of bytes (std::string)
79 layer_->mutable_innerproduct()
80 ->mutable_bias()
81 ->mutable_float16value()
82 ->assign(bias_->data.raw, bias_->bytes);
83 }
84 }
85 }
86
PopulateSubgraph(TfLiteContext * context)87 TfLiteStatus FullyConnectedOpBuilder::PopulateSubgraph(TfLiteContext* context) {
88 const auto* fc_params =
89 reinterpret_cast<const TfLiteFullyConnectedParams*>(builtin_data_);
90 TfLiteFusedActivation activation = fc_params->activation;
91
92 if (activation == kTfLiteActNone) {
93 builder_output_ = AddOutput();
94 } else {
95 ActivationLayerBuilder* activation_builder =
96 reinterpret_cast<ActivationLayerBuilder*>(
97 graph_builder_->AddBuilder(CreateActivationLayerBuilder, nullptr));
98 activation_builder->SetActivation(activation);
99 activation_builder->AddInput(AddOutput());
100 activation_builder->PopulateSubgraph(context);
101 builder_output_ = activation_builder->GetOutput(context);
102 }
103 return kTfLiteOk;
104 }
105
RegisterInputs(const TfLiteIntArray * inputs,TfLiteContext * context)106 TfLiteStatus FullyConnectedOpBuilder::RegisterInputs(
107 const TfLiteIntArray* inputs, TfLiteContext* context) {
108 const int kInput = 0;
109 const int kWeights = 1;
110 const int kBias = 2;
111 AddInput(inputs->data[kInput]);
112 SetWeights(&context->tensors[inputs->data[kWeights]]);
113 if (inputs->size > 2) {
114 SetBias(&context->tensors[inputs->data[kBias]]);
115 }
116 return kTfLiteOk;
117 }
118
RegisterOutputs(const TfLiteIntArray * outputs,TfLiteContext * context)119 TfLiteStatus FullyConnectedOpBuilder::RegisterOutputs(
120 const TfLiteIntArray* outputs, TfLiteContext* context) {
121 if (outputs->size != 1) {
122 TF_LITE_KERNEL_LOG(context, "Wrong # of outputs!.");
123 return kTfLiteError;
124 }
125 TensorID output_tensor = GetOutput(context);
126 if (output_tensor.NodeID() == -1) {
127 TF_LITE_KERNEL_LOG(context, "Failed to build output tensor.");
128 return kTfLiteError;
129 }
130 graph_builder_->AddTensorWithID(outputs->data[0], output_tensor);
131 return kTfLiteOk;
132 }
133
CreateFullyConnectedOpBuilder(GraphBuilder * graph_builder)134 OpBuilder* CreateFullyConnectedOpBuilder(GraphBuilder* graph_builder) {
135 return new FullyConnectedOpBuilder(graph_builder);
136 }
137
IsFloatType(TfLiteType type)138 bool IsFloatType(TfLiteType type) {
139 return type == kTfLiteFloat32 || type == kTfLiteFloat16;
140 }
141
IsFullyConnectedOpSupported(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context)142 bool IsFullyConnectedOpSupported(const TfLiteRegistration* registration,
143 const TfLiteNode* node,
144 TfLiteContext* context) {
145 if (node->builtin_data == nullptr) return false;
146 const auto* fc_params =
147 reinterpret_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
148 const int kInput = 0;
149 const int kWeights = 1;
150 const int kBias = 2;
151
152 if (fc_params->weights_format != kTfLiteFullyConnectedWeightsFormatDefault) {
153 return false;
154 }
155 const TfLiteTensor* input;
156 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
157 const TfLiteTensor* weights;
158 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kWeights, &weights));
159
160 if (!IsFloatType(input->type)) {
161 return false;
162 }
163 if (!IsFloatType(weights->type) || !IsConstantTensor(weights)) {
164 return false;
165 }
166 // Core ML 2 only supports single-batch fully connected layer, thus dimensions
167 // except the last one should be 1.
168 if (input->dims->data[input->dims->size - 1] != NumElements(input)) {
169 return false;
170 }
171
172 if (node->inputs->size > 2) {
173 const TfLiteTensor* bias;
174 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBias, &bias));
175 if (!IsFloatType(bias->type) || !IsConstantTensor(bias)) {
176 return false;
177 }
178 }
179
180 TfLiteFusedActivation activation = fc_params->activation;
181 if (activation == kTfLiteActSignBit) {
182 return false;
183 }
184 return true;
185 }
186
187 } // namespace coreml
188 } // namespace delegates
189 } // namespace tflite
190