• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16from functools import partial
17from typing import Tuple, List
18import mindspore.context as context
19import mindspore.nn as nn
20from mindspore import Tensor
21from mindspore.ops import PrimitiveWithInfer, prim_attr_register
22from mindspore._checkparam import Validator as validator
23from mindspore.common import dtype as mstype
24import numpy as np
25import pytest
26
27context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
28
29
30class Rolling(PrimitiveWithInfer):
31    """
32        Shift op frontend implementation
33    """
34
35    @prim_attr_register
36    def __init__(self, window: int, min_periods: int, center: bool, axis: int, closed: str,
37                 method: str):
38        """Initialize Sort"""
39        self.window = validator.check_value_type("window", window, [int], self.name)
40        self.min_periods = validator.check_value_type("min_periods", min_periods, [int], self.name)
41        self.center = validator.check_value_type("center", center, [bool], self.name)
42        self.axis = validator.check_value_type("axis", axis, [int], self.name)
43        self.closed = validator.check_value_type("closed", closed, [str], self.name)
44        self.method = validator.check_value_type("method", method, [str], self.name)
45
46        self.init_prim_io_names(inputs=['x'], outputs=['output'])
47
48    def __infer__(self, x):
49        out_shapes = x['shape']
50        return {
51            'shape': tuple(out_shapes),
52            'dtype': x['dtype'],
53            'value': None
54        }
55
56    def infer_dtype(self, x_dtype):
57        validator.check_tensor_dtype_valid(x_dtype, [mstype.float32, mstype.float64, mstype.int32, mstype.int64],
58                                           self.name, True)
59        return x_dtype
60
61
62class RollingNet(nn.Cell):
63    def __init__(self, window: int, min_periods: int, center: bool, axis: int, closed: str,
64                 method: str):
65        super(RollingNet, self).__init__()
66        self.rolling = Rolling(window, min_periods, center, axis, closed, method)
67
68    def construct(self, x):
69        return self.rolling(x)
70
71
72def get_window_bounds(num_values: int, window_size: int, center: bool, closed: str = 'right') -> Tuple[List, List]:
73    assert closed in {'left', 'both', 'right', 'neither'}
74    offset = (window_size - 1) // 2 if center else 0
75
76    end = np.arange(offset + 1, num_values + 1 + offset, dtype=np.int64)
77    start = end - window_size
78    if closed in {'left', 'both'}:
79        start -= 1
80    if closed in {'left', 'neither'}:
81        end -= 1
82
83    end = np.clip(end, 0, num_values)
84    start = np.clip(start, 0, num_values)
85
86    return list(start), list(end)
87
88
89def numpy_rolling(array: np.ndarray, window: int, min_periods: int, center: bool, axis: int, closed: str,
90                  method: str) -> np.ndarray:
91    assert window > 0
92    assert 0 < min_periods <= window
93    assert axis in range(-array.ndim, array.ndim)
94    reduce_map = {'max': np.max, 'min': np.min, 'mean': np.mean, 'sum': np.sum, 'std': partial(np.std, ddof=1),
95                  'var': partial(np.var, ddof=1)}
96    assert method in reduce_map
97
98    size = array.shape[axis]
99    start, end = get_window_bounds(size, window, center, closed)
100
101    rolling_indices = [[slice(None)] * array.ndim for _ in range(len(start))]
102    for i, j, indice in zip(start, end, rolling_indices):
103        indice[axis] = None if j - i < min_periods else slice(i, j)
104        # print(f'i={i}, j={j}, index={index}, indice={rolling_indices[index][axis]}')
105
106    shape = list(array.shape)
107    shape[axis] = 1
108    nan_array = np.empty(shape)
109    if array.dtype == np.float32 or array.dtype == np.float64:
110        nan_array[:] = np.nan
111    elif array.dtype == np.int32 or array.dtype == np.int64:
112        nan_array[:] = 0
113
114    arrays = [
115        nan_array.copy() if not indice[axis]
116        else reduce_map[method](array[tuple(indice)], axis=axis, keepdims=True).reshape(shape)
117        for indice in rolling_indices]
118
119    return np.stack(arrays, axis=axis).reshape(array.shape).astype(array.dtype)
120
121
122@pytest.mark.parametrize('shape', [(10, 8, 15, 7), (5, 3, 8, 10)])
123@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
124@pytest.mark.parametrize('window, min_periods', [(3, 3), (5, 3)])
125@pytest.mark.parametrize('center', [True, False])
126@pytest.mark.parametrize('axis', [2, 3, -1])
127@pytest.mark.parametrize('closed', ['left', 'both', 'right', 'neither'])
128@pytest.mark.parametrize('method', ['max', 'min', 'mean', 'sum', 'std', 'var'])
129def test_two_way(shape: List[int], dtype, window: int, min_periods: int, center: bool, axis: int, closed: str,
130                 method: str) -> np.ndarray:
131    if dtype in (np.int32, np.int64):
132        arr = np.random.randint(0, 100, size=shape)
133    else:
134        arr = np.random.random(shape).astype(dtype)
135    expect_result = numpy_rolling(arr, window=window, min_periods=min_periods, center=center, axis=axis, closed=closed,
136                                  method=method)
137    rolling = RollingNet(window=window, min_periods=min_periods, center=center, axis=axis, closed=closed,
138                         method=method)
139    actual_result = rolling(Tensor(arr)).asnumpy()
140    print('arr: \n', arr, arr.dtype, arr.shape)
141    print('np: \n', expect_result, expect_result.dtype, expect_result.shape)
142    print('mine: \n', actual_result, actual_result.dtype, actual_result.shape)
143    print(f'center: {center}, axis: {axis}, method: {method}')
144    assert np.allclose(expect_result, actual_result, equal_nan=True)
145