• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_LOADER_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_LOADER_H_
18 
19 #include <deque>
20 #include <memory>
21 #include <queue>
22 #include <string>
23 #include <vector>
24 #include <unordered_map>
25 #include <unordered_set>
26 
27 #include "minddata/dataset/core/data_type.h"
28 #include "minddata/dataset/core/tensor.h"
29 #include "minddata/dataset/engine/gnn/edge.h"
30 #include "minddata/dataset/engine/gnn/feature.h"
31 #include "minddata/dataset/engine/gnn/graph_feature_parser.h"
32 #if !defined(_WIN32) && !defined(_WIN64)
33 #include "minddata/dataset/engine/gnn/graph_shared_memory.h"
34 #endif
35 #include "minddata/dataset/engine/gnn/node.h"
36 #include "minddata/dataset/util/status.h"
37 #include "minddata/mindrecord/include/shard_reader.h"
38 #include "minddata/dataset/engine/gnn/graph_data_impl.h"
39 namespace mindspore {
40 namespace dataset {
41 namespace gnn {
42 
43 using mindrecord::ShardReader;
44 using NodeIdMap = std::unordered_map<NodeIdType, std::shared_ptr<Node>>;
45 using EdgeIdMap = std::unordered_map<EdgeIdType, std::shared_ptr<Edge>>;
46 using NodeTypeMap = std::unordered_map<NodeType, std::vector<NodeIdType>>;
47 using EdgeTypeMap = std::unordered_map<EdgeType, std::vector<EdgeIdType>>;
48 using NodeFeatureMap = std::unordered_map<NodeType, std::unordered_set<FeatureType>>;
49 using EdgeFeatureMap = std::unordered_map<EdgeType, std::unordered_set<FeatureType>>;
50 using DefaultNodeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>;
51 using DefaultEdgeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>;
52 
53 // this class interfaces with the underlying storage format (mindrecord)
54 // it returns raw nodes and edges via GetNodesAndEdges
55 // it is then the responsibility of graph to construct itself based on the nodes and edges
56 // if needed, this class could become a base where each derived class handles a specific storage format
57 class GraphLoader {
58  public:
59   GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int32_t num_workers = 4, bool server_mode = false);
60 
61   ~GraphLoader() = default;
62   // Init mindrecord and load everything into memory multi-threaded
63   // @return Status - the status code
64   Status InitAndLoad();
65 
66   // this function will query mindrecord and construct all nodes and edges
67   // nodes and edges are added to map without any connection. That's because there nodes and edges are read in
68   // random order. src_node and dst_node in Edge are node_id only with -1 as type.
69   // features attached to each node and edge are expected to be filled correctly
70   Status GetNodesAndEdges();
71 
72  private:
73   //
74   // worker thread that reads mindrecord file
75   // @param int32_t worker_id - id of each worker
76   // @return Status - the status code
77   Status WorkerEntry(int32_t worker_id);
78 
79   // Load a node based on 1 row of mindrecord, returns a shared_ptr<Node>
80   // @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
81   // @param mindrecord::json &jsn - contains raw data
82   // @param std::shared_ptr<Node> *node - return value
83   // @param NodeFeatureMap *feature_map -
84   // @param DefaultNodeFeatureMap *default_feature -
85   // @return Status - the status code
86   Status LoadNode(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Node> *node,
87                   NodeFeatureMap *feature_map, DefaultNodeFeatureMap *default_feature);
88 
89   // @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
90   // @param mindrecord::json &jsn - contains raw data
91   // @param std::shared_ptr<Edge> *edge - return value, the edge ptr, edge is not yet connected
92   // @param FeatureMap *feature_map
93   // @param DefaultEdgeFeatureMap *default_feature -
94   // @return Status - the status code
95   Status LoadEdge(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Edge> *edge,
96                   EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature);
97 
98   // merge NodeFeatureMap and EdgeFeatureMap of each worker into 1
99   void MergeFeatureMaps();
100 
101   GraphDataImpl *graph_impl_;
102   std::string mr_path_;
103   const int32_t num_workers_;
104   std::atomic_int row_id_;
105   std::unique_ptr<ShardReader> shard_reader_;
106   std::unique_ptr<GraphFeatureParser> graph_feature_parser_;
107   std::vector<std::deque<std::shared_ptr<Node>>> n_deques_;
108   std::vector<std::deque<std::shared_ptr<Edge>>> e_deques_;
109   std::vector<NodeFeatureMap> n_feature_maps_;
110   std::vector<EdgeFeatureMap> e_feature_maps_;
111   std::vector<DefaultNodeFeatureMap> default_node_feature_maps_;
112   std::vector<DefaultEdgeFeatureMap> default_edge_feature_maps_;
113   const std::vector<std::string> required_key_;
114   std::unordered_map<std::string, bool> optional_key_;
115 };
116 }  // namespace gnn
117 }  // namespace dataset
118 }  // namespace mindspore
119 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_LOADER_H_
120