• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/nnapi/nnapi_handler.h"
16 
17 #include <cstdio>
18 
19 #include "tensorflow/lite/nnapi/nnapi_implementation.h"
20 
21 namespace tflite {
22 namespace nnapi {
23 
24 // static
25 const char NnApiHandler::kNnapiReferenceDeviceName[] = "nnapi-reference";
26 // static
27 const int NnApiHandler::kNnapiReferenceDevice = 1;
28 // static
29 const int NnApiHandler::kNnapiDevice = 2;
30 
31 char* NnApiHandler::nnapi_device_name_ = nullptr;
32 int NnApiHandler::nnapi_device_feature_level_;
33 
NnApiPassthroughInstance()34 const NnApi* NnApiPassthroughInstance() {
35   static const NnApi orig_nnapi_copy = *NnApiImplementation();
36   return &orig_nnapi_copy;
37 }
38 
39 // static
Instance()40 NnApiHandler* NnApiHandler::Instance() {
41   // Ensuring that the original copy of nnapi is saved before we return
42   // access to NnApiHandler
43   NnApiPassthroughInstance();
44   static NnApiHandler handler{const_cast<NnApi*>(NnApiImplementation())};
45   return &handler;
46 }
47 
Reset()48 void NnApiHandler::Reset() {
49   // Restores global NNAPI to original value
50   *nnapi_ = *NnApiPassthroughInstance();
51 }
52 
SetAndroidSdkVersion(int version,bool set_unsupported_ops_to_null)53 void NnApiHandler::SetAndroidSdkVersion(int version,
54                                         bool set_unsupported_ops_to_null) {
55   nnapi_->android_sdk_version = version;
56   nnapi_->nnapi_runtime_feature_level = version;
57 
58   if (!set_unsupported_ops_to_null) {
59     return;
60   }
61 
62   if (version < 29) {
63     nnapi_->ANeuralNetworks_getDeviceCount = nullptr;
64     nnapi_->ANeuralNetworks_getDevice = nullptr;
65     nnapi_->ANeuralNetworksDevice_getName = nullptr;
66     nnapi_->ANeuralNetworksDevice_getVersion = nullptr;
67     nnapi_->ANeuralNetworksDevice_getFeatureLevel = nullptr;
68     nnapi_->ANeuralNetworksDevice_getType = nullptr;
69     nnapi_->ANeuralNetworksModel_getSupportedOperationsForDevices = nullptr;
70     nnapi_->ANeuralNetworksCompilation_createForDevices = nullptr;
71     nnapi_->ANeuralNetworksCompilation_setCaching = nullptr;
72     nnapi_->ANeuralNetworksExecution_compute = nullptr;
73     nnapi_->ANeuralNetworksExecution_getOutputOperandRank = nullptr;
74     nnapi_->ANeuralNetworksExecution_getOutputOperandDimensions = nullptr;
75     nnapi_->ANeuralNetworksBurst_create = nullptr;
76     nnapi_->ANeuralNetworksBurst_free = nullptr;
77     nnapi_->ANeuralNetworksExecution_burstCompute = nullptr;
78     nnapi_->ANeuralNetworksMemory_createFromAHardwareBuffer = nullptr;
79     nnapi_->ANeuralNetworksExecution_setMeasureTiming = nullptr;
80     nnapi_->ANeuralNetworksExecution_getDuration = nullptr;
81     nnapi_->ANeuralNetworksDevice_getExtensionSupport = nullptr;
82     nnapi_->ANeuralNetworksModel_getExtensionOperandType = nullptr;
83     nnapi_->ANeuralNetworksModel_getExtensionOperationType = nullptr;
84     nnapi_->ANeuralNetworksModel_setOperandExtensionData = nullptr;
85   }
86   if (version < 28) {
87     nnapi_->ANeuralNetworksModel_relaxComputationFloat32toFloat16 = nullptr;
88   }
89 }
90 
SetDeviceName(const std::string & name)91 void NnApiHandler::SetDeviceName(const std::string& name) {
92   delete[] nnapi_device_name_;
93   nnapi_device_name_ = new char[name.size() + 1];
94   std::strcpy(nnapi_device_name_, name.c_str());  // NOLINT
95 }
96 
GetDeviceNameReturnsName(const std::string & name)97 void NnApiHandler::GetDeviceNameReturnsName(const std::string& name) {
98   NnApiHandler::SetDeviceName(name);
99   GetDeviceNameReturns<0>();
100 }
101 
SetNnapiSupportedDevice(const std::string & name,int feature_level)102 void NnApiHandler::SetNnapiSupportedDevice(const std::string& name,
103                                            int feature_level) {
104   NnApiHandler::SetDeviceName(name);
105   nnapi_device_feature_level_ = feature_level;
106 
107   GetDeviceCountReturnsCount<2>();
108   nnapi_->ANeuralNetworks_getDevice =
109       [](uint32_t devIndex, ANeuralNetworksDevice** device) -> int {
110     if (devIndex > 1) {
111       return ANEURALNETWORKS_BAD_DATA;
112     }
113 
114     if (devIndex == 1) {
115       *device =
116           reinterpret_cast<ANeuralNetworksDevice*>(NnApiHandler::kNnapiDevice);
117     } else {
118       *device = reinterpret_cast<ANeuralNetworksDevice*>(
119           NnApiHandler::kNnapiReferenceDevice);
120     }
121     return ANEURALNETWORKS_NO_ERROR;
122   };
123   nnapi_->ANeuralNetworksDevice_getName =
124       [](const ANeuralNetworksDevice* device, const char** name) -> int {
125     if (device ==
126         reinterpret_cast<ANeuralNetworksDevice*>(NnApiHandler::kNnapiDevice)) {
127       *name = NnApiHandler::nnapi_device_name_;
128       return ANEURALNETWORKS_NO_ERROR;
129     }
130     if (device == reinterpret_cast<ANeuralNetworksDevice*>(
131                       NnApiHandler::kNnapiReferenceDevice)) {
132       *name = NnApiHandler::kNnapiReferenceDeviceName;
133       return ANEURALNETWORKS_NO_ERROR;
134     }
135 
136     return ANEURALNETWORKS_BAD_DATA;
137   };
138   nnapi_->ANeuralNetworksDevice_getFeatureLevel =
139       [](const ANeuralNetworksDevice* device, int64_t* featureLevel) -> int {
140     if (device ==
141         reinterpret_cast<ANeuralNetworksDevice*>(NnApiHandler::kNnapiDevice)) {
142       *featureLevel = NnApiHandler::nnapi_device_feature_level_;
143       return ANEURALNETWORKS_NO_ERROR;
144     }
145     if (device == reinterpret_cast<ANeuralNetworksDevice*>(
146                       NnApiHandler::kNnapiReferenceDevice)) {
147       *featureLevel = 1000;
148       return ANEURALNETWORKS_NO_ERROR;
149     }
150 
151     return ANEURALNETWORKS_BAD_DATA;
152   };
153 }
154 
155 }  // namespace nnapi
156 }  // namespace tflite
157