• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing Mask op in DE
17"""
18import numpy as np
19import pytest
20
21import mindspore.common.dtype as mstype
22import mindspore.dataset as ds
23import mindspore.dataset.transforms.c_transforms as ops
24
25mstype_to_np_type = {
26    mstype.bool_: np.bool,
27    mstype.int8: np.int8,
28    mstype.uint8: np.uint8,
29    mstype.int16: np.int16,
30    mstype.uint16: np.uint16,
31    mstype.int32: np.int32,
32    mstype.uint32: np.uint32,
33    mstype.int64: np.int64,
34    mstype.uint64: np.uint64,
35    mstype.float16: np.float16,
36    mstype.float32: np.float32,
37    mstype.float64: np.float64,
38    mstype.string: np.str
39}
40
41
42def mask_compare(array, op, constant, dtype=mstype.bool_):
43    data = ds.NumpySlicesDataset([array])
44    array = np.array(array)
45    data = data.map(operations=ops.Mask(op, constant, dtype))
46    for d in data:
47        if op == ops.Relational.EQ:
48            array = array == np.array(constant, dtype=array.dtype)
49        elif op == ops.Relational.NE:
50            array = array != np.array(constant, dtype=array.dtype)
51        elif op == ops.Relational.GT:
52            array = array > np.array(constant, dtype=array.dtype)
53        elif op == ops.Relational.GE:
54            array = array >= np.array(constant, dtype=array.dtype)
55        elif op == ops.Relational.LT:
56            array = array < np.array(constant, dtype=array.dtype)
57        elif op == ops.Relational.LE:
58            array = array <= np.array(constant, dtype=array.dtype)
59
60        array = array.astype(dtype=mstype_to_np_type[dtype])
61
62        np.testing.assert_array_equal(array, d[0].asnumpy())
63
64
65def test_mask_int_comparison():
66    for k in mstype_to_np_type:
67        if k == mstype.string:
68            continue
69        mask_compare([1, 2, 3, 4, 5], ops.Relational.EQ, 3, k)
70        mask_compare([1, 2, 3, 4, 5], ops.Relational.NE, 3, k)
71        mask_compare([1, 2, 3, 4, 5], ops.Relational.LT, 3, k)
72        mask_compare([1, 2, 3, 4, 5], ops.Relational.LE, 3, k)
73        mask_compare([1, 2, 3, 4, 5], ops.Relational.GT, 3, k)
74        mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3, k)
75
76
77def test_mask_float_comparison():
78    for k in mstype_to_np_type:
79        if k == mstype.string:
80            continue
81        mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.EQ, 3, k)
82        mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.NE, 3, k)
83        mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.LT, 3, k)
84        mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.LE, 3, k)
85        mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.GT, 3, k)
86        mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.GE, 3, k)
87
88
89def test_mask_float_comparison2():
90    for k in mstype_to_np_type:
91        if k == mstype.string:
92            continue
93        mask_compare([1, 2, 3, 4, 5], ops.Relational.EQ, 3.5, k)
94        mask_compare([1, 2, 3, 4, 5], ops.Relational.NE, 3.5, k)
95        mask_compare([1, 2, 3, 4, 5], ops.Relational.LT, 3.5, k)
96        mask_compare([1, 2, 3, 4, 5], ops.Relational.LE, 3.5, k)
97        mask_compare([1, 2, 3, 4, 5], ops.Relational.GT, 3.5, k)
98        mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3.5, k)
99
100
101def test_mask_string_comparison():
102    for k in mstype_to_np_type:
103        if k == mstype.string:
104            continue
105        mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.EQ, "3.", k)
106        mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.NE, "3.", k)
107        mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.LT, "3.", k)
108        mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.LE, "3.", k)
109        mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.GT, "3.", k)
110        mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.GE, "3.", k)
111
112
113def test_mask_exceptions_str():
114    with pytest.raises(RuntimeError) as info:
115        mask_compare([1, 2, 3, 4, 5], ops.Relational.EQ, "3.5")
116    assert "input datatype does not match the value datatype" in str(info.value)
117
118    with pytest.raises(RuntimeError) as info:
119        mask_compare(["1", "2", "3", "4", "5"], ops.Relational.EQ, 3.5)
120    assert "input datatype does not match the value datatype" in str(info.value)
121
122    with pytest.raises(RuntimeError) as info:
123        mask_compare(["1", "2", "3", "4", "5"], ops.Relational.EQ, "3.5", mstype.string)
124    assert "Only supports bool or numeric datatype for generated mask type" in str(info.value)
125
126
127if __name__ == "__main__":
128    test_mask_int_comparison()
129    test_mask_float_comparison()
130    test_mask_float_comparison2()
131    test_mask_string_comparison()
132    test_mask_exceptions_str()
133