• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_META_MODEL_H
18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_META_MODEL_H
19 
20 #include "HalInterfaces.h"
21 
22 #include <android-base/macros.h>
23 #include <functional>
24 #include <map>
25 #include <optional>
26 #include <set>
27 #include <utility>
28 #include <vector>
29 
30 namespace android::nn {
31 
32 // The MetaModel class encapsulates a Model and provides machinery to create
33 // from that original Model a "slice" of that Model consisting of:
34 // - the subset of operations that is compliant with a particular HAL version; and
35 // - a mechanism for mapping operations from the slice back to operations of the
36 //   original Model.
37 // The slice is intended to be passed to IDevice::getSupportedOperations*(),
38 // with the mapping used to translate the results of that call from the slice's
39 // operations to the original Model's operations.  The slice has no other
40 // purpose (for example, it is not guaranteed to have the same topology as a
41 // subgraph of the original model).
42 //
43 // When a getSlice*() method is called, a slice is created and cached, if
44 // necessary; and then the cached slice is returned.
45 //
46 // The meaning of the return value of the getSlice*() methods is explained by
47 // the following example:
48 //
49 //     const MetaModel& metaModel = ...;
50 //     auto ret = metaModel.getSliceV1_0();  // getSliceV1_1() is similar
51 //     if (ret.has_value()) {
52 //         const V1_0::Model model = ret->first;  // the slice
53 //         auto mapper = ret->second;
54 //         // mapper is a functor that takes an operation index in the
55 //         // slice and returns the corresponding operation index in the
56 //         // original Model.  The functor will remain valid for the lifetime
57 //         // of the MetaModel.
58 //     } else {
59 //         // Could not obtain a slice.  For example, perhaps none of the
60 //         // original model's operations are compliant with V1_0.
61 //     }
62 //
63 class MetaModel {
64    public:
65     using Mapper = std::function<uint32_t(uint32_t)>;
66 
67     template <class T_Model>
68     using ReturnedSlice = std::optional<std::pair<T_Model, Mapper>>;
69 
MetaModel(hal::Model model,bool strictSlicing)70     MetaModel(hal::Model model, bool strictSlicing)
71         : mHidlModel(std::move(model)), mStrictSlicing(strictSlicing) {}
72 
getModel()73     const hal::Model& getModel() const { return mHidlModel; }
74 
getSliceV1_0()75     ReturnedSlice<hal::V1_0::Model> getSliceV1_0() const { return getSlice(&mSliceV1_0); }
getSliceV1_1()76     ReturnedSlice<hal::V1_1::Model> getSliceV1_1() const { return getSlice(&mSliceV1_1); }
getSliceV1_2()77     ReturnedSlice<hal::V1_2::Model> getSliceV1_2() const { return getSlice(&mSliceV1_2); }
78 
79     // Disallowing copy constructor and assignment operator is for efficiency,
80     // not for correctness.  The default copy constructor and assignment
81     // operator would work fine.  However, they could be surprisingly expensive
82     // if the mSlice* members get copied: Up to three Model instances and two
83     // std::vector instances could be copied.  We could choose to accept this
84     // expense; or we could write custom copy and assign that do not copy the
85     // mSlice* members but instead set the destination mSlice* members to
86     // SliceState::UNINITIALIZED.
87     //
88     // There are no such issues with move constructor and move assignment.
89     MetaModel(const MetaModel&) = delete;
90     MetaModel& operator=(const MetaModel&) = delete;
91     MetaModel(MetaModel&&) = default;
92     MetaModel& operator=(MetaModel&&) = default;
93 
94    private:
95     hal::Model mHidlModel;
96 
97     // mStrictSlicing controls sanity checking.  If the slicing algorithm
98     // produces an invalid model (because something has gone wrong with the
99     // algorithm or with a utility function it depends on), getSlice*() can
100     // return an std::optional<> for which has_value() returns false, signifying
101     // that no slice is available.  However, if mStrictSlicing is true,
102     // getSlice*() cause a CHECK*() to fail.  This can be used in debugging to
103     // find situations where slicing has failed unexpectedly.
104     bool mStrictSlicing;
105 
106     enum class SliceState { UNINITIALIZED, INVALID, NORMAL };
107     template <class T_SlicedModel>
108     struct Slice {
109         SliceState mState = SliceState::UNINITIALIZED;
110         T_SlicedModel mHidlModel;
111         std::vector<uint32_t> mSlicedOperationIndexToOrigIndex;
112 
113         using Operand = typename decltype(mHidlModel.operands)::value_type;
114         using Operation = typename decltype(mHidlModel.operations)::value_type;
115         using OperationType = decltype(Operation::type);
116     };
117     mutable Slice<hal::V1_0::Model> mSliceV1_0;
118     mutable Slice<hal::V1_1::Model> mSliceV1_1;
119     mutable Slice<hal::V1_2::Model> mSliceV1_2;
120 
121     template <class T_SlicedModel>
122     ReturnedSlice<T_SlicedModel> getSlice(Slice<T_SlicedModel>* slice) const;
123 
124     template <class T_SlicedModel>
125     Slice<T_SlicedModel> makeSlice() const;
126 
127     // Utility class for makeSlice().
128     template <typename T_SlicedOperand>
129     class OrigOperandToSlicedInputOperandIndex;
130 
131     // Utility function for makeSlice(): Walks operations of original
132     // model and populates sliced model accordingly.
133     template <class T_SlicedModel>
134     void processOperations(
135             Slice<T_SlicedModel>* slice,
136             std::map<uint32_t, uint32_t>* origOperandIndexToSlicedIndex,
137             OrigOperandToSlicedInputOperandIndex<typename Slice<T_SlicedModel>::Operand>*
138                     origOperandToSlicedInputOperandIndex,
139             const std::set<uint32_t>& noncompliantOperations,
140             const std::set<uint32_t>& inputOperandIndexesOfCompliantOperations) const;
141 };
142 
143 }  // namespace android::nn
144 
145 #endif  // ANDROID_FRAMEWORKS_ML_NN_COMMON_META_MODEL_H
146