• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from mindspore import nn
2from mindspore.common.tensor import Tensor
3from mindspore.ops import operations as P
4import mindspore.hypercomplex.dual as ops
5
6
7class DeepConvNet(nn.Cell):
8    def __init__(self):
9        super(DeepConvNet, self).__init__()
10
11        self.conv1 = ops.Conv1d(1, 16, kernel_size=6, stride=2, padding=2, pad_mode='pad')
12        self.bn1 = ops.BatchNorm1d(16)
13        self.avg_pool1 = ops.AvgPool1d(kernel_size=2, stride=2)
14        self.pad1 = nn.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, 2)), mode='CONSTANT')
15
16        self.conv2 = ops.Conv1d(16, 32, kernel_size=3, stride=2, padding=0)
17        self.bn2 = ops.BatchNorm1d(32)
18        self.avg_pool2 = ops.AvgPool1d(kernel_size=2, stride=2)
19
20        self.conv3 = ops.Conv1d(32, 64, kernel_size=3, stride=1, padding=1, pad_mode='pad')
21        self.bn3 = ops.BatchNorm1d(64)
22        self.avg_pool3 = ops.AvgPool1d(kernel_size=2, stride=2)
23
24        self.conv4 = ops.Conv1d(64, 64, kernel_size=3, stride=1, padding=1, pad_mode='pad')
25        self.bn4 = ops.BatchNorm1d(64)
26        self.avg_pool4 = ops.AvgPool1d(kernel_size=2, stride=2)
27
28        self.conv5 = ops.Conv1d(64, 128, kernel_size=3, stride=1, padding=1, pad_mode='pad')
29        self.conv6 = ops.Conv1d(128, 128, kernel_size=3, stride=1, padding=1, pad_mode='pad')
30        self.bn6 = ops.BatchNorm1d(128)
31        self.avg_pool6 = ops.AvgPool1d(kernel_size=2, stride=2)
32
33        self.shape_op = P.Shape()
34        self.reshape = P.Reshape()
35        self.permute = P.Transpose()
36        self.flatten = P.Flatten()
37
38        self.fc1 = ops.Dense(4096, 1024)
39        self.fc2 = nn.Dense(2048, 84)
40
41        self.relu = ops.ReLU()
42        self.sigmoid = nn.Sigmoid()
43
44    def construct(self, u: Tensor) -> Tensor:
45        u = self.conv1(u)
46        u = self.bn1(u)
47        u = self.relu(u)
48        u = self.avg_pool1(u)
49        u = self.pad1(u)
50
51        u = self.conv2(u)
52        u = self.bn2(u)
53        u = self.relu(u)
54        u = self.avg_pool2(u)
55
56        u = self.conv3(u)
57        u = self.bn3(u)
58        u = self.relu(u)
59        u = self.avg_pool3(u)
60
61        u = self.conv4(u)
62        u = self.bn4(u)
63        u = self.relu(u)
64        u = self.avg_pool4(u)
65
66        u = self.conv5(u)
67        u = self.relu(u)
68
69        u = self.conv6(u)
70        u = self.bn6(u)
71        u = self.relu(u)
72        u = self.avg_pool6(u)
73
74        u_shape = self.shape_op(u)
75        u = self.reshape(u, (u_shape[0], u_shape[1], -1))
76        u = self.fc1(u)
77        u = self.relu(u)
78
79        u = self.permute(u, (1, 0, 2))
80        x = self.flatten(u)
81        x = self.fc2(x)
82        x = self.sigmoid(x)
83        return x
84