1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Architecture""" 16import os 17import numpy as np 18import pytest 19 20import mindspore.nn as nn 21from mindspore import Parameter, Tensor, context 22from mindspore.ops import operations as P 23from mindspore.common import dtype as mstype 24from mindspore.common.initializer import initializer 25from mindspore.train.serialization import export 26 27context.set_context(mode=context.PYNATIVE_MODE) 28 29 30class MeanConv(nn.Cell): 31 def __init__(self, 32 feature_in_dim, 33 feature_out_dim, 34 activation, 35 dropout=0.2): 36 super(MeanConv, self).__init__() 37 self.out_weight = Parameter( 38 initializer("XavierUniform", [feature_in_dim * 2, feature_out_dim], dtype=mstype.float32)) 39 if activation == "tanh": 40 self.act = P.Tanh() 41 elif activation == "relu": 42 self.act = P.ReLU() 43 else: 44 raise ValueError("activation should be tanh or relu") 45 self.cast = P.Cast() 46 self.matmul = P.MatMul() 47 self.concat = P.Concat(axis=1) 48 self.reduce_mean = P.ReduceMean(keep_dims=False) 49 self.dropout = nn.Dropout(keep_prob=1 - dropout) 50 51 def construct(self, self_feature, neigh_feature): 52 neigh_matrix = self.reduce_mean(neigh_feature, 1) 53 neigh_matrix = self.dropout(neigh_matrix) 54 output = self.concat((self_feature, neigh_matrix)) 55 output = self.act(self.matmul(output, self.out_weight)) 56 return output 57 58 59class AttenConv(nn.Cell): 60 def __init__(self, 61 feature_in_dim, 62 feature_out_dim, 63 dropout=0.2): 64 super(AttenConv, self).__init__() 65 self.out_weight = Parameter( 66 initializer("XavierUniform", [feature_in_dim * 2, feature_out_dim], dtype=mstype.float32)) 67 self.cast = P.Cast() 68 self.squeeze = P.Squeeze(1) 69 self.concat = P.Concat(axis=1) 70 self.expanddims = P.ExpandDims() 71 self.softmax = P.Softmax(axis=-1) 72 self.matmul = P.MatMul() 73 self.matmul_3 = P.BatchMatMul() 74 self.matmul_t = P.BatchMatMul(transpose_b=True) 75 self.dropout = nn.Dropout(keep_prob=1 - dropout) 76 77 def construct(self, self_feature, neigh_feature): 78 query = self.expanddims(self_feature, 1) 79 neigh_matrix = self.dropout(neigh_feature) 80 score = self.matmul_t(query, neigh_matrix) 81 score = self.softmax(score) 82 atten_agg = self.matmul_3(score, neigh_matrix) 83 atten_agg = self.squeeze(atten_agg) 84 output = self.matmul(self.concat((atten_agg, self_feature)), self.out_weight) 85 return output 86 87 88class BGCF(nn.Cell): 89 def __init__(self, 90 dataset_argv, 91 architect_argv, 92 activation, 93 neigh_drop_rate, 94 num_user, 95 num_item, 96 input_dim): 97 super(BGCF, self).__init__() 98 self.user_embed = Parameter(initializer("XavierUniform", [num_user, input_dim], dtype=mstype.float32)) 99 self.item_embed = Parameter(initializer("XavierUniform", [num_item, input_dim], dtype=mstype.float32)) 100 self.cast = P.Cast() 101 self.tanh = P.Tanh() 102 self.shape = P.Shape() 103 self.split = P.Split(0, 2) 104 self.gather = P.Gather() 105 self.reshape = P.Reshape() 106 self.concat_0 = P.Concat(0) 107 self.concat_1 = P.Concat(1) 108 (self.input_dim, self.num_user, self.num_item) = dataset_argv 109 self.layer_dim = architect_argv 110 self.gnew_agg_mean = MeanConv(self.input_dim, self.layer_dim, 111 activation=activation, dropout=neigh_drop_rate[1]) 112 self.gnew_agg_mean.to_float(mstype.float16) 113 self.gnew_agg_user = AttenConv(self.input_dim, self.layer_dim, dropout=neigh_drop_rate[2]) 114 self.gnew_agg_user.to_float(mstype.float16) 115 self.gnew_agg_item = AttenConv(self.input_dim, self.layer_dim, dropout=neigh_drop_rate[2]) 116 self.gnew_agg_item.to_float(mstype.float16) 117 self.user_feature_dim = self.input_dim 118 self.item_feature_dim = self.input_dim 119 self.final_weight = Parameter( 120 initializer("XavierUniform", [self.input_dim * 3, self.input_dim * 3], dtype=mstype.float32)) 121 self.raw_agg_funcs_user = MeanConv(self.input_dim, self.layer_dim, 122 activation=activation, dropout=neigh_drop_rate[0]) 123 self.raw_agg_funcs_user.to_float(mstype.float16) 124 self.raw_agg_funcs_item = MeanConv(self.input_dim, self.layer_dim, 125 activation=activation, dropout=neigh_drop_rate[0]) 126 self.raw_agg_funcs_item.to_float(mstype.float16) 127 128 def construct(self, 129 u_id, 130 pos_item_id, 131 neg_item_id, 132 pos_users, 133 pos_items, 134 u_group_nodes, 135 u_neighs, 136 u_gnew_neighs, 137 i_group_nodes, 138 i_neighs, 139 i_gnew_neighs, 140 neg_group_nodes, 141 neg_neighs, 142 neg_gnew_neighs, 143 neg_item_num): 144 all_user_embed = self.gather(self.user_embed, self.concat_0((u_id, pos_users)), 0) 145 u_self_matrix_at_layers = self.gather(self.user_embed, u_group_nodes, 0) 146 u_neigh_matrix_at_layers = self.gather(self.item_embed, u_neighs, 0) 147 u_output_mean = self.raw_agg_funcs_user(u_self_matrix_at_layers, u_neigh_matrix_at_layers) 148 u_gnew_neighs_matrix = self.gather(self.item_embed, u_gnew_neighs, 0) 149 u_output_from_gnew_mean = self.gnew_agg_mean(u_self_matrix_at_layers, u_gnew_neighs_matrix) 150 u_output_from_gnew_att = self.gnew_agg_user(u_self_matrix_at_layers, 151 self.concat_1((u_neigh_matrix_at_layers, u_gnew_neighs_matrix))) 152 u_output = self.concat_1((u_output_mean, u_output_from_gnew_mean, u_output_from_gnew_att)) 153 all_user_rep = self.tanh(u_output) 154 all_pos_item_embed = self.gather(self.item_embed, self.concat_0((pos_item_id, pos_items)), 0) 155 i_self_matrix_at_layers = self.gather(self.item_embed, i_group_nodes, 0) 156 i_neigh_matrix_at_layers = self.gather(self.user_embed, i_neighs, 0) 157 i_output_mean = self.raw_agg_funcs_item(i_self_matrix_at_layers, i_neigh_matrix_at_layers) 158 i_gnew_neighs_matrix = self.gather(self.user_embed, i_gnew_neighs, 0) 159 i_output_from_gnew_mean = self.gnew_agg_mean(i_self_matrix_at_layers, i_gnew_neighs_matrix) 160 i_output_from_gnew_att = self.gnew_agg_item(i_self_matrix_at_layers, 161 self.concat_1((i_neigh_matrix_at_layers, i_gnew_neighs_matrix))) 162 i_output = self.concat_1((i_output_mean, i_output_from_gnew_mean, i_output_from_gnew_att)) 163 all_pos_item_rep = self.tanh(i_output) 164 neg_item_embed = self.gather(self.item_embed, neg_item_id, 0) 165 neg_self_matrix_at_layers = self.gather(self.item_embed, neg_group_nodes, 0) 166 neg_neigh_matrix_at_layers = self.gather(self.user_embed, neg_neighs, 0) 167 neg_output_mean = self.raw_agg_funcs_item(neg_self_matrix_at_layers, neg_neigh_matrix_at_layers) 168 neg_gnew_neighs_matrix = self.gather(self.user_embed, neg_gnew_neighs, 0) 169 neg_output_from_gnew_mean = self.gnew_agg_mean(neg_self_matrix_at_layers, neg_gnew_neighs_matrix) 170 neg_output_from_gnew_att = self.gnew_agg_item(neg_self_matrix_at_layers, 171 self.concat_1( 172 (neg_neigh_matrix_at_layers, neg_gnew_neighs_matrix))) 173 neg_output = self.concat_1((neg_output_mean, neg_output_from_gnew_mean, neg_output_from_gnew_att)) 174 neg_output = self.tanh(neg_output) 175 neg_output_shape = self.shape(neg_output) 176 neg_item_rep = self.reshape(neg_output, 177 (self.shape(neg_item_embed)[0], neg_item_num, neg_output_shape[-1])) 178 179 return all_user_embed, all_user_rep, all_pos_item_embed, all_pos_item_rep, neg_item_embed, neg_item_rep 180 181 182class ForwardBGCF(nn.Cell): 183 def __init__(self, 184 network): 185 super(ForwardBGCF, self).__init__() 186 self.network = network 187 188 def construct(self, users, items, neg_items, u_neighs, u_gnew_neighs, i_neighs, i_gnew_neighs): 189 _, user_rep, _, item_rep, _, _, = self.network(users, items, neg_items, users, items, users, 190 u_neighs, u_gnew_neighs, items, i_neighs, i_gnew_neighs, 191 items, i_neighs, i_gnew_neighs, 1) 192 return user_rep, item_rep 193 194@pytest.mark.level0 195@pytest.mark.platform_x86_ascend_training 196@pytest.mark.platform_arm_ascend_training 197@pytest.mark.env_onecard 198def test_export_bgcf(): 199 num_user, num_item = 7068, 3570 200 network = BGCF([64, num_user, num_item], 64, "tanh", 201 [0.0, 0.0, 0.0], num_user, num_item, 64) 202 203 forward_net = ForwardBGCF(network) 204 users = Tensor(np.zeros([num_user,]).astype(np.int32)) 205 items = Tensor(np.zeros([num_item,]).astype(np.int32)) 206 neg_items = Tensor(np.zeros([num_item, 1]).astype(np.int32)) 207 u_test_neighs = Tensor(np.zeros([num_user, 40]).astype(np.int32)) 208 u_test_gnew_neighs = Tensor(np.zeros([num_user, 20]).astype(np.int32)) 209 i_test_neighs = Tensor(np.zeros([num_item, 40]).astype(np.int32)) 210 i_test_gnew_neighs = Tensor(np.zeros([num_item, 20]).astype(np.int32)) 211 input_data = [users, items, neg_items, u_test_neighs, u_test_gnew_neighs, i_test_neighs, i_test_gnew_neighs] 212 file_name = "bgcf" 213 export(forward_net, *input_data, file_name=file_name, file_format="MINDIR") 214 mindir_file = file_name + ".mindir" 215 assert os.path.exists(mindir_file) 216 os.remove(mindir_file) 217 export(forward_net, *input_data, file_name=file_name, file_format="AIR") 218 air_file = file_name + ".air" 219 assert os.path.exists(air_file) 220 os.remove(air_file) 221