1 /** 2 * Copyright 2020-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CORE_LOAD_MODEL_H 17 #define MINDSPORE_CORE_LOAD_MODEL_H 18 19 #include <map> 20 #include <vector> 21 #include <string> 22 #include <memory> 23 24 #include "ir/func_graph.h" 25 26 namespace mindspore { 27 class Layout { 28 public: 29 Layout() = default; 30 get_device_arrangement()31 const std::vector<int64_t> &get_device_arrangement() const { return device_arrangement_; } set_device_arrangement(const std::vector<int64_t> & device_arrangement)32 void set_device_arrangement(const std::vector<int64_t> &device_arrangement) { 33 device_arrangement_ = device_arrangement; 34 } get_tensor_map()35 const std::vector<int64_t> &get_tensor_map() const { return tensor_map_; } set_tensor_map(const std::vector<int64_t> & tensor_map)36 void set_tensor_map(const std::vector<int64_t> &tensor_map) { tensor_map_ = tensor_map; } get_slice_shape()37 const std::vector<int64_t> &get_slice_shape() const { return slice_shape_; } set_slice_shape(const std::vector<int64_t> & slice_shape)38 void set_slice_shape(const std::vector<int64_t> &slice_shape) { slice_shape_ = slice_shape; } get_field_size()39 int64_t get_field_size() const { return field_size_; } set_field_size(int64_t field_size)40 void set_field_size(int64_t field_size) { field_size_ = field_size; } get_uniform_split()41 bool get_uniform_split() const { return uniform_split_; } set_uniform_split(bool uniform_split)42 void set_uniform_split(bool uniform_split) { uniform_split_ = uniform_split; } get_opt_shard_group()43 const std::string &get_opt_shard_group() const { return opt_shard_group_; } set_opt_shard_group(const std::string & opt_shard_group)44 void set_opt_shard_group(const std::string &opt_shard_group) { opt_shard_group_ = opt_shard_group; } set_pipeline_shared(bool pipeline_shared)45 void set_pipeline_shared(bool pipeline_shared) { pipeline_shared_ = pipeline_shared; } pipeline_shared()46 bool pipeline_shared() const { return pipeline_shared_; } set_is_send(bool is_send)47 void set_is_send(bool is_send) { is_send_ = is_send; } is_send()48 bool is_send() const { return is_send_; } set_peer_rank(int64_t peer_rank)49 void set_peer_rank(int64_t peer_rank) { peer_rank_ = peer_rank; } peer_rank()50 int64_t peer_rank() const { return peer_rank_; } set_sr_tag(int64_t sr_tag)51 void set_sr_tag(int64_t sr_tag) { sr_tag_ = sr_tag; } sr_tag()52 int64_t sr_tag() const { return sr_tag_; } 53 54 private: 55 std::vector<int64_t> device_arrangement_{}; 56 std::vector<int64_t> tensor_map_{}; 57 std::vector<int64_t> slice_shape_{}; 58 int64_t field_size_ = 0; 59 bool uniform_split_ = false; 60 std::string opt_shard_group_ = ""; 61 // pipeline stage shared param info 62 bool pipeline_shared_ = false; 63 bool is_send_ = false; 64 int64_t peer_rank_{0}; 65 int64_t sr_tag_{0}; 66 }; 67 using LayoutPtr = std::shared_ptr<Layout>; 68 using LayoutMap = std::map<string, LayoutPtr>; 69 class MS_CORE_API MindIRLoader { 70 public: 71 MindIRLoader() = default; MindIRLoader(bool is_lite,const unsigned char * dec_key,const size_t key_len,const std::string & dec_mode,bool inc_load)72 MindIRLoader(bool is_lite, const unsigned char *dec_key, const size_t key_len, const std::string &dec_mode, 73 bool inc_load) 74 : is_lite_(is_lite), dec_key_(dec_key), key_len_(key_len), dec_mode_(dec_mode), inc_load_(inc_load) {} 75 ~MindIRLoader() = default; 76 set_has_parallel_info(bool has_parallel_info)77 void set_has_parallel_info(bool has_parallel_info) { has_parallel_info_ = has_parallel_info; } set_weights_value_map(const std::map<string,ValuePtr> & weights_value_map)78 void set_weights_value_map(const std::map<string, ValuePtr> &weights_value_map) { 79 weights_value_map_ = weights_value_map; 80 } layout_map()81 const LayoutMap &layout_map() const { return layout_map_; } 82 FuncGraphPtr LoadMindIR(const void *buffer, const size_t &size); 83 FuncGraphPtr LoadMindIR(const void *buffer, const size_t &size, const std::string &mindir_path); 84 FuncGraphPtr LoadMindIR(const std::string &file_name, 85 mindspore::HashMap<std::string, AnfNodePtr> *name_to_node = nullptr); 86 bool LoadMindIR(const std::string &file_name, const std::vector<FuncGraphPtr> &graphs, 87 mindspore::HashMap<std::string, AnfNodePtr> *name_to_node = nullptr); 88 bool LoadMindIR(const void *buffer, const size_t &size, const std::string &mindir_path, FuncGraphPtr *func_graph, 89 std::string *user_info_string); 90 std::vector<FuncGraphPtr> LoadMindIRs(const std::vector<std::string> &file_names); 91 std::vector<std::string> LoadPreprocess(const std::string &file_name); is_lite()92 bool is_lite() const { return is_lite_; } inc_load()93 bool inc_load() const { return inc_load_; } key_len()94 size_t key_len() const { return key_len_; } dec_key()95 const unsigned char *dec_key() const { return dec_key_; } dec_mode()96 const std::string &dec_mode() const { return dec_mode_; } 97 98 private: 99 bool is_lite_ = false; 100 const unsigned char *dec_key_ = nullptr; 101 size_t key_len_ = 0; 102 std::string dec_mode_ = std::string("AES-GCM"); 103 bool inc_load_ = false; 104 std::map<string, ValuePtr> weights_value_map_; 105 bool has_parallel_info_ = false; 106 LayoutMap layout_map_; 107 }; 108 MS_CORE_API FuncGraphPtr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false); 109 } // namespace mindspore 110 #endif // MINDSPORE_CORE_LOAD_MODEL_H 111