1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/gnn/local_node.h"
17
18 #include <algorithm>
19 #include <random>
20 #include <string>
21 #include <utility>
22
23 #include "minddata/dataset/engine/gnn/edge.h"
24 #include "minddata/dataset/util/random.h"
25
26 namespace mindspore {
27 namespace dataset {
28 namespace gnn {
29
LocalNode(NodeIdType id,NodeType type,WeightType weight)30 LocalNode::LocalNode(NodeIdType id, NodeType type, WeightType weight)
31 : Node(id, type, weight), rnd_(GetRandomDevice()) {
32 rnd_.seed(GetSeed());
33 }
34
GetFeatures(FeatureType feature_type,std::shared_ptr<Feature> * out_feature)35 Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
36 auto itr = features_.find(feature_type);
37 if (itr != features_.end()) {
38 *out_feature = itr->second;
39 return Status::OK();
40 } else {
41 std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
42 RETURN_STATUS_UNEXPECTED(err_msg);
43 }
44 }
45
GetAllNeighbors(NodeType neighbor_type,std::vector<NodeIdType> * out_neighbors,bool exclude_itself)46 Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors, bool exclude_itself) {
47 std::vector<NodeIdType> neighbors;
48 auto itr = neighbor_nodes_.find(neighbor_type);
49 if (itr != neighbor_nodes_.end()) {
50 if (exclude_itself) {
51 neighbors.resize(itr->second.first.size());
52 std::transform(itr->second.first.begin(), itr->second.first.end(), neighbors.begin(),
53 [](const std::shared_ptr<Node> node) { return node->id(); });
54 } else {
55 neighbors.resize(itr->second.first.size() + 1);
56 neighbors[0] = id_;
57 std::transform(itr->second.first.begin(), itr->second.first.end(), neighbors.begin() + 1,
58 [](const std::shared_ptr<Node> node) { return node->id(); });
59 }
60 } else {
61 MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type;
62 if (!exclude_itself) {
63 neighbors.emplace_back(id_);
64 }
65 }
66 *out_neighbors = std::move(neighbors);
67 return Status::OK();
68 }
69
GetRandomSampledNeighbors(const std::vector<std::shared_ptr<Node>> & neighbors,int32_t samples_num,std::vector<NodeIdType> * out)70 Status LocalNode::GetRandomSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
71 std::vector<NodeIdType> *out) {
72 std::vector<NodeIdType> shuffled_id(neighbors.size());
73 std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
74 std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
75 int32_t num = std::min(samples_num, static_cast<int32_t>(neighbors.size()));
76 for (int32_t i = 0; i < num; ++i) {
77 out->emplace_back(neighbors[shuffled_id[i]]->id());
78 }
79 return Status::OK();
80 }
81
GetWeightSampledNeighbors(const std::vector<std::shared_ptr<Node>> & neighbors,const std::vector<WeightType> & weights,int32_t samples_num,std::vector<NodeIdType> * out)82 Status LocalNode::GetWeightSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors,
83 const std::vector<WeightType> &weights, int32_t samples_num,
84 std::vector<NodeIdType> *out) {
85 CHECK_FAIL_RETURN_UNEXPECTED(neighbors.size() == weights.size(),
86 "The number of neighbors does not match the weight.");
87 std::discrete_distribution<NodeIdType> discrete_dist(weights.begin(), weights.end());
88 for (int32_t i = 0; i < samples_num; ++i) {
89 NodeIdType index = discrete_dist(rnd_);
90 out->emplace_back(neighbors[index]->id());
91 }
92 return Status::OK();
93 }
94
GetSampledNeighbors(NodeType neighbor_type,int32_t samples_num,SamplingStrategy strategy,std::vector<NodeIdType> * out_neighbors)95 Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy,
96 std::vector<NodeIdType> *out_neighbors) {
97 std::vector<NodeIdType> neighbors;
98 neighbors.reserve(samples_num);
99 auto itr = neighbor_nodes_.find(neighbor_type);
100 if (itr != neighbor_nodes_.end()) {
101 if (strategy == SamplingStrategy::kRandom) {
102 while (neighbors.size() < samples_num) {
103 RETURN_IF_NOT_OK(GetRandomSampledNeighbors(itr->second.first, samples_num - neighbors.size(), &neighbors));
104 }
105 } else if (strategy == SamplingStrategy::kEdgeWeight) {
106 RETURN_IF_NOT_OK(GetWeightSampledNeighbors(itr->second.first, itr->second.second, samples_num, &neighbors));
107 } else {
108 RETURN_STATUS_UNEXPECTED("Invalid strategy");
109 }
110 } else {
111 MS_LOG(DEBUG) << "There are no neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type;
112 // If there are no neighbors, they are filled with kDefaultNodeId
113 for (int32_t i = 0; i < samples_num; ++i) {
114 neighbors.emplace_back(kDefaultNodeId);
115 }
116 }
117 *out_neighbors = std::move(neighbors);
118 return Status::OK();
119 }
120
AddNeighbor(const std::shared_ptr<Node> & node,const WeightType & weight)121 Status LocalNode::AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &weight) {
122 auto itr = neighbor_nodes_.find(node->type());
123 if (itr != neighbor_nodes_.end()) {
124 itr->second.first.push_back(node);
125 itr->second.second.push_back(weight);
126 } else {
127 std::vector<std::shared_ptr<Node>> nodes = {node};
128 std::vector<WeightType> weights = {weight};
129 neighbor_nodes_[node->type()] = std::make_pair(std::move(nodes), std::move(weights));
130 }
131 return Status::OK();
132 }
133
AddAdjacent(const std::shared_ptr<Node> & node,const std::shared_ptr<Edge> & edge)134 Status LocalNode::AddAdjacent(const std::shared_ptr<Node> &node, const std::shared_ptr<Edge> &edge) {
135 auto node_id = node->id();
136 auto edge_id = edge->id();
137 adjacent_nodes_.insert({node_id, edge_id});
138 return Status::OK();
139 }
140
GetEdgeByAdjNodeId(const NodeIdType & adj_node_id,EdgeIdType * out_edge_id)141 Status LocalNode::GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType *out_edge_id) {
142 auto itr = adjacent_nodes_.find(adj_node_id);
143 if (itr != adjacent_nodes_.end()) {
144 (*out_edge_id) = itr->second;
145 } else {
146 (*out_edge_id) = -1;
147 MS_LOG(WARNING) << "Number " << adj_node_id << " node is not adjacent to number " << this->id() << " node.";
148 }
149
150 return Status::OK();
151 }
152
UpdateFeature(const std::shared_ptr<Feature> & feature)153 Status LocalNode::UpdateFeature(const std::shared_ptr<Feature> &feature) {
154 auto itr = features_.find(feature->type());
155 if (itr != features_.end()) {
156 RETURN_STATUS_UNEXPECTED("Feature already exists");
157 } else {
158 features_[feature->type()] = feature;
159 return Status::OK();
160 }
161 }
162
163 } // namespace gnn
164 } // namespace dataset
165 } // namespace mindspore
166