• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/nnapi/nnapi_handler.h"
16 
17 #include <cstdio>
18 
19 #include "tensorflow/lite/nnapi/nnapi_implementation.h"
20 
21 namespace tflite {
22 namespace nnapi {
23 
24 // static
25 const char NnApiHandler::kNnapiReferenceDeviceName[] = "nnapi-reference";
26 // static
27 const int NnApiHandler::kNnapiReferenceDevice = 1;
28 // static
29 const int NnApiHandler::kNnapiDevice = 2;
30 
31 char* NnApiHandler::nnapi_device_name_ = nullptr;
32 int NnApiHandler::nnapi_device_feature_level_;
33 
NnApiPassthroughInstance()34 const NnApi* NnApiPassthroughInstance() {
35   static const NnApi orig_nnapi_copy = *NnApiImplementation();
36   return &orig_nnapi_copy;
37 }
38 
39 // static
Instance()40 NnApiHandler* NnApiHandler::Instance() {
41   // Ensuring that the original copy of nnapi is saved before we return
42   // access to NnApiHandler
43   NnApiPassthroughInstance();
44   static NnApiHandler handler{const_cast<NnApi*>(NnApiImplementation())};
45   return &handler;
46 }
47 
Reset()48 void NnApiHandler::Reset() {
49   // Restores global NNAPI to original value
50   *nnapi_ = *NnApiPassthroughInstance();
51 }
52 
SetAndroidSdkVersion(int version,bool set_unsupported_ops_to_null)53 void NnApiHandler::SetAndroidSdkVersion(int version,
54                                         bool set_unsupported_ops_to_null) {
55   nnapi_->android_sdk_version = version;
56 
57   if (!set_unsupported_ops_to_null) {
58     return;
59   }
60 
61   if (version < 29) {
62     nnapi_->ANeuralNetworks_getDeviceCount = nullptr;
63     nnapi_->ANeuralNetworks_getDevice = nullptr;
64     nnapi_->ANeuralNetworksDevice_getName = nullptr;
65     nnapi_->ANeuralNetworksDevice_getVersion = nullptr;
66     nnapi_->ANeuralNetworksDevice_getFeatureLevel = nullptr;
67     nnapi_->ANeuralNetworksDevice_getType = nullptr;
68     nnapi_->ANeuralNetworksModel_getSupportedOperationsForDevices = nullptr;
69     nnapi_->ANeuralNetworksCompilation_createForDevices = nullptr;
70     nnapi_->ANeuralNetworksCompilation_setCaching = nullptr;
71     nnapi_->ANeuralNetworksExecution_compute = nullptr;
72     nnapi_->ANeuralNetworksExecution_getOutputOperandRank = nullptr;
73     nnapi_->ANeuralNetworksExecution_getOutputOperandDimensions = nullptr;
74     nnapi_->ANeuralNetworksBurst_create = nullptr;
75     nnapi_->ANeuralNetworksBurst_free = nullptr;
76     nnapi_->ANeuralNetworksExecution_burstCompute = nullptr;
77     nnapi_->ANeuralNetworksMemory_createFromAHardwareBuffer = nullptr;
78     nnapi_->ANeuralNetworksExecution_setMeasureTiming = nullptr;
79     nnapi_->ANeuralNetworksExecution_getDuration = nullptr;
80     nnapi_->ANeuralNetworksDevice_getExtensionSupport = nullptr;
81     nnapi_->ANeuralNetworksModel_getExtensionOperandType = nullptr;
82     nnapi_->ANeuralNetworksModel_getExtensionOperationType = nullptr;
83     nnapi_->ANeuralNetworksModel_setOperandExtensionData = nullptr;
84   }
85   if (version < 28) {
86     nnapi_->ANeuralNetworksModel_relaxComputationFloat32toFloat16 = nullptr;
87   }
88 }
89 
SetDeviceName(const std::string & name)90 void NnApiHandler::SetDeviceName(const std::string& name) {
91   delete[] nnapi_device_name_;
92   nnapi_device_name_ = new char[name.size() + 1];
93   std::strcpy(nnapi_device_name_, name.c_str());  // NOLINT
94 }
95 
GetDeviceNameReturnsName(const std::string & name)96 void NnApiHandler::GetDeviceNameReturnsName(const std::string& name) {
97   NnApiHandler::SetDeviceName(name);
98   GetDeviceNameReturns<0>();
99 }
100 
SetNnapiSupportedDevice(const std::string & name,int feature_level)101 void NnApiHandler::SetNnapiSupportedDevice(const std::string& name,
102                                            int feature_level) {
103   NnApiHandler::SetDeviceName(name);
104   nnapi_device_feature_level_ = feature_level;
105 
106   GetDeviceCountReturnsCount<2>();
107   nnapi_->ANeuralNetworks_getDevice =
108       [](uint32_t devIndex, ANeuralNetworksDevice** device) -> int {
109     if (devIndex > 1) {
110       return ANEURALNETWORKS_BAD_DATA;
111     }
112 
113     if (devIndex == 1) {
114       *device =
115           reinterpret_cast<ANeuralNetworksDevice*>(NnApiHandler::kNnapiDevice);
116     } else {
117       *device = reinterpret_cast<ANeuralNetworksDevice*>(
118           NnApiHandler::kNnapiReferenceDevice);
119     }
120     return ANEURALNETWORKS_NO_ERROR;
121   };
122   nnapi_->ANeuralNetworksDevice_getName =
123       [](const ANeuralNetworksDevice* device, const char** name) -> int {
124     if (device ==
125         reinterpret_cast<ANeuralNetworksDevice*>(NnApiHandler::kNnapiDevice)) {
126       *name = NnApiHandler::nnapi_device_name_;
127       return ANEURALNETWORKS_NO_ERROR;
128     }
129     if (device == reinterpret_cast<ANeuralNetworksDevice*>(
130                       NnApiHandler::kNnapiReferenceDevice)) {
131       *name = NnApiHandler::kNnapiReferenceDeviceName;
132       return ANEURALNETWORKS_NO_ERROR;
133     }
134 
135     return ANEURALNETWORKS_BAD_DATA;
136   };
137   nnapi_->ANeuralNetworksDevice_getFeatureLevel =
138       [](const ANeuralNetworksDevice* device, int64_t* featureLevel) -> int {
139     if (device ==
140         reinterpret_cast<ANeuralNetworksDevice*>(NnApiHandler::kNnapiDevice)) {
141       *featureLevel = NnApiHandler::nnapi_device_feature_level_;
142       return ANEURALNETWORKS_NO_ERROR;
143     }
144     if (device == reinterpret_cast<ANeuralNetworksDevice*>(
145                       NnApiHandler::kNnapiReferenceDevice)) {
146       *featureLevel = 1000;
147       return ANEURALNETWORKS_NO_ERROR;
148     }
149 
150     return ANEURALNETWORKS_BAD_DATA;
151   };
152 }
153 
154 }  // namespace nnapi
155 }  // namespace tflite
156