1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_RUNTIME_PACK_WEIGHT_H_ 18 #define MINDSPORE_LITE_SRC_RUNTIME_PACK_WEIGHT_H_ 19 #include <map> 20 #include <string> 21 #include <algorithm> 22 #include <utility> 23 #include <vector> 24 #include <set> 25 #include <mutex> 26 #include <unordered_map> 27 #include <memory> 28 #include "src/tensor.h" 29 #include "src/litert/lite_session.h" 30 namespace mindspore::lite { 31 struct ModelConstWeight { 32 // origin tensor data <-> packed tensor data 33 std::map<const void *, void *> origin_and_packed_pair; 34 std::shared_ptr<Allocator> allocator = nullptr; 35 int numa_id = -1; 36 std::unordered_map<int, void *> tensors_data; 37 std::set<void *> fp16_fp32_data; 38 bool copy_buf; 39 }; 40 41 class PackWeight { 42 public: 43 PackWeight() = default; 44 ~PackWeight(); 45 STATUS InitPackWeight(const void *model_buf, size_t model_size, std::string id, int numa_id, 46 bool need_copy_buf = true); 47 char *GetSharedModelBuf(std::string id, int numa_id); 48 STATUS StoreOriginTensorData(const void *model_buf, const void *origin_tensor_data); 49 void *GetPackData(const void *tensor_data, const size_t size, bool *is_packed); 50 STATUS ReplaceOriginTensorData(const void *model_buf, std::vector<Tensor *> *tensors, int tensor_index); 51 void *ReplaceFp16Data(void *origin_fp16_data, size_t size); 52 void FreePackWeight(std::string id, bool free_all = false); 53 54 private: 55 void FreePackedWeight(ModelConstWeight *weight); 56 void FreeTensorData(ModelConstWeight *weight); 57 void FreeFp16ToFp32Data(ModelConstWeight *weight); 58 59 std::mutex mtx_weight_; 60 std::unordered_map<void *, void *> fp16_fp32_data_pair_; 61 // runner_id/model_id : { numa_id : ModelConstWeight } 62 std::unordered_map<std::string, std::unordered_map<int, ModelConstWeight *>> model_weights_; 63 // runner_id/model_id : { numa_id : shared model buf address } 64 std::unordered_map<std::string, std::unordered_map<int, void *>> shared_bufs_; 65 }; 66 } // namespace mindspore::lite 67 #endif // MINDSPORE_LITE_SRC_RUNTIME_PACK_WEIGHT_H_ 68