1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include "tensorflow/lite/delegates/coreml/coreml_delegate.h" 16 17#include <string.h> 18#include <sys/utsname.h> 19#include <limits> 20#include <vector> 21 22#include "tensorflow/lite/builtin_ops.h" 23#include "tensorflow/lite/c/builtin_op_data.h" 24#include "tensorflow/lite/context_util.h" 25#include "tensorflow/lite/delegates/coreml/builders/op_validator.h" 26#include "tensorflow/lite/delegates/coreml/builders/util.h" 27#include "tensorflow/lite/delegates/coreml/coreml_delegate_kernel.h" 28#include "tensorflow/lite/delegates/utils.h" 29#include "tensorflow/lite/kernels/kernel_util.h" 30#include "tensorflow/lite/minimal_logging.h" 31 32namespace tflite { 33namespace { 34constexpr int kMinNodesPerCoreMlDelegate = 2; 35 36using delegates::coreml::CoreMlDelegateKernel; 37 38bool IsNodeSupportedByDelegate(const TfLiteRegistration* registration, const TfLiteNode* node, 39 TfLiteContext* context, const TfLiteCoreMlDelegateOptions* options) { 40 if (@available(iOS 11.0, *)) { 41 } else { 42 return false; 43 } 44 45 // For most ops, only version 1 is supported. 46 if (registration->version > 1) { 47 switch (registration->builtin_code) { 48 case kTfLiteBuiltinDepthwiseConv2d: 49 if (registration->version > 2) return false; 50 break; 51 // FullyConnected without bias is supported starting from version 6. 52 case kTfLiteBuiltinFullyConnected: 53 if (registration->version > 6) return false; 54 break; 55 default: 56 return false; 57 } 58 } 59 60 // The model should not be full-integer quantized. For ops supported by Core ML delegate, 61 // Testing if the first input is float is sufficient to filter full-integer quantized ops. 62 int input_tensor_index = 0; 63 // TransposeConv input: (output_shape, filters, input) 64 if (registration->builtin_code == kTfLiteBuiltinTransposeConv) { 65 input_tensor_index = 2; 66 } 67 if (GetInput(context, node, input_tensor_index)->type != kTfLiteFloat32) { 68 return false; 69 } 70 71 // TODO(b/149179044): Add extra validation if this is not sufficient. 72 73 // TODO(karimnossier): Refactor this function. 74 // TODO(karimnosseir): Add 75 // 1) Checks for versioning. 76 // 2) Checks for input constraints. 77 // Follow the ordering of TfLiteBuiltinOperator enum. 78 switch (registration->builtin_code) { 79 case kTfLiteBuiltinAdd: { 80 return node->builtin_data != nullptr && 81 delegates::coreml::IsBinaryOpSupported(registration, node, context); 82 } 83 case kTfLiteBuiltinAveragePool2d: { 84 const auto* params = reinterpret_cast<const TfLitePoolParams*>(node->builtin_data); 85 return params != nullptr && params->activation == kTfLiteActNone; 86 } 87 case kTfLiteBuiltinConcatenation: { 88 return delegates::coreml::IsConcatenationOpSupported(registration, node, context); 89 } 90 case kTfLiteBuiltinConv2d: { 91 return delegates::coreml::IsConvolutionOpSupported(registration, node, context); 92 } 93 case kTfLiteBuiltinDepthwiseConv2d: { 94 return delegates::coreml::IsDepthwiseConvolutionOpSupported(registration, node, context); 95 } 96 case kTfLiteBuiltinFullyConnected: { 97 return delegates::coreml::IsFullyConnectedOpSupported(registration, node, context); 98 } 99 case kTfLiteBuiltinHardSwish: { 100 return true; 101 } 102 case kTfLiteBuiltinLogistic: { 103 return true; 104 } 105 case kTfLiteBuiltinMaxPool2d: { 106 const auto* params = reinterpret_cast<const TfLitePoolParams*>(node->builtin_data); 107 return params != nullptr && params->activation == kTfLiteActNone; 108 } 109 case kTfLiteBuiltinMirrorPad: { 110 return delegates::coreml::IsMirrorPadOpSupported(registration, node, context); 111 } 112 case kTfLiteBuiltinMean: { 113 return delegates::coreml::IsMeanOpSupported(registration, node, context); 114 } 115 case kTfLiteBuiltinMul: { 116 return node->builtin_data != nullptr && 117 delegates::coreml::IsBinaryOpSupported(registration, node, context); 118 } 119 case kTfLiteBuiltinPad: 120 case kTfLiteBuiltinPadv2: { 121 return delegates::coreml::IsPadOpSupported(registration, node, context); 122 } 123 case kTfLiteBuiltinRelu: { 124 return true; 125 } 126 case kTfLiteBuiltinReluN1To1: { 127 return true; 128 } 129 case kTfLiteBuiltinRelu6: { 130 return true; 131 } 132 case kTfLiteBuiltinReshape: { 133 return delegates::coreml::IsReshapeOpSupported(registration, node, context, 134 options->coreml_version); 135 } 136 case kTfLiteBuiltinResizeBilinear: { 137 return delegates::coreml::IsResizeBilinearOpSupported(registration, node, context); 138 } 139 case kTfLiteBuiltinSoftmax: { 140 // Only supports when beta is 1.0 for now. 141 const auto* softmax_params = reinterpret_cast<const TfLiteSoftmaxParams*>(node->builtin_data); 142 return softmax_params != nullptr && softmax_params->beta == 1.0; 143 } 144 case kTfLiteBuiltinTanh: { 145 return true; 146 } 147 case kTfLiteBuiltinTransposeConv: { 148 return delegates::coreml::IsTransposeConvolutionOpSupported(registration, node, context); 149 } 150 default: 151 return false; 152 } 153 return false; 154} 155 156class CoreMlDelegate : public TfLiteDelegate { 157 public: 158 explicit CoreMlDelegate(const TfLiteCoreMlDelegateOptions* params) 159 : params_(params != nullptr ? *params : TfLiteCoreMlDelegateOptions()) { 160 { 161 if (@available(iOS 13.0, *)) { 162 if (params_.coreml_version != 2 && params_.coreml_version != 3) { 163 NSLog(@"coreml_version must be 2 or 3. Setting to 3."); 164 params_.coreml_version = 3; 165 } 166 } else if (@available(iOS 12.0, *)) { 167 if (params_.coreml_version != 2) { 168 NSLog(@"coreml_version must be 2 - using Core ML version 2."); 169 params_.coreml_version = 2; 170 } 171 } 172 if (params_.max_delegated_partitions <= 0) { 173 params_.max_delegated_partitions = std::numeric_limits<int>::max(); 174 } 175 if (params_.min_nodes_per_partition <= 0) { 176 params_.min_nodes_per_partition = kMinNodesPerCoreMlDelegate; 177 } 178 } 179 } 180 181 TfLiteCoreMlDelegateOptions* params() { return ¶ms_; } 182 183 bool VerifyDelegate() { return true; } 184 185 private: 186 TfLiteCoreMlDelegateOptions params_; 187}; 188 189TfLiteRegistration GetCoreMlKernelRegistration() { 190 // This is the registration for the Delegate Node that gets added to 191 // the TFLite graph instead of the subGraph it replaces it. 192 // It is treated as an OP node. But in our case 193 // Init will initialize the delegate 194 // Invoke will run the delegate graph. 195 // Prepare for prearing the delegate. 196 // Free for any cleaning needed by the delegate. 197 TfLiteRegistration kernel_registration; 198 kernel_registration.builtin_code = kTfLiteBuiltinDelegate; 199 kernel_registration.custom_name = "TfLiteCoreMlDelegate"; 200 kernel_registration.free = [](TfLiteContext* context, void* buffer) -> void { 201 delete reinterpret_cast<CoreMlDelegateKernel*>(buffer); 202 }; 203 kernel_registration.init = [](TfLiteContext* context, const char* buffer, 204 size_t length) -> void* { 205 const auto* params = reinterpret_cast<const TfLiteDelegateParams*>(buffer); 206 const auto* coreml_options = (reinterpret_cast<CoreMlDelegate*>(params->delegate))->params(); 207 CoreMlDelegateKernel* coreml_kernel = new CoreMlDelegateKernel(coreml_options->coreml_version); 208 if (coreml_kernel->Init(context, params) != kTfLiteOk) { 209 delete coreml_kernel; 210 return nullptr; 211 } 212 return coreml_kernel; 213 }; 214 kernel_registration.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { 215 CoreMlDelegateKernel* kernel = reinterpret_cast<CoreMlDelegateKernel*>(node->user_data); 216 if (!kernel) { 217 TF_LITE_KERNEL_LOG(context, "CoreMl Kernel was not initialized"); 218 return kTfLiteError; 219 } 220 return kernel->Invoke(context, node); 221 }; 222 kernel_registration.prepare = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { 223 CoreMlDelegateKernel* kernel = reinterpret_cast<CoreMlDelegateKernel*>(node->user_data); 224 if (kernel == nullptr) { 225 TF_LITE_KERNEL_LOG(context, "CoreMl Kernel was not initialized"); 226 return kTfLiteError; 227 } 228 return kernel->Prepare(context, node); 229 }; 230 231 return kernel_registration; 232} 233 234TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { 235 const auto* params = reinterpret_cast<TfLiteCoreMlDelegateOptions*>(delegate->data_); 236 237 delegates::IsNodeSupportedFn node_supported_fn = [=](TfLiteContext* context, TfLiteNode* node, 238 TfLiteRegistration* registration, 239 std::string* unsupported_details) -> bool { 240 return IsNodeSupportedByDelegate(registration, node, context, params); 241 }; 242 243 delegates::FP16GraphPartitionHelper partition_helper(context, node_supported_fn); 244 TF_LITE_ENSURE_STATUS(partition_helper.Partition(nullptr)); 245 246 std::vector<int> delegated_nodes = partition_helper.GetNodesOfFirstNLargestPartitions( 247 params->max_delegated_partitions, params->min_nodes_per_partition); 248 TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO, 249 "CoreML delegate: %d nodes delegated out of %d nodes, " 250 "with %d partitions.\n", 251 delegated_nodes.size(), partition_helper.num_total_nodes(), 252 partition_helper.num_partitions()); 253 return context->ReplaceNodeSubsetsWithDelegateKernels( 254 context, GetCoreMlKernelRegistration(), BuildTfLiteIntArray(delegated_nodes).get(), delegate); 255} 256 257TfLiteDelegate* CreateCoreMlDelegate(const TfLiteCoreMlDelegateOptions* options) { 258 TfLiteDelegate* delegate = new CoreMlDelegate(options); 259 if (!static_cast<CoreMlDelegate*>(delegate)->VerifyDelegate()) { 260 delete delegate; 261 return nullptr; 262 } 263 264 delegate->data_ = static_cast<tflite::CoreMlDelegate*>(delegate)->params(); 265 delegate->flags = kTfLiteDelegateFlagsNone; 266 delegate->Prepare = &DelegatePrepare; 267 delegate->CopyFromBufferHandle = nullptr; 268 delegate->CopyToBufferHandle = nullptr; 269 delegate->FreeBufferHandle = nullptr; 270 271 return delegate; 272} 273} // namespace 274} // namespace tflite 275 276namespace { 277// utsname.machine has device identifier. For example, identifier for iPhone Xs is "iPhone11,2". 278// Since Neural Engine is only available for use on A12 and later, major device version in the 279// identifier is checked for these models: 280// A12: iPhone XS (11,2), iPad Mini - 5th Gen (11,1) 281// A12X: iPad Pro - 3rd Gen (8,1) 282// For more information, see https://www.theiphonewiki.com/wiki/Models 283bool IsNeuralEngineAvailable() { 284 struct utsname system_info; 285 uname(&system_info); 286 287 if (strncmp("iPad", system_info.machine, 4) == 0) { 288 const int major_version = atoi(system_info.machine + 4); 289 return major_version >= 8; // There are no device between iPad 8 and 11. 290 } else if (strncmp("iPhone", system_info.machine, 6) == 0) { 291 const int major_version = atoi(system_info.machine + 6); 292 return major_version >= 11; 293 } 294 return false; 295} 296 297} // namespace 298 299TfLiteDelegate* TfLiteCoreMlDelegateCreate(const TfLiteCoreMlDelegateOptions* options) { 300 if (@available(iOS 12.0, *)) { 301 if (options->enabled_devices == TfLiteCoreMlDelegateDevicesWithNeuralEngine && 302 !IsNeuralEngineAvailable()) { 303 NSLog(@"This device does not have Neural Engine, so Core ML delegate will not be enabled. " 304 "If you want to run Core ML delegate anyway, set enabled_devices option to " 305 "TfLiteCoreMlDelegateAllDevices (or enabledDevices to .allDevices in Swift)."); 306 return nullptr; 307 } 308 return tflite::CreateCoreMlDelegate(options); 309 } else { 310 NSLog(@"Core ML delegate is not supported in this iOS version. " 311 "Minimum required iOS version is 12.0."); 312 return nullptr; 313 } 314} 315 316void TfLiteCoreMlDelegateDelete(TfLiteDelegate* delegate) { delete delegate; } 317