• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_LOCAL_NODE_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_LOCAL_NODE_H_
18 
19 #include <memory>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23 
24 #include "minddata/dataset/engine/gnn/node.h"
25 #include "minddata/dataset/engine/gnn/feature.h"
26 #include "minddata/dataset/util/status.h"
27 
28 namespace mindspore {
29 namespace dataset {
30 namespace gnn {
31 
32 class LocalNode : public Node {
33  public:
34   // Constructor
35   // @param NodeIdType id - node id
36   // @param NodeType type - node type
37   LocalNode(NodeIdType id, NodeType type, WeightType weight);
38 
39   ~LocalNode() = default;
40 
41   // Get the feature of a node
42   // @param FeatureType feature_type - type of feature
43   // @param std::shared_ptr<Feature> *out_feature - Returned feature
44   // @return Status The status code returned
45   Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override;
46 
47   // Get the all neighbors of a node
48   // @param NodeType neighbor_type - type of neighbor
49   // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
50   // @return Status The status code returned
51   Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors,
52                          bool exclude_itself = false) override;
53 
54   // Get the sampled neighbors of a node
55   // @param NodeType neighbor_type - type of neighbor
56   // @param int32_t samples_num - Number of neighbors to be acquired
57   // @param SamplingStrategy strategy - Sampling strategy
58   // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
59   // @return Status The status code returned
60   Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy,
61                              std::vector<NodeIdType> *out_neighbors) override;
62 
63   // Add neighbor of node
64   // @param std::shared_ptr<Node> node -
65   // @return Status The status code returned
66   Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &) override;
67 
68   // Add adjacent node and relative edge for source node
69   // @param std::shared_ptr<Node> node - the node to be inserted into adjacent table
70   // @param std::shared_ptr<Edge> edge - the edge related to the adjacent node of source node
71   // @return Status - The status code that indicate the result of function execution
72   Status AddAdjacent(const std::shared_ptr<Node> &node, const std::shared_ptr<Edge> &edge) override;
73 
74   // Get relative connecting edge of adjacent node by node id
75   // @param NodeIdType - The id of adjacent node to be processed
76   // @param std::shared_ptr<EdgeIdType> - The id of relative connecting edge
77   // @return Status - The status code that indicate the result of function execution
78   Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType *out_edge_id) override;
79 
80   // Update feature of node
81   // @param std::shared_ptr<Feature> feature -
82   // @return Status The status code returned
83   Status UpdateFeature(const std::shared_ptr<Feature> &feature) override;
84 
85  private:
86   Status GetRandomSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
87                                    std::vector<NodeIdType> *out);
88 
89   Status GetWeightSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors,
90                                    const std::vector<WeightType> &weights, int32_t samples_num,
91                                    std::vector<NodeIdType> *out);
92 
93   std::mt19937 rnd_;
94   std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_;
95   std::unordered_map<NodeType, std::pair<std::vector<std::shared_ptr<Node>>, std::vector<WeightType>>> neighbor_nodes_;
96   std::unordered_map<NodeIdType, EdgeIdType> adjacent_nodes_;
97 };
98 }  // namespace gnn
99 }  // namespace dataset
100 }  // namespace mindspore
101 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_LOCAL_NODE_H_
102