• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_RUNTIME_TYPE_MANAGER_H
18 #define ANDROID_FRAMEWORKS_ML_NN_RUNTIME_TYPE_MANAGER_H
19 
20 #include <map>
21 #include <set>
22 #include <string>
23 #include <vector>
24 
25 #include "Manager.h"
26 
27 #ifndef NN_COMPATIBILITY_LIBRARY_BUILD
28 #include "AppInfoFetcher.h"
29 #endif  // NN_COMPATIBILITY_LIBRARY_BUILD
30 
31 namespace android {
32 namespace nn {
33 
34 // Manages runtime operand and operation type information.
35 //
36 // This class gathers information about extension types from all devices
37 // and provides a unified way to access information about any known type.
38 class TypeManager {
39    public:
get()40     static TypeManager* get() {
41         static TypeManager manager;
42         return &manager;
43     }
44 
45     // Creates an operand/operation type corresponding to a given extension
46     // name and type within extension.
47     //
48     // Returns false if the extension is unknown.
49     bool getExtensionType(const char* extensionName, uint16_t typeWithinExtension, int32_t* type);
50 
51     // Looks up information about the extension corresponding to the given prefix
52     //
53     // Returns false if no extension corresponds to the given prefix.
54     bool getExtensionInfo(uint16_t prefix, const Extension** extension) const;
55 
56     // Looks up information about an extension operand type
57     //
58     // Returns false if the extension or type is unknown.
59     bool getExtensionOperandTypeInfo(OperandType type,
60                                      const Extension::OperandTypeInformation** info) const;
61 
62     // Returns true if an operand type is a tensor type.
63     //
64     // Aborts if the type is an unknown extension type.
65     bool isTensorType(OperandType type) const;
66 
67     // Returns the amount of space needed to store a value of the dimensions and
68     // type of this operand. For a tensor with unspecified rank or at least one
69     // unspecified dimension, returns zero.
70     //
71     // Aborts if the type is an unknown extension type.
72     // Aborts if the size would overflow the return type.
getSizeOfData(const Operand & operand)73     uint32_t getSizeOfData(const Operand& operand) const {
74         return getSizeOfData(operand.type, operand.dimensions);
75     }
76 
77     // Returns the amount of space needed to store a value of the specified
78     // dimensions and type. For a tensor with unspecified rank or at least one
79     // unspecified dimension, returns zero.
80     //
81     // Aborts if the type is an unknown extension type.
82     uint32_t getSizeOfData(OperandType type, const std::vector<uint32_t>& dimensions) const;
83 
84     // Returns true if the amount of space needed to store a value of the specified
85     // dimensions and element size overflows the uint32_t type.
86     //
87     // See also TypeManager::sizeOfDataOverflowsUInt32().
88     bool sizeOfDataOverflowsUInt32(OperandType type, const std::vector<uint32_t>& dimensions) const;
89 
90     // Returns true if extensions usage is allowed in current process.
areExtensionsAllowed()91     bool areExtensionsAllowed() const { return mExtensionsAllowed; }
92 
93     // This method is intended for use only by internal unit tests.
94     //
95     // Registers an extension.
96     //
97     // Returns true if the registration was successful.
forTest_registerExtension(const Extension & extension)98     bool forTest_registerExtension(const Extension& extension) {
99         return registerExtension(extension, "INTERNAL TEST");
100     }
101 
102     // This method is intended for use only by internal unit tests.
103     //
104     // Resets the internal state.
105     //
106     // After calling forTest_registerExtension() any number of times, call
107     // forTest_reset() to return to the state as if forTest_registerExtension()
108     // had never been called. Note that forTest_reset() resets all internal
109     // state (including assigned prefixes) and re-discovers extensions from
110     // available devices.
forTest_reset()111     void forTest_reset() { *this = TypeManager(); }
112 
113 #ifndef NN_COMPATIBILITY_LIBRARY_BUILD
114     // Check if NNAPI Vendor extensions are usable in the process with the given app
115     // and supplemental infomation.
116     //
117     // useOnProductImageEnabled - whether apps/binaries preinstalled on /product partition
118     // can be enabled for extensions use.
119     // allowlist - list of apps/binaries which are allowed to use extensions.
120     static bool isExtensionsUseAllowed(const AppInfoFetcher::AppInfo& appPackageInfo,
121                                        bool useOnProductImageEnabled,
122                                        const std::vector<std::string>& allowlist);
123 #endif  // NN_COMPATIBILITY_LIBRARY_BUILD
124 
125    private:
126     TypeManager();
127     void findAvailableExtensions();
128     bool registerExtension(Extension extension, const std::string& deviceName);
129 
130     // Returns the numeric "prefix" value corresponding to an extension.
131     //
132     // Returns false when assigning a new prefix would overflow uint16_t.
133     bool getExtensionPrefix(const std::string& extensionName, uint16_t* prefix);
134 
135     const DeviceManager* mDeviceManager = DeviceManager::get();
136 
137     // Contains all registered extensions.
138     std::map<std::string, Extension> mExtensionNameToExtension;
139 
140     // Contains the name of the first discovered device that supports an
141     // extension. Used for error reporting.
142     std::map<std::string, std::string> mExtensionNameToFirstDevice;
143 
144     // When multiple devices report conflicting information about an extension,
145     // the extension is disabled.
146     std::set<std::string> mDisabledExtensions;
147 
148     // The fields below are used to support efficient extension name to
149     // prefix mapping. New prefixes are created by getExtensionPrefix.
150     std::map<std::string, uint16_t> mExtensionNameToPrefix;
151     // Entries of mPrefixToExtension point into mExtensionNameToExtension.
152     // prefix=0 corresponds to no extension and should never be looked up.
153     std::vector<Extension*> mPrefixToExtension = {nullptr};
154 
155     // True if Extensions can be used in current process.
156     bool mExtensionsAllowed = false;
157 };
158 
159 }  // namespace nn
160 }  // namespace android
161 
162 #endif  // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_TYPE_MANAGER_H
163