• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/engine/gnn/local_node.h"
17 
18 #include <algorithm>
19 #include <random>
20 #include <string>
21 #include <utility>
22 
23 #include "minddata/dataset/engine/gnn/edge.h"
24 #include "minddata/dataset/util/random.h"
25 
26 namespace mindspore {
27 namespace dataset {
28 namespace gnn {
29 
LocalNode(NodeIdType id,NodeType type,WeightType weight)30 LocalNode::LocalNode(NodeIdType id, NodeType type, WeightType weight)
31     : Node(id, type, weight), rnd_(GetRandomDevice()) {
32   rnd_.seed(GetSeed());
33 }
34 
GetFeatures(FeatureType feature_type,std::shared_ptr<Feature> * out_feature)35 Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
36   auto itr = features_.find(feature_type);
37   if (itr != features_.end()) {
38     *out_feature = itr->second;
39     return Status::OK();
40   } else {
41     std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
42     RETURN_STATUS_UNEXPECTED(err_msg);
43   }
44 }
45 
GetAllNeighbors(NodeType neighbor_type,std::vector<NodeIdType> * out_neighbors,bool exclude_itself)46 Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors, bool exclude_itself) {
47   std::vector<NodeIdType> neighbors;
48   auto itr = neighbor_nodes_.find(neighbor_type);
49   if (itr != neighbor_nodes_.end()) {
50     if (exclude_itself) {
51       neighbors.resize(itr->second.first.size());
52       std::transform(itr->second.first.begin(), itr->second.first.end(), neighbors.begin(),
53                      [](const std::shared_ptr<Node> node) { return node->id(); });
54     } else {
55       neighbors.resize(itr->second.first.size() + 1);
56       neighbors[0] = id_;
57       std::transform(itr->second.first.begin(), itr->second.first.end(), neighbors.begin() + 1,
58                      [](const std::shared_ptr<Node> node) { return node->id(); });
59     }
60   } else {
61     MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type;
62     if (!exclude_itself) {
63       neighbors.emplace_back(id_);
64     }
65   }
66   *out_neighbors = std::move(neighbors);
67   return Status::OK();
68 }
69 
GetRandomSampledNeighbors(const std::vector<std::shared_ptr<Node>> & neighbors,int32_t samples_num,std::vector<NodeIdType> * out)70 Status LocalNode::GetRandomSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
71                                             std::vector<NodeIdType> *out) {
72   std::vector<NodeIdType> shuffled_id(neighbors.size());
73   std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
74   std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
75   int32_t num = std::min(samples_num, static_cast<int32_t>(neighbors.size()));
76   for (int32_t i = 0; i < num; ++i) {
77     out->emplace_back(neighbors[shuffled_id[i]]->id());
78   }
79   return Status::OK();
80 }
81 
GetWeightSampledNeighbors(const std::vector<std::shared_ptr<Node>> & neighbors,const std::vector<WeightType> & weights,int32_t samples_num,std::vector<NodeIdType> * out)82 Status LocalNode::GetWeightSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors,
83                                             const std::vector<WeightType> &weights, int32_t samples_num,
84                                             std::vector<NodeIdType> *out) {
85   CHECK_FAIL_RETURN_UNEXPECTED(neighbors.size() == weights.size(),
86                                "The number of neighbors does not match the weight.");
87   std::discrete_distribution<NodeIdType> discrete_dist(weights.begin(), weights.end());
88   for (int32_t i = 0; i < samples_num; ++i) {
89     NodeIdType index = discrete_dist(rnd_);
90     out->emplace_back(neighbors[index]->id());
91   }
92   return Status::OK();
93 }
94 
GetSampledNeighbors(NodeType neighbor_type,int32_t samples_num,SamplingStrategy strategy,std::vector<NodeIdType> * out_neighbors)95 Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy,
96                                       std::vector<NodeIdType> *out_neighbors) {
97   std::vector<NodeIdType> neighbors;
98   neighbors.reserve(samples_num);
99   auto itr = neighbor_nodes_.find(neighbor_type);
100   if (itr != neighbor_nodes_.end()) {
101     if (strategy == SamplingStrategy::kRandom) {
102       while (neighbors.size() < samples_num) {
103         RETURN_IF_NOT_OK(GetRandomSampledNeighbors(itr->second.first, samples_num - neighbors.size(), &neighbors));
104       }
105     } else if (strategy == SamplingStrategy::kEdgeWeight) {
106       RETURN_IF_NOT_OK(GetWeightSampledNeighbors(itr->second.first, itr->second.second, samples_num, &neighbors));
107     } else {
108       RETURN_STATUS_UNEXPECTED("Invalid strategy");
109     }
110   } else {
111     MS_LOG(DEBUG) << "There are no neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type;
112     // If there are no neighbors, they are filled with kDefaultNodeId
113     for (int32_t i = 0; i < samples_num; ++i) {
114       neighbors.emplace_back(kDefaultNodeId);
115     }
116   }
117   *out_neighbors = std::move(neighbors);
118   return Status::OK();
119 }
120 
AddNeighbor(const std::shared_ptr<Node> & node,const WeightType & weight)121 Status LocalNode::AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &weight) {
122   auto itr = neighbor_nodes_.find(node->type());
123   if (itr != neighbor_nodes_.end()) {
124     itr->second.first.push_back(node);
125     itr->second.second.push_back(weight);
126   } else {
127     std::vector<std::shared_ptr<Node>> nodes = {node};
128     std::vector<WeightType> weights = {weight};
129     neighbor_nodes_[node->type()] = std::make_pair(std::move(nodes), std::move(weights));
130   }
131   return Status::OK();
132 }
133 
AddAdjacent(const std::shared_ptr<Node> & node,const std::shared_ptr<Edge> & edge)134 Status LocalNode::AddAdjacent(const std::shared_ptr<Node> &node, const std::shared_ptr<Edge> &edge) {
135   auto node_id = node->id();
136   auto edge_id = edge->id();
137   adjacent_nodes_.insert({node_id, edge_id});
138   return Status::OK();
139 }
140 
GetEdgeByAdjNodeId(const NodeIdType & adj_node_id,EdgeIdType * out_edge_id)141 Status LocalNode::GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType *out_edge_id) {
142   auto itr = adjacent_nodes_.find(adj_node_id);
143   if (itr != adjacent_nodes_.end()) {
144     (*out_edge_id) = itr->second;
145   } else {
146     (*out_edge_id) = -1;
147     MS_LOG(WARNING) << "Number " << adj_node_id << " node is not adjacent to number " << this->id() << " node.";
148   }
149 
150   return Status::OK();
151 }
152 
UpdateFeature(const std::shared_ptr<Feature> & feature)153 Status LocalNode::UpdateFeature(const std::shared_ptr<Feature> &feature) {
154   auto itr = features_.find(feature->type());
155   if (itr != features_.end()) {
156     RETURN_STATUS_UNEXPECTED("Feature already exists");
157   } else {
158     features_[feature->type()] = feature;
159     return Status::OK();
160   }
161 }
162 
163 }  // namespace gnn
164 }  // namespace dataset
165 }  // namespace mindspore
166