• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test layer switch"""
16import numpy as np
17
18import mindspore
19from mindspore import nn
20from mindspore import Tensor
21from mindspore import context
22from mindspore.ops import operations as P
23
24class Layer1(nn.Cell):
25    def __init__(self):
26        super(Layer1, self).__init__()
27        self.net = nn.Conv2d(3, 1, 3, pad_mode='same')
28        self.pad = nn.Pad(
29            paddings=((0, 0), (0, 2), (0, 0), (0, 0)), mode="CONSTANT")
30
31    def construct(self, x):
32        y = self.net(x)
33        return self.pad(y)
34
35
36class Layer2(nn.Cell):
37    def __init__(self):
38        super(Layer2, self).__init__()
39        self.net = nn.Conv2d(3, 1, 7, pad_mode='same')
40        self.pad = nn.Pad(
41            paddings=((0, 0), (0, 2), (0, 0), (0, 0)), mode="CONSTANT")
42
43    def construct(self, x):
44        y = self.net(x)
45        return self.pad(y)
46
47
48class Layer3(nn.Cell):
49    def __init__(self):
50        super(Layer3, self).__init__()
51        self.net = nn.Conv2d(3, 3, 3, pad_mode='same')
52
53    def construct(self, x):
54        return self.net(x)
55
56
57class SwitchNet(nn.Cell):
58    def __init__(self):
59        super(SwitchNet, self).__init__()
60        self.layer1 = Layer1()
61        self.layer2 = Layer2()
62        self.layer3 = Layer3()
63        self.layers = (self.layer1, self.layer2, self.layer3)
64        self.fill = P.Fill()
65
66    def construct(self, x, index):
67        y = self.layers[index](x)
68        return y
69
70
71class MySwitchNet(nn.Cell):
72    def __init__(self):
73        super(MySwitchNet, self).__init__()
74        self.layer1 = Layer1()
75        self.layer2 = Layer2()
76        self.layer3 = Layer3()
77        self.layers = (self.layer1, self.layer2, self.layer3)
78        self.fill = P.Fill()
79
80    def construct(self, x, index):
81        y = self.layers[0](x)
82        for i in range(len(self.layers)):
83            if i == index:
84                y = self.layers[i](x)
85        return y
86
87
88def test_layer_switch():
89    context.set_context(mode=context.GRAPH_MODE)
90    net = MySwitchNet()
91    x = Tensor(np.ones((3, 3, 24, 24)), mindspore.float32)
92    index = Tensor(0, dtype=mindspore.int32)
93    net(x, index)
94
95class MySwitchNetPynative(nn.Cell):
96    def __init__(self):
97        super(MySwitchNetPynative, self).__init__()
98        self.layer1 = Layer1()
99        self.layer2 = Layer2()
100        self.layer3 = Layer3()
101        self.layers = (self.layer1, self.layer2, self.layer3)
102        self.fill = P.Fill()
103
104    def construct(self, x, index):
105        return self.layers[index](x)
106
107
108def test_layer_switch_pynative():
109    context.set_context(mode=context.PYNATIVE_MODE)
110    net = MySwitchNetPynative()
111    x = Tensor(np.ones((3, 3, 24, 24)), mindspore.float32)
112    index = Tensor(2, dtype=mindspore.int32)
113    net(x, index)
114    context.set_context(mode=context.GRAPH_MODE)
115