1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16from functools import partial 17from typing import Tuple, List 18import mindspore.context as context 19import mindspore.nn as nn 20from mindspore import Tensor 21from mindspore.ops import PrimitiveWithInfer, prim_attr_register 22from mindspore._checkparam import Validator as validator 23from mindspore.common import dtype as mstype 24import numpy as np 25import pytest 26 27context.set_context(mode=context.GRAPH_MODE, device_target="CPU") 28 29 30class Rolling(PrimitiveWithInfer): 31 """ 32 Shift op frontend implementation 33 """ 34 35 @prim_attr_register 36 def __init__(self, window: int, min_periods: int, center: bool, axis: int, closed: str, 37 method: str): 38 """Initialize Sort""" 39 self.window = validator.check_value_type("window", window, [int], self.name) 40 self.min_periods = validator.check_value_type("min_periods", min_periods, [int], self.name) 41 self.center = validator.check_value_type("center", center, [bool], self.name) 42 self.axis = validator.check_value_type("axis", axis, [int], self.name) 43 self.closed = validator.check_value_type("closed", closed, [str], self.name) 44 self.method = validator.check_value_type("method", method, [str], self.name) 45 46 self.init_prim_io_names(inputs=['x'], outputs=['output']) 47 48 def __infer__(self, x): 49 out_shapes = x['shape'] 50 return { 51 'shape': tuple(out_shapes), 52 'dtype': x['dtype'], 53 'value': None 54 } 55 56 def infer_dtype(self, x_dtype): 57 validator.check_tensor_dtype_valid(x_dtype, [mstype.float32, mstype.float64, mstype.int32, mstype.int64], 58 self.name, True) 59 return x_dtype 60 61 62class RollingNet(nn.Cell): 63 def __init__(self, window: int, min_periods: int, center: bool, axis: int, closed: str, 64 method: str): 65 super(RollingNet, self).__init__() 66 self.rolling = Rolling(window, min_periods, center, axis, closed, method) 67 68 def construct(self, x): 69 return self.rolling(x) 70 71 72def get_window_bounds(num_values: int, window_size: int, center: bool, closed: str = 'right') -> Tuple[List, List]: 73 assert closed in {'left', 'both', 'right', 'neither'} 74 offset = (window_size - 1) // 2 if center else 0 75 76 end = np.arange(offset + 1, num_values + 1 + offset, dtype=np.int64) 77 start = end - window_size 78 if closed in {'left', 'both'}: 79 start -= 1 80 if closed in {'left', 'neither'}: 81 end -= 1 82 83 end = np.clip(end, 0, num_values) 84 start = np.clip(start, 0, num_values) 85 86 return list(start), list(end) 87 88 89def numpy_rolling(array: np.ndarray, window: int, min_periods: int, center: bool, axis: int, closed: str, 90 method: str) -> np.ndarray: 91 assert window > 0 92 assert 0 < min_periods <= window 93 assert axis in range(-array.ndim, array.ndim) 94 reduce_map = {'max': np.max, 'min': np.min, 'mean': np.mean, 'sum': np.sum, 'std': partial(np.std, ddof=1), 95 'var': partial(np.var, ddof=1)} 96 assert method in reduce_map 97 98 size = array.shape[axis] 99 start, end = get_window_bounds(size, window, center, closed) 100 101 rolling_indices = [[slice(None)] * array.ndim for _ in range(len(start))] 102 for i, j, indice in zip(start, end, rolling_indices): 103 indice[axis] = None if j - i < min_periods else slice(i, j) 104 # print(f'i={i}, j={j}, index={index}, indice={rolling_indices[index][axis]}') 105 106 shape = list(array.shape) 107 shape[axis] = 1 108 nan_array = np.empty(shape) 109 if array.dtype == np.float32 or array.dtype == np.float64: 110 nan_array[:] = np.nan 111 elif array.dtype == np.int32 or array.dtype == np.int64: 112 nan_array[:] = 0 113 114 arrays = [ 115 nan_array.copy() if not indice[axis] 116 else reduce_map[method](array[tuple(indice)], axis=axis, keepdims=True).reshape(shape) 117 for indice in rolling_indices] 118 119 return np.stack(arrays, axis=axis).reshape(array.shape).astype(array.dtype) 120 121 122@pytest.mark.parametrize('shape', [(10, 8, 15, 7), (5, 3, 8, 10)]) 123@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64]) 124@pytest.mark.parametrize('window, min_periods', [(3, 3), (5, 3)]) 125@pytest.mark.parametrize('center', [True, False]) 126@pytest.mark.parametrize('axis', [2, 3, -1]) 127@pytest.mark.parametrize('closed', ['left', 'both', 'right', 'neither']) 128@pytest.mark.parametrize('method', ['max', 'min', 'mean', 'sum', 'std', 'var']) 129def test_two_way(shape: List[int], dtype, window: int, min_periods: int, center: bool, axis: int, closed: str, 130 method: str) -> np.ndarray: 131 if dtype in (np.int32, np.int64): 132 arr = np.random.randint(0, 100, size=shape) 133 else: 134 arr = np.random.random(shape).astype(dtype) 135 expect_result = numpy_rolling(arr, window=window, min_periods=min_periods, center=center, axis=axis, closed=closed, 136 method=method) 137 rolling = RollingNet(window=window, min_periods=min_periods, center=center, axis=axis, closed=closed, 138 method=method) 139 actual_result = rolling(Tensor(arr)).asnumpy() 140 print('arr: \n', arr, arr.dtype, arr.shape) 141 print('np: \n', expect_result, expect_result.dtype, expect_result.shape) 142 print('mine: \n', actual_result, actual_result.dtype, actual_result.shape) 143 print(f'center: {center}, axis: {axis}, method: {method}') 144 assert np.allclose(expect_result, actual_result, equal_nan=True) 145