• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import mindspore.nn as nn
16from mindspore.common import dtype
17from mindspore.ops import operations as P
18from mindspore.ops import prim_attr_register, PrimitiveWithInfer
19
20
21def get_add(a, b):
22    return a + b
23
24
25def get_f(v):
26    return v + 1
27
28
29relu = nn.ReLU()
30
31
32def get_relu(x):
33    return relu(x)
34
35
36softmax_cross_entropy_with_logits = P.SoftmaxCrossEntropyWithLogits()
37
38
39def get_softmax_cross_entropy_with_logits(logits, labels):
40    return softmax_cross_entropy_with_logits(logits, labels)
41
42
43class TensorToScalar(PrimitiveWithInfer):
44    """this is a test primitive for cases that has tensor input, but has only one scalar output"""
45
46    @prim_attr_register
47    def __init__(self):
48        """init"""
49
50    def __call__(self, logits, labels):
51        raise NotImplementedError
52
53    def infer_shape(self, logits_shape, label_shape):
54        return []
55
56    def infer_dtype(self, logits_type, labels_type):
57        # pylint: disable=unused-argument
58        return dtype.float64
59
60
61tensorToScalar = TensorToScalar()
62
63
64def get_tensor_to_scalar(logits, labels):
65    return tensorToScalar(logits, labels)
66
67
68conv2d = P.Conv2D(64,
69                  (3, 3),
70                  pad_mode="pad",
71                  pad=1,
72                  stride=2)
73
74
75def get_conv2d(x, w):
76    return conv2d(x, w)
77
78
79conv2dNative = P.DepthwiseConv2dNative(3, (3, 3), pad_mode="pad", pad=1, stride=2)
80
81
82def get_conv2d_native(x, w):
83    return conv2dNative(x, w)
84
85
86biasAdd = P.BiasAdd()
87
88
89def get_bias_add(x, b):
90    return biasAdd(x, b)
91
92
93def test_conv2d(out_channel, kernel_size, pad, stride, dilation):
94    conv = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size, pad_mode="pad", pad=pad,
95                    stride=stride, dilation=dilation)
96
97    def get_conv(x, w):
98        return conv(x, w)
99
100    return get_conv
101
102
103def test_dropout():
104    dropOutGenMask = P.DropoutGenMask()
105    dropoutDoMask = P.DropoutDoMask()
106    shape = P.Shape()
107
108    def get_dropout(x, prob):
109        mask = dropOutGenMask(shape(x), prob)
110        y = dropoutDoMask(x, mask, prob)
111        return y
112
113    return get_dropout
114