• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef NEURAL_NETWORK_RUNTIME_NNCOMPILER_H
17 #define NEURAL_NETWORK_RUNTIME_NNCOMPILER_H
18 
19 #include "compiler.h"
20 
21 #include "mindir.h"
22 #include "device.h"
23 #include "inner_model.h"
24 #include "prepared_model.h"
25 #include "nnexecutor.h"
26 
27 namespace OHOS {
28 namespace NeuralNetworkRuntime {
29 
30 class NNCompiler : public Compiler {
31 public:
32     NNCompiler() = delete;
33     NNCompiler(std::shared_ptr<Device> device, size_t backendID);
34     NNCompiler(const void* model, std::shared_ptr<Device> device, size_t backendID);
35     ~NNCompiler() override;
36 
37     size_t GetBackendID() const override;
38 
39     OH_NN_ReturnCode SetCacheDir(const std::string& cacheModelPath, uint32_t version) override;
40     OH_NN_ReturnCode SetPerformance(OH_NN_PerformanceMode performance) override;
41     OH_NN_ReturnCode SetPriority(OH_NN_Priority priority) override;
42     OH_NN_ReturnCode SetEnableFp16(bool isFp16) override;
43 
44     bool IsBuild() const override;
45     OH_NN_ReturnCode Build() override;
46 
47     OH_NN_ReturnCode SaveToCacheFile() const override;
48     OH_NN_ReturnCode RestoreFromCacheFile() override;
49     OH_NN_ReturnCode SaveToCacheBuffer(const void* buffer, size_t length, size_t* modelSize) const override;
50     OH_NN_ReturnCode RestoreFromCacheBuffer(const void* buffer, size_t length) override;
51 
52     OH_NN_ReturnCode SetExtensionConfig(const std::unordered_map<std::string, std::vector<char>>& configs) override;
53     OH_NN_ReturnCode SetOptions(const std::vector<std::shared_ptr<void>>& options) override;
54 
55     NNExecutor* CreateExecutor();
56 
57 private:
58     void ReleaseBuffer(std::vector<Buffer>& buffers) const;
59     void ReleaseBufferByDevice(std::vector<Buffer>& buffers) const;
60     OH_NN_ReturnCode SerializeTensorsToBuffer(
61         const std::vector<std::pair<std::shared_ptr<TensorDesc>, OH_NN_TensorType>>& tensorDescs,
62         Buffer& buffer) const;
63     OH_NN_ReturnCode DeserializedTensorsFromBuffer(
64         const Buffer& buffer, std::vector<std::pair<std::shared_ptr<TensorDesc>, OH_NN_TensorType>>& tensorDescs);
65 
66     OH_NN_ReturnCode NormalBuild();
67     OH_NN_ReturnCode BuildOfflineModel();
68     OH_NN_ReturnCode CheckModelParameter() const;
69     OH_NN_ReturnCode IsOfflineModel(bool& isOfflineModel) const;
70     OH_NN_ReturnCode IsSupportedModel(const std::shared_ptr<mindspore::lite::LiteGraph>& liteGraph,
71                                       bool& isSupportedModel) const;
72 
73 private:
74     bool m_isBuild {false};
75     bool m_enableFp16 {false};
76     std::string m_cachePath;
77     uint32_t m_cacheVersion {0};
78     std::shared_ptr<Device> m_device {nullptr};
79     size_t m_backendID {0};
80     OH_NN_Priority m_priority {OH_NN_PRIORITY_NONE};
81     OH_NN_PerformanceMode m_performance {OH_NN_PERFORMANCE_NONE};
82     std::shared_ptr<PreparedModel> m_preparedModel {nullptr};
83     Buffer m_quantBuffer {nullptr, 0};
84     std::string m_modelName;
85     void* m_metaGraph {nullptr};
86     InnerModel* m_innerModel {nullptr};
87     std::shared_ptr<mindspore::lite::LiteGraph> m_liteGraph {nullptr};
88     std::vector<std::pair<std::shared_ptr<TensorDesc>, OH_NN_TensorType>> m_inputTensorDescs;
89     std::vector<std::pair<std::shared_ptr<TensorDesc>, OH_NN_TensorType>> m_outputTensorDescs;
90 };
91 } // NeuralNetworkRuntime
92 } // OHOS
93 
94 #endif // NEURAL_NETWORK_RUNTIME_NNCOMPILER_H