• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/coreml/builders/fully_connected_op_builder.h"
16 
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/delegates/coreml/builders/activation_layer_builder.h"
20 #include "tensorflow/lite/delegates/coreml/builders/op_factory.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 
24 namespace tflite {
25 namespace delegates {
26 namespace coreml {
DebugName()27 const std::string& FullyConnectedOpBuilder::DebugName() {
28   if (debug_name_.empty()) SetDebugName("FullyConnectedOpBuilder", node_id_);
29   return debug_name_;
30 }
31 
SetWeights(TfLiteTensor * weights)32 void FullyConnectedOpBuilder::SetWeights(TfLiteTensor* weights) {
33   weights_ = weights;
34 }
35 
SetBias(TfLiteTensor * bias)36 void FullyConnectedOpBuilder::SetBias(TfLiteTensor* bias) { bias_ = bias; }
37 
Build()38 CoreML::Specification::NeuralNetworkLayer* FullyConnectedOpBuilder::Build() {
39   if (layer_ == nullptr) {
40     layer_.reset(new CoreML::Specification::NeuralNetworkLayer);
41   }
42   layer_->set_name(DebugName());
43 
44   FillCoreMLWeights();
45   FillCoreMLBias();
46 
47   return layer_.release();
48 }
49 
FillCoreMLWeights()50 void FullyConnectedOpBuilder::FillCoreMLWeights() {
51   layer_->mutable_innerproduct()->set_inputchannels(weights_->dims->data[1]);
52   layer_->mutable_innerproduct()->set_outputchannels(weights_->dims->data[0]);
53   if (weights_->type == kTfLiteFloat32) {
54     const float* weights_data = GetTensorData<float>(weights_);
55     std::copy(weights_data, weights_data + NumElements(weights_),
56               google::protobuf::RepeatedFieldBackInserter(layer_->mutable_innerproduct()
57                                                     ->mutable_weights()
58                                                     ->mutable_floatvalue()));
59   } else if (weights_->type == kTfLiteFloat16) {
60     // float16value has type of bytes (std::string)
61     layer_->mutable_innerproduct()
62         ->mutable_weights()
63         ->mutable_float16value()
64         ->assign(weights_->data.raw, weights_->bytes);
65   }
66 }
67 
FillCoreMLBias()68 void FullyConnectedOpBuilder::FillCoreMLBias() {
69   if (bias_ != nullptr) {
70     layer_->mutable_innerproduct()->set_hasbias(true);
71     if (bias_->type == kTfLiteFloat32) {
72       const float* bias_data = GetTensorData<float>(bias_);
73       std::copy(bias_data, bias_data + NumElements(bias_),
74                 google::protobuf::RepeatedFieldBackInserter(layer_->mutable_innerproduct()
75                                                       ->mutable_bias()
76                                                       ->mutable_floatvalue()));
77     } else if (bias_->type == kTfLiteFloat16) {
78       // float16value has type of bytes (std::string)
79       layer_->mutable_innerproduct()
80           ->mutable_bias()
81           ->mutable_float16value()
82           ->assign(bias_->data.raw, bias_->bytes);
83     }
84   }
85 }
86 
PopulateSubgraph(TfLiteContext * context)87 TfLiteStatus FullyConnectedOpBuilder::PopulateSubgraph(TfLiteContext* context) {
88   const auto* fc_params =
89       reinterpret_cast<const TfLiteFullyConnectedParams*>(builtin_data_);
90   TfLiteFusedActivation activation = fc_params->activation;
91 
92   if (activation == kTfLiteActNone) {
93     builder_output_ = AddOutput();
94   } else {
95     ActivationLayerBuilder* activation_builder =
96         reinterpret_cast<ActivationLayerBuilder*>(
97             graph_builder_->AddBuilder(CreateActivationLayerBuilder, nullptr));
98     activation_builder->SetActivation(activation);
99     activation_builder->AddInput(AddOutput());
100     activation_builder->PopulateSubgraph(context);
101     builder_output_ = activation_builder->GetOutput(context);
102   }
103   return kTfLiteOk;
104 }
105 
RegisterInputs(const TfLiteIntArray * inputs,TfLiteContext * context)106 TfLiteStatus FullyConnectedOpBuilder::RegisterInputs(
107     const TfLiteIntArray* inputs, TfLiteContext* context) {
108   const int kInput = 0;
109   const int kWeights = 1;
110   const int kBias = 2;
111   AddInput(inputs->data[kInput]);
112   SetWeights(&context->tensors[inputs->data[kWeights]]);
113   if (inputs->size > 2) {
114     SetBias(&context->tensors[inputs->data[kBias]]);
115   }
116   return kTfLiteOk;
117 }
118 
RegisterOutputs(const TfLiteIntArray * outputs,TfLiteContext * context)119 TfLiteStatus FullyConnectedOpBuilder::RegisterOutputs(
120     const TfLiteIntArray* outputs, TfLiteContext* context) {
121   if (outputs->size != 1) {
122     TF_LITE_KERNEL_LOG(context, "Wrong # of outputs!.");
123     return kTfLiteError;
124   }
125   TensorID output_tensor = GetOutput(context);
126   if (output_tensor.NodeID() == -1) {
127     TF_LITE_KERNEL_LOG(context, "Failed to build output tensor.");
128     return kTfLiteError;
129   }
130   graph_builder_->AddTensorWithID(outputs->data[0], output_tensor);
131   return kTfLiteOk;
132 }
133 
CreateFullyConnectedOpBuilder(GraphBuilder * graph_builder)134 OpBuilder* CreateFullyConnectedOpBuilder(GraphBuilder* graph_builder) {
135   return new FullyConnectedOpBuilder(graph_builder);
136 }
137 
IsFloatType(TfLiteType type)138 bool IsFloatType(TfLiteType type) {
139   return type == kTfLiteFloat32 || type == kTfLiteFloat16;
140 }
141 
IsFullyConnectedOpSupported(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context)142 bool IsFullyConnectedOpSupported(const TfLiteRegistration* registration,
143                                  const TfLiteNode* node,
144                                  TfLiteContext* context) {
145   if (node->builtin_data == nullptr) return false;
146   const auto* fc_params =
147       reinterpret_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
148   const int kInput = 0;
149   const int kWeights = 1;
150   const int kBias = 2;
151 
152   if (fc_params->weights_format != kTfLiteFullyConnectedWeightsFormatDefault) {
153     return false;
154   }
155   const TfLiteTensor* input;
156   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
157   const TfLiteTensor* weights;
158   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kWeights, &weights));
159 
160   if (!IsFloatType(input->type)) {
161     return false;
162   }
163   if (!IsFloatType(weights->type) || !IsConstantTensor(weights)) {
164     return false;
165   }
166   // Core ML 2 only supports single-batch fully connected layer, thus dimensions
167   // except the last one should be 1.
168   if (input->dims->data[input->dims->size - 1] != NumElements(input)) {
169     return false;
170   }
171 
172   if (node->inputs->size > 2) {
173     const TfLiteTensor* bias;
174     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBias, &bias));
175     if (!IsFloatType(bias->type) || !IsConstantTensor(bias)) {
176       return false;
177     }
178   }
179 
180   TfLiteFusedActivation activation = fc_params->activation;
181   if (activation == kTfLiteActSignBit) {
182     return false;
183   }
184   return true;
185 }
186 
187 }  // namespace coreml
188 }  // namespace delegates
189 }  // namespace tflite
190