• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing Magphase Python API
17"""
18import numpy as np
19
20import mindspore.dataset as ds
21import mindspore.dataset.audio.transforms as audio
22from mindspore import log as logger
23
24
25def test_magphase_pipeline():
26    """
27    Test magphase (pipeline).
28    """
29    logger.info("Test Magphase pipeline.")
30
31    data1 = [[[3.0, -4.0], [-5.0, 12.0]]]
32    expected = [5, 13, -0.927295, 1.965587]
33    dataset = ds.NumpySlicesDataset(data1, column_names=["col1"], shuffle=False)
34    magphase_window = audio.Magphase(power=1.0)
35    dataset = dataset.map(operations=magphase_window, input_columns=["col1"],
36                          output_columns=["mag", "phase"], column_order=["mag", "phase"])
37    for data1, data2 in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True):
38        assert abs(data1[0] - expected[0]) < 0.00001
39        assert abs(data1[1] - expected[1]) < 0.00001
40        assert abs(data2[0] - expected[2]) < 0.00001
41        assert abs(data2[1] - expected[3]) < 0.00001
42
43    logger.info("Finish testing Magphase.")
44
45
46def test_magphase_eager():
47    """
48    Test magphase (eager).
49    """
50    logger.info("Test Magphase eager.")
51
52    input_number = np.array([41, 67, 34, 0, 69, 24, 78, 58]).reshape((2, 2, 2)).astype("double")
53    mag = np.array([78.54934755, 34., 73.05477397, 97.20082304]).reshape((2, 2)).astype("double")
54    phase = np.array([1.02164342, 0, 0.33473684, 0.63938591]).reshape((2, 2)).astype("double")
55    magphase_window = audio.Magphase()
56    data1, data2 = magphase_window(input_number)
57    assert (abs(data1 - mag) < 0.00001).all()
58    assert (abs(data2 - phase) < 0.00001).all()
59
60    logger.info("Finish testing Magphase.")
61
62
63def test_magphase_exception():
64    """
65    Test magphase not callable.
66    """
67    logger.info("Test Magphase not callable.")
68
69    try:
70        input_number = np.array([1, 2, 3, 4]).reshape(4,).astype("double")
71        magphase_window = audio.Magphase(power=2.0)
72        _ = magphase_window(input_number)
73    except RuntimeError as error:
74        logger.info("Got an exception in Magphase: {}".format(str(error)))
75        assert "Magphase: input tensor is not in shape of <..., 2>." in str(error)
76    try:
77        input_number = np.array([1, 2, 3, 4]).reshape(1, 4).astype("double")
78        magphase_window = audio.Magphase(power=2.0)
79        _ = magphase_window(input_number)
80    except RuntimeError as error:
81        logger.info("Got an exception in Magphase: {}".format(str(error)))
82        assert "Magphase: input tensor is not in shape of <..., 2>." in str(error)
83    try:
84        input_number = np.array(['test', 'test']).reshape(1, 2)
85        magphase_window = audio.Magphase(power=2.0)
86        _ = magphase_window(input_number)
87    except RuntimeError as error:
88        logger.info("Got an exception in Magphase: {}".format(str(error)))
89        assert "Magphase: input tensor type should be int, float or double" in str(error)
90    try:
91        input_number = np.array([1, 2, 3, 4]).reshape(2, 2).astype("double")
92        magphase_window = audio.Magphase(power=-1.0)
93        _ = magphase_window(input_number)
94    except ValueError as error:
95        logger.info("Got an exception in Magphase: {}".format(str(error)))
96        assert "Input power is not within the required interval of [0, 16777216]." in str(error)
97
98    logger.info("Finish testing Magphase.")
99
100
101if __name__ == "__main__":
102    test_magphase_pipeline()
103    test_magphase_eager()
104    test_magphase_exception()
105