1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_partial """ 16import mindspore.ops.composite as C 17import mindspore.ops.functional as F 18from mindspore import context 19from mindspore.common.api import ms_function 20 21 22def setup_module(module): 23 context.set_context(mode=context.PYNATIVE_MODE) 24 25 26def myadd(x, y): 27 return x + y 28 29 30def partial_simple_add(x): 31 return F.partial(myadd, x) 32 33 34@ms_function 35def full_simple_add(x, y): 36 p = partial_simple_add(x) 37 return p(y) 38 39 40def test_full_simple_add(): 41 print(full_simple_add(2, 5)) 42 43 44# partial with multitype 45MULTI_ADD = C.MultitypeFuncGraph('add') 46 47 48@MULTI_ADD.register("Int64", "Int64") 49def add_int(x, y): 50 return F.scalar_add(x, y) 51 52 53@MULTI_ADD.register("Float32", "Float32") 54def add_float(x, y): 55 return F.scalar_add(x, y) 56 57 58def partial_multi_add(x): 59 return F.partial(MULTI_ADD, x) 60 61 62@ms_function 63def full_multi_add(x, y, m, n): 64 p = partial_multi_add(x)(y) 65 q = partial_multi_add(m)(n) 66 return p, q 67 68 69def test_full_multi_add(): 70 print(full_multi_add(1, 2, 1.0, 2.0)) 71