1// Copyright 2018 Google Inc. All rights reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at: 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15#import "tensorflow/lite/objc/apis/TFLInterpreter.h" 16 17#include <vector> 18 19#import "TFLCommonUtil.h" 20#import "TFLErrorUtil.h" 21#import "TFLQuantizationParameters+Internal.h" 22#import "TFLSignatureRunner+Internal.h" 23#import "TFLTensor+Internal.h" 24#import "tensorflow/lite/objc/apis/TFLDelegate.h" 25#import "tensorflow/lite/objc/apis/TFLInterpreterOptions.h" 26#import "tensorflow/lite/objc/apis/TFLTensor.h" 27 28#include "tensorflow/lite/c/c_api.h" 29#include "tensorflow/lite/c/c_api_experimental.h" 30#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" 31 32NS_ASSUME_NONNULL_BEGIN 33 34FOUNDATION_EXPORT NSString *const TFLVersion = 35 TfLiteVersion() == nullptr ? @"" : [NSString stringWithUTF8String:TfLiteVersion()]; 36 37/** 38 * Error reporter for TFLInterpreter. 39 * 40 * @param user_data User data. Not used. 41 * @param format Error message which may contain argument formatting specifiers. 42 * @param args Values of the arguments in the error message. 43 */ 44static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_list args) { 45 NSLog(@"%@", [[NSString alloc] initWithFormat:@(format) arguments:args]); 46} 47 48@interface TFLInterpreter () 49 50/** TfLiteInterpreter backed by C API. */ 51@property(nonatomic) TfLiteInterpreter *interpreter; 52 53/** TfLiteDelegate backed by C API. */ 54@property(nonatomic, nullable) TfLiteDelegate *xnnPackDelegate; 55 56@end 57 58@implementation TFLInterpreter 59 60@synthesize signatureKeys = _signatureKeys; 61 62#pragma mark - NSObject 63 64- (void)dealloc { 65 TfLiteInterpreterDelete(_interpreter); 66 TfLiteXNNPackDelegateDelete(_xnnPackDelegate); 67} 68 69#pragma mark - Public 70 71- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { 72 return [self initWithModelPath:modelPath 73 options:[[TFLInterpreterOptions alloc] init] 74 delegates:@[] 75 error:error]; 76} 77 78- (nullable instancetype)initWithModelPath:(NSString *)modelPath 79 options:(TFLInterpreterOptions *)options 80 error:(NSError **)error { 81 return [self initWithModelPath:modelPath options:options delegates:@[] error:error]; 82} 83 84- (nullable instancetype)initWithModelPath:(NSString *)modelPath 85 options:(TFLInterpreterOptions *)options 86 delegates:(NSArray<TFLDelegate *> *)delegates 87 error:(NSError **)error { 88 self = [super init]; 89 90 if (self != nil) { 91 TfLiteModel *model = nullptr; 92 TfLiteInterpreterOptions *cOptions = nullptr; 93 94 @try { 95 const char *modelPathCString = modelPath.UTF8String; 96 NSString *pathErrorString = 97 [NSString stringWithFormat:@"Cannot load model from path (%@).", modelPath]; 98 if (modelPathCString == nullptr) { 99 [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToLoadModel 100 description:pathErrorString 101 error:error]; 102 return nil; 103 } 104 105 model = TfLiteModelCreateFromFile(modelPathCString); 106 if (model == nullptr) { 107 [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToLoadModel 108 description:pathErrorString 109 error:error]; 110 return nil; 111 } 112 113 cOptions = TfLiteInterpreterOptionsCreate(); 114 if (cOptions == nullptr) { 115 [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToCreateInterpreter 116 description:@"Failed to create the interpreter." 117 error:error]; 118 return nil; 119 } 120 121 if (options.numberOfThreads > 0) { 122 TfLiteInterpreterOptionsSetNumThreads(cOptions, (int32_t)options.numberOfThreads); 123 } 124 TfLiteInterpreterOptionsSetErrorReporter(cOptions, TFLInterpreterErrorReporter, nullptr); 125 126 if (options.useXNNPACK) { 127 TfLiteXNNPackDelegateOptions xnnPackOptions = TfLiteXNNPackDelegateOptionsDefault(); 128 if (options.numberOfThreads > 0) { 129 xnnPackOptions.num_threads = (int32_t)options.numberOfThreads; 130 } 131 132 _xnnPackDelegate = TfLiteXNNPackDelegateCreate(&xnnPackOptions); 133 TfLiteInterpreterOptionsAddDelegate(cOptions, _xnnPackDelegate); 134 } 135 136 for (TFLDelegate *delegate in delegates) { 137 if (delegate.cDelegate != nullptr) { 138 TfLiteInterpreterOptionsAddDelegate( 139 cOptions, reinterpret_cast<TfLiteDelegate *>(delegate.cDelegate)); 140 } 141 } 142 143 TfLiteInterpreter *interpreter = TfLiteInterpreterCreate(model, cOptions); 144 if (interpreter == nullptr) { 145 [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToCreateInterpreter 146 description:@"Failed to create the interpreter." 147 error:error]; 148 return nil; 149 } 150 _interpreter = interpreter; 151 152 _inputTensorCount = (NSUInteger)TfLiteInterpreterGetInputTensorCount(_interpreter); 153 _outputTensorCount = (NSUInteger)TfLiteInterpreterGetOutputTensorCount(_interpreter); 154 if (_inputTensorCount <= 0 || _outputTensorCount <= 0) { 155 [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToCreateInterpreter 156 description:@"Failed to create the interpreter." 157 error:error]; 158 return nil; 159 } 160 } @finally { 161 TfLiteInterpreterOptionsDelete(cOptions); 162 TfLiteModelDelete(model); 163 } 164 } 165 166 return self; 167} 168 169- (NSArray<NSString *> *)signatureKeys { 170 if (_signatureKeys) return _signatureKeys; 171 NSUInteger signatureCount = TfLiteInterpreterGetSignatureCount(self.interpreter); 172 NSMutableArray<NSString *> *mutableKeyArray = 173 [[NSMutableArray alloc] initWithCapacity:signatureCount]; 174 for (NSUInteger i = 0; i < signatureCount; i++) { 175 const char *signatureNameCString = 176 TfLiteInterpreterGetSignatureKey(self.interpreter, (int32_t)i); 177 NSString *signatureName = @""; 178 if (signatureNameCString != nullptr) { 179 signatureName = [NSString stringWithUTF8String:signatureNameCString] ?: @""; 180 } 181 [mutableKeyArray addObject:signatureName]; 182 } 183 _signatureKeys = [mutableKeyArray copy]; 184 return _signatureKeys; 185} 186 187- (BOOL)invokeWithError:(NSError **)error { 188 if (TfLiteInterpreterInvoke(self.interpreter) != kTfLiteOk) { 189 [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToInvoke 190 description:@"Failed to invoke the interpreter." 191 error:error]; 192 return NO; 193 } 194 195 return YES; 196} 197 198- (nullable TFLTensor *)inputTensorAtIndex:(NSUInteger)index error:(NSError **)error { 199 if (![self isValidTensorIndex:index belowLimit:self.inputTensorCount error:error]) { 200 return nil; 201 } 202 203 return [self tensorOfType:TFLTensorTypeInput atIndex:index error:error]; 204} 205 206- (nullable TFLTensor *)outputTensorAtIndex:(NSUInteger)index error:(NSError **)error { 207 if (![self isValidTensorIndex:index belowLimit:self.outputTensorCount error:error]) { 208 return nil; 209 } 210 211 return [self tensorOfType:TFLTensorTypeOutput atIndex:index error:error]; 212} 213 214- (BOOL)resizeInputTensorAtIndex:(NSUInteger)index 215 toShape:(NSArray<NSNumber *> *)shape 216 error:(NSError **)error { 217 if (![self isValidTensorIndex:index belowLimit:self.inputTensorCount error:error]) { 218 return NO; 219 } 220 221 if (shape.count == 0) { 222 [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidShape 223 description:@"Invalid shape. Must not be empty." 224 error:error]; 225 return NO; 226 } 227 228 std::vector<int> cDimensions(shape.count); 229 for (int dimIndex = 0; dimIndex < shape.count; ++dimIndex) { 230 int dimension = shape[dimIndex].intValue; 231 if (dimension <= 0) { 232 NSString *errorDescription = @"Invalid shape. Dimensions must be positive integers."; 233 [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidShape 234 description:errorDescription 235 error:error]; 236 return NO; 237 } 238 cDimensions[dimIndex] = dimension; 239 } 240 241 if (TfLiteInterpreterResizeInputTensor(self.interpreter, (int32_t)index, cDimensions.data(), 242 (int32_t)shape.count) != kTfLiteOk) { 243 NSString *errorDescription = [NSString 244 stringWithFormat:@"Failed to resize input tensor at index (%lu).", (unsigned long)index]; 245 [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToResizeInputTensor 246 description:errorDescription 247 error:error]; 248 return NO; 249 } 250 251 return YES; 252} 253 254- (BOOL)allocateTensorsWithError:(NSError **)error { 255 if (TfLiteInterpreterAllocateTensors(self.interpreter) != kTfLiteOk) { 256 [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToAllocateTensors 257 description:@"Failed to allocate memory for tensors." 258 error:error]; 259 return NO; 260 } 261 return YES; 262} 263 264- (nullable TFLSignatureRunner *)signatureRunnerWithKey:(NSString *)key error:(NSError **)error { 265 if (![self.signatureKeys containsObject:key]) { 266 NSString *errorDescription = [NSString 267 stringWithFormat:@"Failed to create a signature runner. Signature with key (%@) not found.", 268 key]; 269 [TFLErrorUtil setError:error 270 withDomain:TFLSignatureRunnerErrorDomain 271 code:TFLSignatureRunnerErrorCodeFailedToCreateSignatureRunner 272 description:errorDescription]; 273 return nil; 274 } 275 return [[TFLSignatureRunner alloc] initWithInterpreter:self signatureKey:key error:error]; 276} 277 278#pragma mark - TFLTensorDataAccessor 279 280- (BOOL)copyData:(NSData *)data toInputTensor:(TFLTensor *)inputTensor error:(NSError **)error { 281 if (inputTensor.type == TFLTensorTypeOutput) { 282 [TFLErrorUtil 283 saveInterpreterErrorWithCode:TFLInterpreterErrorCodeCopyDataToOutputTensorNotAllowed 284 description:@"Cannot copy data into an output tensor." 285 error:error]; 286 return NO; 287 } 288 const TfLiteTensor *cTensor = [self cTensorOfType:TFLTensorTypeInput 289 atIndex:inputTensor.index 290 error:error]; 291 if (cTensor == nullptr) { 292 return NO; 293 } 294 295 NSUInteger byteSize = (NSUInteger)TfLiteTensorByteSize(cTensor); 296 if (data.length != byteSize) { 297 NSString *errorDescription = [NSString 298 stringWithFormat:@"Input tensor at index (%lu) expects data size (%lu), but got (%lu).", 299 (unsigned long)index, (unsigned long)byteSize, (unsigned long)data.length]; 300 [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidInputByteSize 301 description:errorDescription 302 error:error]; 303 return NO; 304 } 305 306 if (TfLiteTensorCopyFromBuffer((TfLiteTensor *)cTensor, data.bytes, data.length) != kTfLiteOk) { 307 NSString *errorDescription = 308 [NSString stringWithFormat:@"Failed to copy data into input tensor at index (%lu).", 309 (unsigned long)index]; 310 [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToCopyDataToInputTensor 311 description:errorDescription 312 error:error]; 313 return NO; 314 } 315 316 return YES; 317} 318 319- (nullable NSData *)dataFromTensor:(TFLTensor *)tensor error:(NSError **)error { 320 const TfLiteTensor *cTensor = [self cTensorOfType:tensor.type atIndex:tensor.index error:error]; 321 if (cTensor == nullptr) { 322 return nil; 323 } 324 325 void *bytes = TfLiteTensorData(cTensor); 326 NSUInteger byteSize = (NSUInteger)TfLiteTensorByteSize(cTensor); 327 if (bytes == nullptr || byteSize == 0) { 328 NSString *tensorType = [TFLTensor stringForTensorType:tensor.type]; 329 NSString *errorDescription = 330 [NSString stringWithFormat:@"Failed to get data from %@ tensor at index (%lu).", tensorType, 331 (unsigned long)tensor.index]; 332 [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToGetDataFromTensor 333 description:errorDescription 334 error:error]; 335 return nil; 336 } 337 338 return [NSData dataWithBytes:bytes length:byteSize]; 339} 340 341- (nullable NSArray<NSNumber *> *)shapeOfTensor:(TFLTensor *)tensor error:(NSError **)error { 342 const TfLiteTensor *cTensor = [self cTensorOfType:tensor.type atIndex:tensor.index error:error]; 343 if (cTensor == nullptr) { 344 return nil; 345 } 346 347 NSString *tensorType = [TFLTensor stringForTensorType:tensor.type]; 348 int32_t rank = TfLiteTensorNumDims(cTensor); 349 if (rank <= 0) { 350 NSString *errorDescription = 351 [NSString stringWithFormat:@"%@ tensor at index (%lu) has invalid rank (%d).", tensorType, 352 (unsigned long)tensor.index, rank]; 353 [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidTensor 354 description:errorDescription 355 error:error]; 356 return nil; 357 } 358 359 NSMutableArray<NSNumber *> *shape = [NSMutableArray arrayWithCapacity:rank]; 360 for (int32_t dimIndex = 0; dimIndex < rank; dimIndex++) { 361 int32_t dimension = TfLiteTensorDim(cTensor, dimIndex); 362 if (dimension <= 0) { 363 NSString *errorDescription = 364 [NSString stringWithFormat:@"%@ tensor at index (%lu) has invalid %d-th dimension (%d).", 365 tensorType, (unsigned long)tensor.index, dimIndex, dimension]; 366 [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidTensor 367 description:errorDescription 368 error:error]; 369 return nil; 370 } 371 shape[dimIndex] = @((NSUInteger)dimension); 372 } 373 374 return shape; 375} 376 377#pragma mark - Private 378 379- (const TfLiteTensor *)cTensorOfType:(TFLTensorType)type 380 atIndex:(NSUInteger)index 381 error:(NSError **)error { 382 const TfLiteTensor *tensor = nullptr; 383 384 switch (type) { 385 case TFLTensorTypeInput: 386 tensor = TfLiteInterpreterGetInputTensor(self.interpreter, (int32_t)index); 387 break; 388 case TFLTensorTypeOutput: 389 tensor = TfLiteInterpreterGetOutputTensor(self.interpreter, (int32_t)index); 390 break; 391 } 392 393 if (tensor == nullptr) { 394 NSString *tensorType = [TFLTensor stringForTensorType:type]; 395 NSString *errorDescription = 396 [NSString stringWithFormat:@"Failed to get %@ tensor at index (%lu).", tensorType, 397 (unsigned long)index]; 398 [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToGetTensor 399 description:errorDescription 400 error:error]; 401 } 402 403 return tensor; 404} 405 406- (nullable TFLTensor *)tensorOfType:(TFLTensorType)type 407 atIndex:(NSUInteger)index 408 error:(NSError **)error { 409 const TfLiteTensor *tensor = [self cTensorOfType:type atIndex:index error:error]; 410 411 if (tensor == nullptr) { 412 return nil; 413 } 414 415 NSString *name = TFLTensorNameFromCTensor(tensor); 416 if (!name) { 417 NSString *tensorType = [TFLTensor stringForTensorType:type]; 418 NSString *errorDescription = 419 [NSString stringWithFormat:@"Failed to get name of %@ tensor at index (%lu).", tensorType, 420 (unsigned long)index]; 421 [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidTensor 422 description:errorDescription 423 error:error]; 424 return nil; 425 } 426 TFLTensorDataType dataType = TFLTensorDataTypeFromCTensor(tensor); 427 TFLQuantizationParameters *quantizationParams = TFLQuantizationParamsFromCTensor(tensor); 428 429 return [[TFLTensor alloc] initWithInterpreter:self 430 type:type 431 index:index 432 name:name 433 dataType:dataType 434 quantizationParameters:quantizationParams]; 435} 436 437- (BOOL)isValidTensorIndex:(NSUInteger)index 438 belowLimit:(NSUInteger)totalTensorCount 439 error:(NSError **)error { 440 if (index >= totalTensorCount) { 441 NSString *errorDescription = 442 [NSString stringWithFormat:@"Invalid tensor index (%lu) exceeds max (%lu).", 443 (unsigned long)index, (totalTensorCount - 1)]; 444 [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidTensorIndex 445 description:errorDescription 446 error:error]; 447 return NO; 448 } 449 450 return YES; 451} 452 453@end 454 455NS_ASSUME_NONNULL_END 456