1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test jvp in pynative mode""" 16import numpy as np 17import pytest 18import mindspore.nn as nn 19import mindspore.context as context 20from mindspore import Tensor 21from mindspore.nn.grad import Vjp 22 23context.set_context(mode=context.PYNATIVE_MODE) 24 25 26class SingleInputNet(nn.Cell): 27 def construct(self, x): 28 return x**3 29 30 31class MultipleInputsOutputNet(nn.Cell): 32 def construct(self, x, y): 33 return 2*x, y**3 34 35 36def test_vjp_single_input_pynative(): 37 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 38 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 39 net = SingleInputNet() 40 Vjp(net)(x, v) 41 42 43def test_vjp_multiple_inputs_default_v_pynative(): 44 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 45 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 46 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 47 net = MultipleInputsOutputNet() 48 Vjp(net)(x, y, (v, v)) 49 50 51def test_vjp_wrong_input_v_pynative(): 52 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 53 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 54 net = SingleInputNet() 55 with pytest.raises(TypeError): 56 Vjp(net)(x, (v, v)) 57 58 59def test_vjp_wrong_input_v_2_pynative(): 60 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 61 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 62 net = SingleInputNet() 63 with pytest.raises(TypeError): 64 Vjp(net)(x, (v,)) 65 66 67def test_vjp_wrong_input_pynative(): 68 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 69 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 70 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 71 net = SingleInputNet() 72 with pytest.raises(TypeError): 73 Vjp(net)(x, y, v) 74 75 76def test_vjp_wrong_input_2_pynative(): 77 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 78 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 79 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 80 net = MultipleInputsOutputNet() 81 with pytest.raises(TypeError): 82 Vjp(net)((x, y), (v, v)) 83