1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_NODE_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_NODE_H_ 18 19 #include <memory> 20 #include <unordered_map> 21 #include <vector> 22 23 #include "minddata/dataset/engine/gnn/feature.h" 24 #include "minddata/dataset/util/status.h" 25 26 namespace mindspore { 27 namespace dataset { 28 namespace gnn { 29 using NodeType = int8_t; 30 using NodeIdType = int32_t; 31 using WeightType = float; 32 using EdgeIdType = int32_t; 33 34 constexpr NodeIdType kDefaultNodeId = -1; 35 36 class Edge; 37 38 class Node { 39 public: 40 // Constructor 41 // @param NodeIdType id - node id 42 // @param NodeType type - node type 43 // @param WeightType type - node weight Node(NodeIdType id,NodeType type,WeightType weight)44 Node(NodeIdType id, NodeType type, WeightType weight) : id_(id), type_(type), weight_(weight) {} 45 46 virtual ~Node() = default; 47 48 // @return NodeIdType - Returned node id id()49 NodeIdType id() const { return id_; } 50 51 // @return NodeIdType - Returned node type type()52 NodeType type() const { return type_; } 53 54 // @return WeightType - Returned node weight weight()55 WeightType weight() const { return weight_; } 56 57 // Get the feature of a node 58 // @param FeatureType feature_type - type of feature 59 // @param std::shared_ptr<Feature> *out_feature - Returned feature 60 // @return Status The status code returned 61 virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0; 62 63 // Get the all neighbors of a node 64 // @param NodeType neighbor_type - type of neighbor 65 // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id 66 // @return Status The status code returned 67 virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors, 68 bool exclude_itself = false) = 0; 69 70 // Get the sampled neighbors of a node 71 // @param NodeType neighbor_type - type of neighbor 72 // @param int32_t samples_num - Number of neighbors to be acquired 73 // @param SamplingStrategy strategy - Sampling strategy 74 // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id 75 // @return Status The status code returned 76 virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy, 77 std::vector<NodeIdType> *out_neighbors) = 0; 78 79 // Add neighbor of node 80 // @param std::shared_ptr<Node> node - 81 // @return Status The status code returned 82 virtual Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &weight) = 0; 83 84 // Add adjacent node and relative edge for source node 85 // @param std::shared_ptr<Node> node - the node to be inserted into adjacent table 86 // @param std::shared_ptr<Edge> edge - the edge related to the adjacent node of source node 87 // @return Status - The status code that indicate the result of function execution 88 virtual Status AddAdjacent(const std::shared_ptr<Node> &node, const std::shared_ptr<Edge> &edge) = 0; 89 90 // Get relative connecting edge of adjacent node by node id 91 // @param NodeIdType - The id of adjacent node to be processed 92 // @param std::shared_ptr<EdgeIdType> - The id of relative connecting edge 93 // @return Status - The status code that indicate the result of function execution 94 virtual Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType *out_edge_id) = 0; 95 96 // Update feature of node 97 // @param std::shared_ptr<Feature> feature - 98 // @return Status The status code returned 99 virtual Status UpdateFeature(const std::shared_ptr<Feature> &feature) = 0; 100 101 protected: 102 NodeIdType id_; 103 NodeType type_; 104 WeightType weight_; 105 }; 106 } // namespace gnn 107 } // namespace dataset 108 } // namespace mindspore 109 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_NODE_H_ 110