1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import mindspore.context as context 18from mindspore import Tensor 19import mindspore.nn as nn 20from mindspore.nn import Cell 21from mindspore.ops import operations as P 22import mindspore.ops.functional as F 23import pytest 24 25context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 26# enable graph kernel optimization. 27context.set_context(enable_graph_kernel=True) 28 29 30class BertAttentionPiece(Cell): 31 def __init__(self): 32 super(BertAttentionPiece, self).__init__() 33 self.add = P.Add() 34 self.dropout = nn.Dropout(1 - 0.1) 35 self.softmax = nn.Softmax() 36 self.multiply_data = -10000.0 37 self.sub = P.Sub() 38 self.multiply = P.Mul() 39 self.get_dtype = P.DType() 40 self.cast = P.Cast() 41 42 def construct(self, attention_mask, attention_scores): 43 multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)), 44 self.cast(attention_mask, self.get_dtype(attention_scores))) 45 adder = self.multiply(multiply_out, self.multiply_data) 46 attention_scores = self.add(adder, attention_scores) 47 attention_probs = self.softmax(attention_scores) 48 attention_probs = self.dropout(attention_probs) 49 return attention_probs 50 51 52def get_rtol_atol(dtype): 53 if dtype == np.float16: 54 return 1.e-3, 1.e-3 55 return 1.e-4, 1.e-4 56 57 58def compare_result(expect, output, dtype): 59 rtol, atol = get_rtol_atol(dtype) 60 if isinstance(expect, (list, tuple)): 61 assert isinstance(output, (list, tuple)) and len(expect) == len(output) 62 expect_list = list(expect) 63 output_list = list(output) 64 for e, o in zip(expect_list, output_list): 65 assert np.allclose(e.asnumpy(), o.asnumpy(), rtol, atol, equal_nan=True) 66 else: 67 assert np.allclose(expect.asnumpy(), output.asnumpy(), rtol, atol, equal_nan=True) 68 69 70def get_softmax_output(x, y, enable_stitch_fusion): 71 # enable graph kernel stitch fusion. 72 if enable_stitch_fusion: 73 context.set_context(graph_kernel_flags="--enable_stitch_fusion=true") 74 net = BertAttentionPiece() 75 result = net(x, y) 76 return result 77 78 79def test_softmax(shape, dtype): 80 np.random.seed(0) 81 x = Tensor(np.random.normal(0, 1, shape).astype(dtype)) 82 y = Tensor(np.random.normal(0, 1, shape).astype(dtype)) 83 expect = get_softmax_output(x, y, False) 84 output = get_softmax_output(x, y, True) 85 compare_result(expect, output, dtype) 86 87 88@pytest.mark.level0 89@pytest.mark.platform_x86_gpu_training 90@pytest.mark.env_onecard 91def test_softmax_gpu(): 92 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 93 test_softmax([64, 12, 128, 128], np.float16) 94