1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import mindspore.nn as nn 16from mindspore.common import dtype 17from mindspore.ops import operations as P 18from mindspore.ops import prim_attr_register, PrimitiveWithInfer 19 20 21def get_add(a, b): 22 return a + b 23 24 25def get_f(v): 26 return v + 1 27 28 29relu = nn.ReLU() 30 31 32def get_relu(x): 33 return relu(x) 34 35 36softmax_cross_entropy_with_logits = P.SoftmaxCrossEntropyWithLogits() 37 38 39def get_softmax_cross_entropy_with_logits(logits, labels): 40 return softmax_cross_entropy_with_logits(logits, labels) 41 42 43class TensorToScalar(PrimitiveWithInfer): 44 """this is a test primitive for cases that has tensor input, but has only one scalar output""" 45 46 @prim_attr_register 47 def __init__(self): 48 """init""" 49 50 def __call__(self, logits, labels): 51 raise NotImplementedError 52 53 def infer_shape(self, logits_shape, label_shape): 54 return [] 55 56 def infer_dtype(self, logits_type, labels_type): 57 # pylint: disable=unused-argument 58 return dtype.float64 59 60 61tensorToScalar = TensorToScalar() 62 63 64def get_tensor_to_scalar(logits, labels): 65 return tensorToScalar(logits, labels) 66 67 68conv2d = P.Conv2D(64, 69 (3, 3), 70 pad_mode="pad", 71 pad=1, 72 stride=2) 73 74 75def get_conv2d(x, w): 76 return conv2d(x, w) 77 78 79conv2dNative = P.DepthwiseConv2dNative(3, (3, 3), pad_mode="pad", pad=1, stride=2) 80 81 82def get_conv2d_native(x, w): 83 return conv2dNative(x, w) 84 85 86biasAdd = P.BiasAdd() 87 88 89def get_bias_add(x, b): 90 return biasAdd(x, b) 91 92 93def test_conv2d(out_channel, kernel_size, pad, stride, dilation): 94 conv = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size, pad_mode="pad", pad=pad, 95 stride=stride, dilation=dilation) 96 97 def get_conv(x, w): 98 return conv(x, w) 99 100 return get_conv 101 102 103def test_dropout(): 104 dropOutGenMask = P.DropoutGenMask() 105 dropoutDoMask = P.DropoutDoMask() 106 shape = P.Shape() 107 108 def get_dropout(x, prob): 109 mask = dropOutGenMask(shape(x), prob) 110 y = dropoutDoMask(x, mask, prob) 111 return y 112 113 return get_dropout 114