• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test Activations """
16import functools
17import numpy as np
18
19import mindspore.nn as nn
20from mindspore.ops import operations as P
21from ....mindspore_test_framework.mindspore_test import mindspore_test
22from ....mindspore_test_framework.pipeline.forward.compile_forward \
23    import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
24from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
25    import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
26from ....ops_common import convert
27
28
29class SeqConvBnRelu(nn.Cell):
30    """ SeqConvBnRelu definition """
31
32    def __init__(self, in_ch, out_ch):
33        super(SeqConvBnRelu, self).__init__()
34        self.conv = nn.Conv2d(in_ch, out_ch, 3)
35        self.bn = nn.BatchNorm2d(out_ch)
36        self.relu = P.ReLU()
37
38    def construct(self, input_x):
39        return self.relu(self.bn(self.conv(input_x)))
40
41
42test_case_reid_ops = [
43    ('ReduceMax', {
44        'block': P.ReduceMax(keep_dims=False),
45        'desc_const': [(1,)],
46        'desc_inputs': [convert([32, 32], np.float16)],
47        'desc_bprop': [convert([32], np.float16)],
48        'skip': []}),
49    ('ReduceMin', {
50        'block': P.ReduceMin(),
51        'desc_const': [(1,)],
52        'desc_inputs': [[32, 32]],
53        'desc_bprop': [[32]],
54        'skip': []}),
55    ('ReduceMean', {
56        'block': P.ReduceMean(keep_dims=True),
57        'desc_const': [(1, 2)],
58        'desc_inputs': [[32, 4, 4]],
59        'desc_bprop': [[32, 1, 1]]}),
60    ('Log', {
61        'block': P.Log(),
62        'desc_inputs': [[4, 128, 1024]],
63        'desc_bprop': [[4, 128, 1024]],
64        'skip': ['backward']}),  # check backward error
65    ('Reciprocal', {
66        'block': P.Reciprocal(),
67        'desc_inputs': [[4, 128, 1024]],
68        'desc_bprop': [[4, 128, 1024]],
69        'skip': ['backward']}),
70    ('FloorDiv', {
71        'block': P.FloorDiv(),
72        'desc_inputs': [[4, 128, 1024], [4, 128, 1024]],
73        'desc_bprop': [[4, 128, 1024]]}),
74    ('Sigmoid', {
75        'block': P.Sigmoid(),
76        'desc_inputs': [[4, 128, 1024]],
77        'desc_bprop': [[4, 128, 1024]]}),
78    ('Softmax', {
79        'block': P.Softmax(),
80        'desc_inputs': [[1, 16]],
81        'desc_bprop': [[1, 16]],
82        'skip': ['backward']}),  # check backward error
83    ('Softmax', {
84        'block': P.Softmax(axis=(0, 1)),
85        'desc_inputs': [[1, 16]],
86        'desc_bprop': [[1, 16]],
87        'skip': ['backward']}),
88    ('L2Normalize', {
89        'block': P.L2Normalize(),
90        'desc_inputs': [[4, 128, 1024]],
91        'desc_bprop': [[4, 128, 1024]]}),
92    ('ReLU', {
93        'block': P.ReLU(),
94        'desc_inputs': [[64, 64, 112, 112]],
95        'desc_bprop': [[64, 64, 112, 112]]}),
96    ('SeqConvBnRelu', {
97        'block': SeqConvBnRelu(3, 64),
98        'desc_inputs': [[64, 3, 112, 112]],
99        'desc_bprop': [[64, 64, 112, 112]]}),
100    ('PReluCell', {
101        'block': nn.PReLU(1, [np.float32(0.25)]),
102        'desc_inputs': [[128, 64, 112, 112]],
103        'desc_bprop': [[128, 64, 112, 112]]}),
104    ('PRelu', {
105        'block': P.PReLU(),
106        'desc_inputs': [[128, 64, 112, 112], [64,]],
107        'desc_bprop': [[128, 64, 112, 112]]}),
108    ('Cos', {
109        'block': P.Cos(),
110        'desc_inputs': [[8, 16]],
111        'desc_bprop': [[8, 16]]}),
112    ('ACos', {
113        'block': P.ACos(),
114        'desc_inputs': [[8, 16]],
115        'desc_bprop': [[8, 16]]}),
116    ('Exp', {
117        'block': P.Exp(),
118        'desc_inputs': [[256, 8]],
119        'desc_bprop': [[256, 8]]}),
120    ('Pow', {
121        'block': P.Pow(),
122        'desc_const': [2.0],
123        'desc_inputs': [[1, 512]],
124        'desc_bprop': [[1, 512]]}),
125    ('LogicalNot', {
126        'block': P.LogicalNot(),
127        'desc_inputs': [convert([256], np.bool_)],
128        'desc_bprop': [convert([256], np.bool_)]}),
129    ('Equal', {
130        'block': P.Equal(),
131        'desc_inputs': [convert([256], np.float16), convert([256], np.float16)],
132        'desc_bprop': [convert([256], np.bool_)]}),
133    ('Greater', {
134        'block': P.Greater(),
135        'desc_inputs': [convert([256], np.float16), convert([256], np.float16)],
136        'desc_bprop': [convert([256], np.bool_)]}),
137    ('Dropout', {
138        'block': nn.Dropout(),
139        'desc_inputs': [[1, 512, 7, 7]],
140        'desc_bprop': [[1, 512, 7, 7]]}),
141    ('MatMul', {
142        'block': P.MatMul(),
143        'desc_inputs': [[64, 512], [512, 64]],
144        'desc_bprop': [[64, 64]]}),
145    ('Maximum', {
146        'block': P.Maximum(),
147        'desc_inputs': [[64, 1], [64, 1]],
148        'desc_bprop': [[64, 1]]}),
149]
150
151test_case_lists = [test_case_reid_ops]
152test_case = functools.reduce(lambda x, y: x + y, test_case_lists)
153# use -k to select certain testcast
154# pytest  tests/python/ops/test_ops.py::test_backward -k LayerNorm
155
156
157test_exec_case = filter(lambda x: 'skip' not in x[1] or
158                        'exec' not in x[1]['skip'], test_case)
159
160test_backward_exec_case = filter(lambda x: 'skip' not in x[1] or
161                                 'backward' not in x[1]['skip'] and 'backward_exec'
162                                 not in x[1]['skip'], test_case)
163
164
165@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
166def test_exec():
167    return test_exec_case
168
169
170@mindspore_test(pipeline_for_compile_grad_ge_graph_for_case_by_case_config)
171def test_backward_exec():
172    return test_backward_exec_case
173