1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17from tests.st.control.cases_register import case_register 18import mindspore.context as context 19import mindspore as ms 20from mindspore.ops import composite as C 21from mindspore.ops import functional as F 22from mindspore.ops import operations as P 23from mindspore import ops, Tensor, nn 24 25context.set_context(mode=context.GRAPH_MODE) 26 27 28@case_register.level0 29@case_register.target_gpu 30@case_register.target_ascend 31def test_dde_make_tuple_joined_with_tuple_output_primitive(): 32 """ 33 Feature: Eliminate unused element for tuple. 34 Description: Two branch return make tuple and tuple output node like top_k 35 Expectation: Correct result and no exception. 36 """ 37 38 @ms.jit 39 def topk_fun(x, k): 40 if k == 0: 41 output = (ms.ops.ones((0,), dtype=ms.float32), ms.ops.ones((0,), dtype=ms.int32)) 42 else: 43 output = ms.ops.topk(x, k, None, True, True) 44 return output 45 46 x = ms.tensor([1., 2., 3.]) 47 k = ms.tensor([0]) 48 out = topk_fun(x, k) 49 expect_out0 = ms.ops.ones((0,), dtype=ms.float32) 50 expect_out1 = ms.ops.ones((0,), dtype=ms.int32) 51 assert np.allclose(out[0].asnumpy(), expect_out0.asnumpy()) 52 assert np.allclose(out[1].asnumpy(), expect_out1.asnumpy()) 53 54 55@case_register.level0 56@case_register.target_gpu 57@case_register.target_ascend 58def test_dde_parameter_converted_to_value_tuple(): 59 """ 60 Feature: Eliminate unused element for tuple. 61 Description: The value_tuple is converted from the parameter which is not set sequence_nodes. 62 Expectation: Correct result and no exception. 63 """ 64 65 def _old_norm(norm_type, x): 66 out = F.pow((F.reduce_sum(F.pow(x, norm_type))), 1. / norm_type).astype(x.dtype) 67 return out 68 69 class ClipByNormFuncNet(nn.Cell): 70 def __init__(self, max_norm, norm_type=2.0, error_if_nonfinite=False): 71 super().__init__() 72 self.max_norm = max_norm 73 self.norm_type = norm_type 74 self.error_if_nonfinite = error_if_nonfinite 75 self.partial_op = P.Partial() 76 self.hyper_map = C.HyperMap() 77 78 def construct(self, *x): 79 is_tensor = False 80 if isinstance(x, Tensor): 81 x = [x] 82 is_tensor = True 83 total_norm = _old_norm(self.norm_type, 84 F.stack(self.hyper_map(self.partial_op(_old_norm, self.norm_type), x))) 85 clip_coef = self.max_norm / (total_norm + 1e-6) 86 if clip_coef < 1: 87 ret = self.hyper_map(self.partial_op(F.mul, clip_coef), x) 88 else: 89 ret = x 90 if is_tensor: 91 return ret[0] 92 return ret 93 94 class GradNetWrtX(nn.Cell): 95 def __init__(self, net): 96 super(GradNetWrtX, self).__init__() 97 self.net = net 98 self.grad_op = ops.GradOperation(sens_param=True) 99 100 def construct(self, *x): 101 gradient_function = self.grad_op(self.net) 102 return gradient_function(*x) 103 104 ms.set_context(mode=ms.GRAPH_MODE) 105 net = ClipByNormFuncNet(max_norm=1, norm_type=2, error_if_nonfinite=True) 106 net.set_train() 107 x = [ops.ones((2, 2)), ops.ones((2,))] 108 ms_output = net(*x) 109 output = GradNetWrtX(net)(*x, ms_output) 110 expect_out = np.array([[0.4082482, 0.4082482], [0.4082482, 0.4082482]]).astype(np.float32) 111 assert np.allclose(output.asnumpy(), expect_out) 112