• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CORE_LOAD_MODEL_H
17 #define MINDSPORE_CORE_LOAD_MODEL_H
18 
19 #include <map>
20 #include <vector>
21 #include <string>
22 #include <memory>
23 
24 #include "ir/func_graph.h"
25 
26 namespace mindspore {
27 class Layout {
28  public:
29   Layout() = default;
30 
get_device_arrangement()31   const std::vector<int64_t> &get_device_arrangement() const { return device_arrangement_; }
set_device_arrangement(const std::vector<int64_t> & device_arrangement)32   void set_device_arrangement(const std::vector<int64_t> &device_arrangement) {
33     device_arrangement_ = device_arrangement;
34   }
get_tensor_map()35   const std::vector<int64_t> &get_tensor_map() const { return tensor_map_; }
set_tensor_map(const std::vector<int64_t> & tensor_map)36   void set_tensor_map(const std::vector<int64_t> &tensor_map) { tensor_map_ = tensor_map; }
get_slice_shape()37   const std::vector<int64_t> &get_slice_shape() const { return slice_shape_; }
set_slice_shape(const std::vector<int64_t> & slice_shape)38   void set_slice_shape(const std::vector<int64_t> &slice_shape) { slice_shape_ = slice_shape; }
get_field_size()39   int64_t get_field_size() const { return field_size_; }
set_field_size(int64_t field_size)40   void set_field_size(int64_t field_size) { field_size_ = field_size; }
get_uniform_split()41   bool get_uniform_split() const { return uniform_split_; }
set_uniform_split(bool uniform_split)42   void set_uniform_split(bool uniform_split) { uniform_split_ = uniform_split; }
get_opt_shard_group()43   const std::string &get_opt_shard_group() const { return opt_shard_group_; }
set_opt_shard_group(const std::string & opt_shard_group)44   void set_opt_shard_group(const std::string &opt_shard_group) { opt_shard_group_ = opt_shard_group; }
set_pipeline_shared(bool pipeline_shared)45   void set_pipeline_shared(bool pipeline_shared) { pipeline_shared_ = pipeline_shared; }
pipeline_shared()46   bool pipeline_shared() const { return pipeline_shared_; }
set_is_send(bool is_send)47   void set_is_send(bool is_send) { is_send_ = is_send; }
is_send()48   bool is_send() const { return is_send_; }
set_peer_rank(int64_t peer_rank)49   void set_peer_rank(int64_t peer_rank) { peer_rank_ = peer_rank; }
peer_rank()50   int64_t peer_rank() const { return peer_rank_; }
set_sr_tag(int64_t sr_tag)51   void set_sr_tag(int64_t sr_tag) { sr_tag_ = sr_tag; }
sr_tag()52   int64_t sr_tag() const { return sr_tag_; }
53 
54  private:
55   std::vector<int64_t> device_arrangement_{};
56   std::vector<int64_t> tensor_map_{};
57   std::vector<int64_t> slice_shape_{};
58   int64_t field_size_ = 0;
59   bool uniform_split_ = false;
60   std::string opt_shard_group_ = "";
61   // pipeline stage shared param info
62   bool pipeline_shared_ = false;
63   bool is_send_ = false;
64   int64_t peer_rank_{0};
65   int64_t sr_tag_{0};
66 };
67 using LayoutPtr = std::shared_ptr<Layout>;
68 using LayoutMap = std::map<string, LayoutPtr>;
69 class MS_CORE_API MindIRLoader {
70  public:
71   MindIRLoader() = default;
MindIRLoader(bool is_lite,const unsigned char * dec_key,const size_t key_len,const std::string & dec_mode,bool inc_load)72   MindIRLoader(bool is_lite, const unsigned char *dec_key, const size_t key_len, const std::string &dec_mode,
73                bool inc_load)
74       : is_lite_(is_lite), dec_key_(dec_key), key_len_(key_len), dec_mode_(dec_mode), inc_load_(inc_load) {}
75   ~MindIRLoader() = default;
76 
set_has_parallel_info(bool has_parallel_info)77   void set_has_parallel_info(bool has_parallel_info) { has_parallel_info_ = has_parallel_info; }
set_weights_value_map(const std::map<string,ValuePtr> & weights_value_map)78   void set_weights_value_map(const std::map<string, ValuePtr> &weights_value_map) {
79     weights_value_map_ = weights_value_map;
80   }
layout_map()81   const LayoutMap &layout_map() const { return layout_map_; }
82   FuncGraphPtr LoadMindIR(const void *buffer, const size_t &size);
83   FuncGraphPtr LoadMindIR(const void *buffer, const size_t &size, const std::string &mindir_path);
84   FuncGraphPtr LoadMindIR(const std::string &file_name,
85                           mindspore::HashMap<std::string, AnfNodePtr> *name_to_node = nullptr);
86   bool LoadMindIR(const std::string &file_name, const std::vector<FuncGraphPtr> &graphs,
87                   mindspore::HashMap<std::string, AnfNodePtr> *name_to_node = nullptr);
88   bool LoadMindIR(const void *buffer, const size_t &size, const std::string &mindir_path, FuncGraphPtr *func_graph,
89                   std::string *user_info_string);
90   std::vector<FuncGraphPtr> LoadMindIRs(const std::vector<std::string> &file_names);
91   std::vector<std::string> LoadPreprocess(const std::string &file_name);
is_lite()92   bool is_lite() const { return is_lite_; }
inc_load()93   bool inc_load() const { return inc_load_; }
key_len()94   size_t key_len() const { return key_len_; }
dec_key()95   const unsigned char *dec_key() const { return dec_key_; }
dec_mode()96   const std::string &dec_mode() const { return dec_mode_; }
97 
98  private:
99   bool is_lite_ = false;
100   const unsigned char *dec_key_ = nullptr;
101   size_t key_len_ = 0;
102   std::string dec_mode_ = std::string("AES-GCM");
103   bool inc_load_ = false;
104   std::map<string, ValuePtr> weights_value_map_;
105   bool has_parallel_info_ = false;
106   LayoutMap layout_map_;
107 };
108 MS_CORE_API FuncGraphPtr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false);
109 }  // namespace mindspore
110 #endif  // MINDSPORE_CORE_LOAD_MODEL_H
111