1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test layer switch""" 16import numpy as np 17 18import mindspore 19from mindspore import nn 20from mindspore import Tensor 21from mindspore import context 22from mindspore.ops import operations as P 23 24class Layer1(nn.Cell): 25 def __init__(self): 26 super(Layer1, self).__init__() 27 self.net = nn.Conv2d(3, 1, 3, pad_mode='same') 28 self.pad = nn.Pad( 29 paddings=((0, 0), (0, 2), (0, 0), (0, 0)), mode="CONSTANT") 30 31 def construct(self, x): 32 y = self.net(x) 33 return self.pad(y) 34 35 36class Layer2(nn.Cell): 37 def __init__(self): 38 super(Layer2, self).__init__() 39 self.net = nn.Conv2d(3, 1, 7, pad_mode='same') 40 self.pad = nn.Pad( 41 paddings=((0, 0), (0, 2), (0, 0), (0, 0)), mode="CONSTANT") 42 43 def construct(self, x): 44 y = self.net(x) 45 return self.pad(y) 46 47 48class Layer3(nn.Cell): 49 def __init__(self): 50 super(Layer3, self).__init__() 51 self.net = nn.Conv2d(3, 3, 3, pad_mode='same') 52 53 def construct(self, x): 54 return self.net(x) 55 56 57class SwitchNet(nn.Cell): 58 def __init__(self): 59 super(SwitchNet, self).__init__() 60 self.layer1 = Layer1() 61 self.layer2 = Layer2() 62 self.layer3 = Layer3() 63 self.layers = (self.layer1, self.layer2, self.layer3) 64 self.fill = P.Fill() 65 66 def construct(self, x, index): 67 y = self.layers[index](x) 68 return y 69 70 71class MySwitchNet(nn.Cell): 72 def __init__(self): 73 super(MySwitchNet, self).__init__() 74 self.layer1 = Layer1() 75 self.layer2 = Layer2() 76 self.layer3 = Layer3() 77 self.layers = (self.layer1, self.layer2, self.layer3) 78 self.fill = P.Fill() 79 80 def construct(self, x, index): 81 y = self.layers[0](x) 82 for i in range(len(self.layers)): 83 if i == index: 84 y = self.layers[i](x) 85 return y 86 87 88def test_layer_switch(): 89 context.set_context(mode=context.GRAPH_MODE) 90 net = MySwitchNet() 91 x = Tensor(np.ones((3, 3, 24, 24)), mindspore.float32) 92 index = Tensor(0, dtype=mindspore.int32) 93 net(x, index) 94 95class MySwitchNetPynative(nn.Cell): 96 def __init__(self): 97 super(MySwitchNetPynative, self).__init__() 98 self.layer1 = Layer1() 99 self.layer2 = Layer2() 100 self.layer3 = Layer3() 101 self.layers = (self.layer1, self.layer2, self.layer3) 102 self.fill = P.Fill() 103 104 def construct(self, x, index): 105 return self.layers[index](x) 106 107 108def test_layer_switch_pynative(): 109 context.set_context(mode=context.PYNATIVE_MODE) 110 net = MySwitchNetPynative() 111 x = Tensor(np.ones((3, 3, 24, 24)), mindspore.float32) 112 index = Tensor(2, dtype=mindspore.int32) 113 net(x, index) 114 context.set_context(mode=context.GRAPH_MODE) 115