• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from mindspore import nn
2from mindspore.common.tensor import Tensor
3from mindspore.ops import operations as P
4from mindspore.hypercomplex.utils import get_x_and_y, to_2channel
5import mindspore.hypercomplex.dual as ops
6
7
8class HCModel(nn.Cell):
9
10    def __init__(self):
11        super(HCModel, self).__init__()
12        self.conv1 = ops.Conv2d(1, 10, kernel_size=3)
13        self.bn1 = ops.BatchNorm2d(10)
14        self.max_pool = ops.MaxPool2d(2)
15        self.relu = ops.ReLU()
16        self.fc1 = ops.Dense(7290, 256)
17        self.fc2 = nn.Dense(512, 10)
18        self.concat = P.Concat(1)
19
20    def construct(self, u: Tensor) -> Tensor:
21        u = to_2channel(u[:, :1], u[:, 1:])
22        u = self.conv1(u)
23        u = self.bn1(u)
24        u = self.relu(u)
25        u = self.max_pool(u)
26        u = u.view(2, u.shape[1], -1)
27        u = self.fc1(u)
28        u = self.relu(u)
29        out_x, out_y = get_x_and_y(u)
30        out = self.concat([out_x, out_y])
31        out = self.fc2(out)
32        return out
33