• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2018 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at:
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#import "tensorflow/lite/objc/apis/TFLInterpreter.h"
16
17#include <vector>
18
19#import "TFLCommonUtil.h"
20#import "TFLErrorUtil.h"
21#import "TFLQuantizationParameters+Internal.h"
22#import "TFLSignatureRunner+Internal.h"
23#import "TFLTensor+Internal.h"
24#import "tensorflow/lite/objc/apis/TFLDelegate.h"
25#import "tensorflow/lite/objc/apis/TFLInterpreterOptions.h"
26#import "tensorflow/lite/objc/apis/TFLTensor.h"
27
28#include "tensorflow/lite/c/c_api.h"
29#include "tensorflow/lite/c/c_api_experimental.h"
30#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
31
32NS_ASSUME_NONNULL_BEGIN
33
34FOUNDATION_EXPORT NSString *const TFLVersion =
35    TfLiteVersion() == nullptr ? @"" : [NSString stringWithUTF8String:TfLiteVersion()];
36
37/**
38 * Error reporter for TFLInterpreter.
39 *
40 * @param user_data User data. Not used.
41 * @param format Error message which may contain argument formatting specifiers.
42 * @param args Values of the arguments in the error message.
43 */
44static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_list args) {
45  NSLog(@"%@", [[NSString alloc] initWithFormat:@(format) arguments:args]);
46}
47
48@interface TFLInterpreter ()
49
50/** TfLiteInterpreter backed by C API. */
51@property(nonatomic) TfLiteInterpreter *interpreter;
52
53/** TfLiteDelegate backed by C API. */
54@property(nonatomic, nullable) TfLiteDelegate *xnnPackDelegate;
55
56@end
57
58@implementation TFLInterpreter
59
60@synthesize signatureKeys = _signatureKeys;
61
62#pragma mark - NSObject
63
64- (void)dealloc {
65  TfLiteInterpreterDelete(_interpreter);
66  TfLiteXNNPackDelegateDelete(_xnnPackDelegate);
67}
68
69#pragma mark - Public
70
71- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error {
72  return [self initWithModelPath:modelPath
73                         options:[[TFLInterpreterOptions alloc] init]
74                       delegates:@[]
75                           error:error];
76}
77
78- (nullable instancetype)initWithModelPath:(NSString *)modelPath
79                                   options:(TFLInterpreterOptions *)options
80                                     error:(NSError **)error {
81  return [self initWithModelPath:modelPath options:options delegates:@[] error:error];
82}
83
84- (nullable instancetype)initWithModelPath:(NSString *)modelPath
85                                   options:(TFLInterpreterOptions *)options
86                                 delegates:(NSArray<TFLDelegate *> *)delegates
87                                     error:(NSError **)error {
88  self = [super init];
89
90  if (self != nil) {
91    TfLiteModel *model = nullptr;
92    TfLiteInterpreterOptions *cOptions = nullptr;
93
94    @try {
95      const char *modelPathCString = modelPath.UTF8String;
96      NSString *pathErrorString =
97          [NSString stringWithFormat:@"Cannot load model from path (%@).", modelPath];
98      if (modelPathCString == nullptr) {
99        [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToLoadModel
100                                       description:pathErrorString
101                                             error:error];
102        return nil;
103      }
104
105      model = TfLiteModelCreateFromFile(modelPathCString);
106      if (model == nullptr) {
107        [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToLoadModel
108                                       description:pathErrorString
109                                             error:error];
110        return nil;
111      }
112
113      cOptions = TfLiteInterpreterOptionsCreate();
114      if (cOptions == nullptr) {
115        [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToCreateInterpreter
116                                       description:@"Failed to create the interpreter."
117                                             error:error];
118        return nil;
119      }
120
121      if (options.numberOfThreads > 0) {
122        TfLiteInterpreterOptionsSetNumThreads(cOptions, (int32_t)options.numberOfThreads);
123      }
124      TfLiteInterpreterOptionsSetErrorReporter(cOptions, TFLInterpreterErrorReporter, nullptr);
125
126      if (options.useXNNPACK) {
127        TfLiteXNNPackDelegateOptions xnnPackOptions = TfLiteXNNPackDelegateOptionsDefault();
128        if (options.numberOfThreads > 0) {
129          xnnPackOptions.num_threads = (int32_t)options.numberOfThreads;
130        }
131
132        _xnnPackDelegate = TfLiteXNNPackDelegateCreate(&xnnPackOptions);
133        TfLiteInterpreterOptionsAddDelegate(cOptions, _xnnPackDelegate);
134      }
135
136      for (TFLDelegate *delegate in delegates) {
137        if (delegate.cDelegate != nullptr) {
138          TfLiteInterpreterOptionsAddDelegate(
139              cOptions, reinterpret_cast<TfLiteDelegate *>(delegate.cDelegate));
140        }
141      }
142
143      TfLiteInterpreter *interpreter = TfLiteInterpreterCreate(model, cOptions);
144      if (interpreter == nullptr) {
145        [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToCreateInterpreter
146                                       description:@"Failed to create the interpreter."
147                                             error:error];
148        return nil;
149      }
150      _interpreter = interpreter;
151
152      _inputTensorCount = (NSUInteger)TfLiteInterpreterGetInputTensorCount(_interpreter);
153      _outputTensorCount = (NSUInteger)TfLiteInterpreterGetOutputTensorCount(_interpreter);
154      if (_inputTensorCount <= 0 || _outputTensorCount <= 0) {
155        [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToCreateInterpreter
156                                       description:@"Failed to create the interpreter."
157                                             error:error];
158        return nil;
159      }
160    } @finally {
161      TfLiteInterpreterOptionsDelete(cOptions);
162      TfLiteModelDelete(model);
163    }
164  }
165
166  return self;
167}
168
169- (NSArray<NSString *> *)signatureKeys {
170  if (_signatureKeys) return _signatureKeys;
171  NSUInteger signatureCount = TfLiteInterpreterGetSignatureCount(self.interpreter);
172  NSMutableArray<NSString *> *mutableKeyArray =
173      [[NSMutableArray alloc] initWithCapacity:signatureCount];
174  for (NSUInteger i = 0; i < signatureCount; i++) {
175    const char *signatureNameCString =
176        TfLiteInterpreterGetSignatureKey(self.interpreter, (int32_t)i);
177    NSString *signatureName = @"";
178    if (signatureNameCString != nullptr) {
179      signatureName = [NSString stringWithUTF8String:signatureNameCString] ?: @"";
180    }
181    [mutableKeyArray addObject:signatureName];
182  }
183  _signatureKeys = [mutableKeyArray copy];
184  return _signatureKeys;
185}
186
187- (BOOL)invokeWithError:(NSError **)error {
188  if (TfLiteInterpreterInvoke(self.interpreter) != kTfLiteOk) {
189    [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToInvoke
190                                   description:@"Failed to invoke the interpreter."
191                                         error:error];
192    return NO;
193  }
194
195  return YES;
196}
197
198- (nullable TFLTensor *)inputTensorAtIndex:(NSUInteger)index error:(NSError **)error {
199  if (![self isValidTensorIndex:index belowLimit:self.inputTensorCount error:error]) {
200    return nil;
201  }
202
203  return [self tensorOfType:TFLTensorTypeInput atIndex:index error:error];
204}
205
206- (nullable TFLTensor *)outputTensorAtIndex:(NSUInteger)index error:(NSError **)error {
207  if (![self isValidTensorIndex:index belowLimit:self.outputTensorCount error:error]) {
208    return nil;
209  }
210
211  return [self tensorOfType:TFLTensorTypeOutput atIndex:index error:error];
212}
213
214- (BOOL)resizeInputTensorAtIndex:(NSUInteger)index
215                         toShape:(NSArray<NSNumber *> *)shape
216                           error:(NSError **)error {
217  if (![self isValidTensorIndex:index belowLimit:self.inputTensorCount error:error]) {
218    return NO;
219  }
220
221  if (shape.count == 0) {
222    [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidShape
223                                   description:@"Invalid shape. Must not be empty."
224                                         error:error];
225    return NO;
226  }
227
228  std::vector<int> cDimensions(shape.count);
229  for (int dimIndex = 0; dimIndex < shape.count; ++dimIndex) {
230    int dimension = shape[dimIndex].intValue;
231    if (dimension <= 0) {
232      NSString *errorDescription = @"Invalid shape. Dimensions must be positive integers.";
233      [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidShape
234                                     description:errorDescription
235                                           error:error];
236      return NO;
237    }
238    cDimensions[dimIndex] = dimension;
239  }
240
241  if (TfLiteInterpreterResizeInputTensor(self.interpreter, (int32_t)index, cDimensions.data(),
242                                         (int32_t)shape.count) != kTfLiteOk) {
243    NSString *errorDescription = [NSString
244        stringWithFormat:@"Failed to resize input tensor at index (%lu).", (unsigned long)index];
245    [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToResizeInputTensor
246                                   description:errorDescription
247                                         error:error];
248    return NO;
249  }
250
251  return YES;
252}
253
254- (BOOL)allocateTensorsWithError:(NSError **)error {
255  if (TfLiteInterpreterAllocateTensors(self.interpreter) != kTfLiteOk) {
256    [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToAllocateTensors
257                                   description:@"Failed to allocate memory for tensors."
258                                         error:error];
259    return NO;
260  }
261  return YES;
262}
263
264- (nullable TFLSignatureRunner *)signatureRunnerWithKey:(NSString *)key error:(NSError **)error {
265  if (![self.signatureKeys containsObject:key]) {
266    NSString *errorDescription = [NSString
267        stringWithFormat:@"Failed to create a signature runner. Signature with key (%@) not found.",
268                         key];
269    [TFLErrorUtil setError:error
270                withDomain:TFLSignatureRunnerErrorDomain
271                      code:TFLSignatureRunnerErrorCodeFailedToCreateSignatureRunner
272               description:errorDescription];
273    return nil;
274  }
275  return [[TFLSignatureRunner alloc] initWithInterpreter:self signatureKey:key error:error];
276}
277
278#pragma mark - TFLTensorDataAccessor
279
280- (BOOL)copyData:(NSData *)data toInputTensor:(TFLTensor *)inputTensor error:(NSError **)error {
281  if (inputTensor.type == TFLTensorTypeOutput) {
282    [TFLErrorUtil
283        saveInterpreterErrorWithCode:TFLInterpreterErrorCodeCopyDataToOutputTensorNotAllowed
284                         description:@"Cannot copy data into an output tensor."
285                               error:error];
286    return NO;
287  }
288  const TfLiteTensor *cTensor = [self cTensorOfType:TFLTensorTypeInput
289                                            atIndex:inputTensor.index
290                                              error:error];
291  if (cTensor == nullptr) {
292    return NO;
293  }
294
295  NSUInteger byteSize = (NSUInteger)TfLiteTensorByteSize(cTensor);
296  if (data.length != byteSize) {
297    NSString *errorDescription = [NSString
298        stringWithFormat:@"Input tensor at index (%lu) expects data size (%lu), but got (%lu).",
299                         (unsigned long)index, (unsigned long)byteSize, (unsigned long)data.length];
300    [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidInputByteSize
301                                   description:errorDescription
302                                         error:error];
303    return NO;
304  }
305
306  if (TfLiteTensorCopyFromBuffer((TfLiteTensor *)cTensor, data.bytes, data.length) != kTfLiteOk) {
307    NSString *errorDescription =
308        [NSString stringWithFormat:@"Failed to copy data into input tensor at index (%lu).",
309                                   (unsigned long)index];
310    [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToCopyDataToInputTensor
311                                   description:errorDescription
312                                         error:error];
313    return NO;
314  }
315
316  return YES;
317}
318
319- (nullable NSData *)dataFromTensor:(TFLTensor *)tensor error:(NSError **)error {
320  const TfLiteTensor *cTensor = [self cTensorOfType:tensor.type atIndex:tensor.index error:error];
321  if (cTensor == nullptr) {
322    return nil;
323  }
324
325  void *bytes = TfLiteTensorData(cTensor);
326  NSUInteger byteSize = (NSUInteger)TfLiteTensorByteSize(cTensor);
327  if (bytes == nullptr || byteSize == 0) {
328    NSString *tensorType = [TFLTensor stringForTensorType:tensor.type];
329    NSString *errorDescription =
330        [NSString stringWithFormat:@"Failed to get data from %@ tensor at index (%lu).", tensorType,
331                                   (unsigned long)tensor.index];
332    [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToGetDataFromTensor
333                                   description:errorDescription
334                                         error:error];
335    return nil;
336  }
337
338  return [NSData dataWithBytes:bytes length:byteSize];
339}
340
341- (nullable NSArray<NSNumber *> *)shapeOfTensor:(TFLTensor *)tensor error:(NSError **)error {
342  const TfLiteTensor *cTensor = [self cTensorOfType:tensor.type atIndex:tensor.index error:error];
343  if (cTensor == nullptr) {
344    return nil;
345  }
346
347  NSString *tensorType = [TFLTensor stringForTensorType:tensor.type];
348  int32_t rank = TfLiteTensorNumDims(cTensor);
349  if (rank <= 0) {
350    NSString *errorDescription =
351        [NSString stringWithFormat:@"%@ tensor at index (%lu) has invalid rank (%d).", tensorType,
352                                   (unsigned long)tensor.index, rank];
353    [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidTensor
354                                   description:errorDescription
355                                         error:error];
356    return nil;
357  }
358
359  NSMutableArray<NSNumber *> *shape = [NSMutableArray arrayWithCapacity:rank];
360  for (int32_t dimIndex = 0; dimIndex < rank; dimIndex++) {
361    int32_t dimension = TfLiteTensorDim(cTensor, dimIndex);
362    if (dimension <= 0) {
363      NSString *errorDescription =
364          [NSString stringWithFormat:@"%@ tensor at index (%lu) has invalid %d-th dimension (%d).",
365                                     tensorType, (unsigned long)tensor.index, dimIndex, dimension];
366      [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidTensor
367                                     description:errorDescription
368                                           error:error];
369      return nil;
370    }
371    shape[dimIndex] = @((NSUInteger)dimension);
372  }
373
374  return shape;
375}
376
377#pragma mark - Private
378
379- (const TfLiteTensor *)cTensorOfType:(TFLTensorType)type
380                              atIndex:(NSUInteger)index
381                                error:(NSError **)error {
382  const TfLiteTensor *tensor = nullptr;
383
384  switch (type) {
385    case TFLTensorTypeInput:
386      tensor = TfLiteInterpreterGetInputTensor(self.interpreter, (int32_t)index);
387      break;
388    case TFLTensorTypeOutput:
389      tensor = TfLiteInterpreterGetOutputTensor(self.interpreter, (int32_t)index);
390      break;
391  }
392
393  if (tensor == nullptr) {
394    NSString *tensorType = [TFLTensor stringForTensorType:type];
395    NSString *errorDescription =
396        [NSString stringWithFormat:@"Failed to get %@ tensor at index (%lu).", tensorType,
397                                   (unsigned long)index];
398    [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToGetTensor
399                                   description:errorDescription
400                                         error:error];
401  }
402
403  return tensor;
404}
405
406- (nullable TFLTensor *)tensorOfType:(TFLTensorType)type
407                             atIndex:(NSUInteger)index
408                               error:(NSError **)error {
409  const TfLiteTensor *tensor = [self cTensorOfType:type atIndex:index error:error];
410
411  if (tensor == nullptr) {
412    return nil;
413  }
414
415  NSString *name = TFLTensorNameFromCTensor(tensor);
416  if (!name) {
417    NSString *tensorType = [TFLTensor stringForTensorType:type];
418    NSString *errorDescription =
419        [NSString stringWithFormat:@"Failed to get name of %@ tensor at index (%lu).", tensorType,
420                                   (unsigned long)index];
421    [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidTensor
422                                   description:errorDescription
423                                         error:error];
424    return nil;
425  }
426  TFLTensorDataType dataType = TFLTensorDataTypeFromCTensor(tensor);
427  TFLQuantizationParameters *quantizationParams = TFLQuantizationParamsFromCTensor(tensor);
428
429  return [[TFLTensor alloc] initWithInterpreter:self
430                                           type:type
431                                          index:index
432                                           name:name
433                                       dataType:dataType
434                         quantizationParameters:quantizationParams];
435}
436
437- (BOOL)isValidTensorIndex:(NSUInteger)index
438                belowLimit:(NSUInteger)totalTensorCount
439                     error:(NSError **)error {
440  if (index >= totalTensorCount) {
441    NSString *errorDescription =
442        [NSString stringWithFormat:@"Invalid tensor index (%lu) exceeds max (%lu).",
443                                   (unsigned long)index, (totalTensorCount - 1)];
444    [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidTensorIndex
445                                   description:errorDescription
446                                         error:error];
447    return NO;
448  }
449
450  return YES;
451}
452
453@end
454
455NS_ASSUME_NONNULL_END
456