• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16import pytest
17
18import mindspore.dataset as ds
19import mindspore.dataset.audio.transforms as a_c_trans
20
21
22def count_unequal_element(data_expected, data_me, rtol, atol):
23    assert data_expected.shape == data_me.shape
24    total_count = len(data_expected.flatten())
25    error = np.abs(data_expected - data_me)
26    greater = np.greater(error, atol + np.abs(data_expected) * rtol)
27    loss_count = np.count_nonzero(greater)
28    assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format(
29        data_expected[greater], data_me[greater], error[greater])
30
31
32def test_func_dc_shift_eager():
33    """
34    Eager Test
35    """
36    arr = np.array([0.60, 0.97, -1.04, -1.26, 0.97, 0.91, 0.48, 0.93, 0.71, 0.61], dtype=np.double)
37    expected = np.array([0.0400, 0.0400, -0.0400, -0.2600, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400],
38                        dtype=np.double)
39    dcshift_op = a_c_trans.DCShift(1.0, 0.04)
40    output = dcshift_op(arr)
41    count_unequal_element(expected, output, 0.0001, 0.0001)
42
43
44def test_func_dc_shift_pipeline():
45    """
46    Pipeline Test
47    """
48    arr = np.array([[1.14, -1.06, 0.94, 0.90], [-1.11, 1.40, -0.33, 1.43]], dtype=np.double)
49    expected = np.array([[0.2300, -0.2600, 0.2300, 0.2300], [-0.3100, 0.2300, 0.4700, 0.2300]], dtype=np.double)
50    dataset = ds.NumpySlicesDataset(arr, column_names=["col1"], shuffle=False)
51    dcshift_op = a_c_trans.DCShift(0.8, 0.03)
52    dataset = dataset.map(operations=dcshift_op, input_columns=["col1"])
53    for item1, item2 in zip(dataset.create_dict_iterator(output_numpy=True), expected):
54        count_unequal_element(item2, item1['col1'], 0.0001, 0.0001)
55
56
57def test_func_dc_shift_pipeline_error():
58    """
59    Pipeline Error Test
60    """
61    arr = np.random.uniform(-2, 2, size=(1000)).astype(np.float)
62    label = np.random.sample((1000, 1))
63    data = (arr, label)
64    dataset = ds.NumpySlicesDataset(data, column_names=["col1", "col2"], shuffle=False)
65    num_itr = 0
66    with pytest.raises(ValueError, match=r"Input shift is not within the required interval of \[-2.0, 2.0\]."):
67        dcshift_op = a_c_trans.DCShift(2.5, 0.03)
68        dataset = dataset.map(operations=dcshift_op, input_columns=["col1"])
69        for _ in dataset.create_dict_iterator(output_numpy=True):
70            num_itr += 1
71
72
73if __name__ == "__main__":
74    test_func_dc_shift_eager()
75    test_func_dc_shift_pipeline()
76    test_func_dc_shift_pipeline_error()
77