1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16from typing import List 17from random import sample 18import mindspore.context as context 19import mindspore.nn as nn 20from mindspore import Tensor 21from mindspore.ops import PrimitiveWithInfer, prim_attr_register 22from mindspore._checkparam import Validator as validator 23from mindspore.common import dtype as mstype 24import numpy as np 25import pandas as pd 26import pytest 27 28context.set_context(mode=context.GRAPH_MODE, device_target="CPU") 29 30 31class Rank(PrimitiveWithInfer): 32 """ 33 Shift op frontend implementation 34 """ 35 36 # size_t axis_{0}; 37 # rank::Method method_{rank::MethodNotDefined}; 38 # rank::NaOption option_{rank::OptionNotDefined}; 39 # bool ascending_{true}; 40 # bool pct_{false}; 41 @prim_attr_register 42 def __init__(self, axis: int, method: str, na_option: str, ascending: bool, pct: bool): 43 """Initialize Sort""" 44 self.axis = validator.check_value_type("axis", axis, [int], self.name) 45 self.method = validator.check_value_type("method", method, [str], self.name) 46 self.na_option = validator.check_value_type("na_option", na_option, [str], self.name) 47 self.ascending = validator.check_value_type("ascending", ascending, [bool], self.name) 48 self.pct = validator.check_value_type("pct", pct, [bool], self.name) 49 50 self.init_prim_io_names(inputs=['x'], outputs=['output']) 51 52 def __infer__(self, x): 53 out_shapes = x['shape'] 54 return { 55 'shape': tuple(out_shapes), 56 'dtype': mstype.float32, 57 'value': None 58 } 59 60 61class RankNet(nn.Cell): 62 def __init__(self, axis: int, method: str, na_option: str, ascending: bool, pct: bool): 63 super(RankNet, self).__init__() 64 self.rank = Rank(axis, method, na_option, ascending, pct) 65 66 def construct(self, x): 67 return self.rank(x) 68 69 70def pandas_rank(arr, **kwargs): 71 ser = pd.DataFrame(arr) 72 result = ser.rank(**kwargs) 73 return result.to_numpy() 74 75 76@pytest.mark.parametrize('shape', [(10,)]) 77@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64]) 78@pytest.mark.parametrize('method', ['dense', 'first', 'max', 'min', 'average']) 79@pytest.mark.parametrize('na_option', ["keep", "top", "bottom"]) 80@pytest.mark.parametrize('ascending', [True, False]) 81@pytest.mark.parametrize('pct', [False, True]) 82def test_rank_1d(shape: List[int], dtype, method: str, ascending: bool, pct: bool, na_option: str): 83 np.random.seed(0) 84 85 if dtype in (np.int32, np.int64): 86 arr = np.random.randint(0, 100, size=shape).astype(dtype) 87 else: 88 arr = np.random.random(size=shape).astype(dtype) 89 arr.flat[sample(range(arr.size), int(arr.size / 10))] = np.nan 90 91 pd_result = pandas_rank(arr, method=method, ascending=ascending, pct=pct, na_option=na_option).flatten() 92 rank = RankNet(0, method=method, ascending=ascending, pct=pct, na_option=na_option) 93 mind_result = rank(Tensor(arr)).asnumpy() 94 95 print('arr: \n', arr, arr.dtype, arr.shape) 96 print('pandas: \n', pd_result, pd_result.dtype, pd_result.shape) 97 print('mind: \n', mind_result, mind_result.dtype, mind_result.shape) 98 print(f'method: {method}, ascending: {ascending}, pct: {pct} na_option: {na_option}') 99 assert np.allclose(pd_result, mind_result, equal_nan=True) 100 101 102@pytest.mark.parametrize('shape', [(5, 6)]) 103@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64]) 104@pytest.mark.parametrize('method', ['dense', 'first', 'max', 'min', 'average']) 105@pytest.mark.parametrize('na_option', ["keep", "top", "bottom"]) 106@pytest.mark.parametrize('axis', [0, 1]) 107@pytest.mark.parametrize('ascending', [True, False]) 108@pytest.mark.parametrize('pct', [False, True]) 109def test_rank_2d(shape: List[int], dtype, method: str, ascending: bool, pct: bool, axis: int, na_option: str): 110 np.random.seed(0) 111 112 if dtype in (np.int32, np.int64): 113 arr = np.random.randint(0, 100, size=shape).astype(dtype) 114 else: 115 arr = np.random.random(size=shape).astype(dtype) 116 arr.flat[sample(range(arr.size), int(arr.size / 10))] = np.nan 117 118 pd_result = pandas_rank(arr, method=method, ascending=ascending, pct=pct, na_option=na_option, axis=axis) 119 rank = RankNet(axis=axis, method=method, ascending=ascending, pct=pct, na_option=na_option) 120 mind_result = rank(Tensor(arr)).asnumpy() 121 122 print('arr: \n', arr, arr.dtype, arr.shape) 123 print('pandas: \n', pd_result, pd_result.dtype, pd_result.shape) 124 print('mind: \n', mind_result, mind_result.dtype, mind_result.shape) 125 print(f'axis: {axis}, method: {method}, ascending: {ascending}, pct: {pct} na_option: {na_option}') 126 assert np.allclose(pd_result, mind_result, equal_nan=True) 127