1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_LOADER_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_LOADER_H_ 18 19 #include <deque> 20 #include <memory> 21 #include <queue> 22 #include <string> 23 #include <vector> 24 #include <unordered_map> 25 #include <unordered_set> 26 27 #include "minddata/dataset/core/data_type.h" 28 #include "minddata/dataset/core/tensor.h" 29 #include "minddata/dataset/engine/gnn/edge.h" 30 #include "minddata/dataset/engine/gnn/feature.h" 31 #include "minddata/dataset/engine/gnn/graph_feature_parser.h" 32 #if !defined(_WIN32) && !defined(_WIN64) 33 #include "minddata/dataset/engine/gnn/graph_shared_memory.h" 34 #endif 35 #include "minddata/dataset/engine/gnn/node.h" 36 #include "minddata/dataset/util/status.h" 37 #include "minddata/mindrecord/include/shard_reader.h" 38 #include "minddata/dataset/engine/gnn/graph_data_impl.h" 39 namespace mindspore { 40 namespace dataset { 41 namespace gnn { 42 43 using mindrecord::ShardReader; 44 using NodeIdMap = std::unordered_map<NodeIdType, std::shared_ptr<Node>>; 45 using EdgeIdMap = std::unordered_map<EdgeIdType, std::shared_ptr<Edge>>; 46 using NodeTypeMap = std::unordered_map<NodeType, std::vector<NodeIdType>>; 47 using EdgeTypeMap = std::unordered_map<EdgeType, std::vector<EdgeIdType>>; 48 using NodeFeatureMap = std::unordered_map<NodeType, std::unordered_set<FeatureType>>; 49 using EdgeFeatureMap = std::unordered_map<EdgeType, std::unordered_set<FeatureType>>; 50 using DefaultNodeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>; 51 using DefaultEdgeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>; 52 53 // this class interfaces with the underlying storage format (mindrecord) 54 // it returns raw nodes and edges via GetNodesAndEdges 55 // it is then the responsibility of graph to construct itself based on the nodes and edges 56 // if needed, this class could become a base where each derived class handles a specific storage format 57 class GraphLoader { 58 public: 59 GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int32_t num_workers = 4, bool server_mode = false); 60 61 ~GraphLoader() = default; 62 // Init mindrecord and load everything into memory multi-threaded 63 // @return Status - the status code 64 Status InitAndLoad(); 65 66 // this function will query mindrecord and construct all nodes and edges 67 // nodes and edges are added to map without any connection. That's because there nodes and edges are read in 68 // random order. src_node and dst_node in Edge are node_id only with -1 as type. 69 // features attached to each node and edge are expected to be filled correctly 70 Status GetNodesAndEdges(); 71 72 private: 73 // 74 // worker thread that reads mindrecord file 75 // @param int32_t worker_id - id of each worker 76 // @return Status - the status code 77 Status WorkerEntry(int32_t worker_id); 78 79 // Load a node based on 1 row of mindrecord, returns a shared_ptr<Node> 80 // @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord 81 // @param mindrecord::json &jsn - contains raw data 82 // @param std::shared_ptr<Node> *node - return value 83 // @param NodeFeatureMap *feature_map - 84 // @param DefaultNodeFeatureMap *default_feature - 85 // @return Status - the status code 86 Status LoadNode(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Node> *node, 87 NodeFeatureMap *feature_map, DefaultNodeFeatureMap *default_feature); 88 89 // @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord 90 // @param mindrecord::json &jsn - contains raw data 91 // @param std::shared_ptr<Edge> *edge - return value, the edge ptr, edge is not yet connected 92 // @param FeatureMap *feature_map 93 // @param DefaultEdgeFeatureMap *default_feature - 94 // @return Status - the status code 95 Status LoadEdge(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Edge> *edge, 96 EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature); 97 98 // merge NodeFeatureMap and EdgeFeatureMap of each worker into 1 99 void MergeFeatureMaps(); 100 101 GraphDataImpl *graph_impl_; 102 std::string mr_path_; 103 const int32_t num_workers_; 104 std::atomic_int row_id_; 105 std::unique_ptr<ShardReader> shard_reader_; 106 std::unique_ptr<GraphFeatureParser> graph_feature_parser_; 107 std::vector<std::deque<std::shared_ptr<Node>>> n_deques_; 108 std::vector<std::deque<std::shared_ptr<Edge>>> e_deques_; 109 std::vector<NodeFeatureMap> n_feature_maps_; 110 std::vector<EdgeFeatureMap> e_feature_maps_; 111 std::vector<DefaultNodeFeatureMap> default_node_feature_maps_; 112 std::vector<DefaultEdgeFeatureMap> default_edge_feature_maps_; 113 const std::vector<std::string> required_key_; 114 std::unordered_map<std::string, bool> optional_key_; 115 }; 116 } // namespace gnn 117 } // namespace dataset 118 } // namespace mindspore 119 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_LOADER_H_ 120