• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16
17import mindspore as ms
18import mindspore.nn as nn
19from mindspore import Tensor
20from mindspore import context
21import mindspore.common.dtype as mstype
22from mindspore.common.seed import _get_graph_seed
23from mindspore.common.api import _cell_graph_executor
24from mindspore._checkparam import Validator
25from mindspore.ops.primitive import constexpr
26from mindspore.ops import composite as C
27from mindspore.ops import operations as P
28from tests.ut.python.ops.test_math_ops import VirtualLoss
29
30
31grad_all = C.GradOperation(get_all=True)
32
33
34class NetWithLoss(nn.Cell):
35    def __init__(self, network):
36        super(NetWithLoss, self).__init__()
37        self.loss = VirtualLoss()
38        self.network = network
39
40    def construct(self, x, y, b):
41        predict = self.network(x, y, b)
42        return self.loss(predict)
43
44
45class GradWrap(nn.Cell):
46    def __init__(self, network):
47        super(GradWrap, self).__init__()
48        self.network = network
49
50    def construct(self, x, y, b):
51        return grad_all(self.network)(x, y, b)
52
53
54@constexpr
55def _is_float_dtype(dtype):
56    if dtype in [mstype.float32, mstype.float16]:
57        return True
58    return False
59
60class Dropout(nn.Cell):
61    def __init__(self, keep_prob=0.5, dtype=mstype.float32):
62        super(Dropout, self).__init__()
63        if keep_prob <= 0 or keep_prob > 1:
64            raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
65        Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
66        Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
67        self.keep_prob = keep_prob
68        seed0, seed1 = _get_graph_seed(0, "dropout")
69        self.seed0 = seed0
70        self.seed1 = seed1
71        self.dtype = dtype
72        self.get_shape = P.Shape()
73        self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1)
74        self.dropout_do_mask = P.DropoutDoMask()
75        self.cast = P.Cast()
76        self.is_gpu = context.get_context('device_target') in ["GPU"]
77        self.dropout = P.Dropout(keep_prob)
78
79    def construct(self, x):
80        if not self.training:
81            return x
82
83        if self.is_gpu:
84            out, _ = self.dropout(x)
85            return out
86
87        if self.keep_prob == 1:
88            return x
89
90        shape = self.get_shape(x)
91        dtype = P.DType()(x)
92        if _is_float_dtype(dtype):
93            keep_prob = self.cast(self.keep_prob, dtype)
94        else:
95            keep_prob = self.cast(self.keep_prob, mstype.float16)
96        output = self.dropout_gen_mask(shape, keep_prob)
97        return self.dropout_do_mask(x, output, keep_prob)
98
99    def extend_repr(self):
100        return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype)
101
102# model_parallel test
103def test_two_matmul_dropout():
104    class Net(nn.Cell):
105        def __init__(self, strategy1, strategy2, strategy3):
106            super().__init__()
107            self.matmul1 = P.MatMul().shard(strategy1)
108            self.dropout = Dropout()
109            self.dropout.dropout_do_mask.shard(strategy2)
110            self.dropout.dropout_gen_mask.shard(strategy2)
111            self.matmul2 = P.MatMul().shard(strategy3)
112
113        def construct(self, x, y, b):
114            out = self.matmul1(x, y)
115            out = self.dropout(out)
116            out = self.matmul2(out, b)
117            return out
118
119    context.set_auto_parallel_context(device_num=8, global_rank=0)
120    strategy1 = ((4, 2), (2, 1))
121    strategy2 = ((8, 1),)
122    strategy3 = ((1, 8), (8, 1))
123    net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
124    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
125    net.set_auto_parallel()
126
127    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
128    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
129    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
130    net.set_train()
131    _cell_graph_executor.compile(net, x, y, b)
132