1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing unique op in DE 17""" 18import numpy as np 19 20import mindspore.dataset as ds 21import mindspore.dataset.transforms.c_transforms as ops 22 23 24def compare(array, res, idx, cnt): 25 data = ds.NumpySlicesDataset([array], column_names="x") 26 data = data.batch(2) 27 data = data.map(operations=ops.Unique(), input_columns=["x"], output_columns=["x", "y", "z"], 28 column_order=["x", "y", "z"]) 29 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 30 np.testing.assert_array_equal(res, d["x"]) 31 np.testing.assert_array_equal(idx, d["y"]) 32 np.testing.assert_array_equal(cnt, d["z"]) 33 34def test_duplicate_basics(): 35 compare([0, 1, 2, 1, 2, 3], np.array([0, 1, 2, 3]), 36 np.array([0, 1, 2, 1, 2, 3]), np.array([1, 2, 2, 1])) 37 compare([0.0, 1.0, 2.0, 1.0, 2.0, 3.0], np.array([0.0, 1.0, 2.0, 3.0]), 38 np.array([0, 1, 2, 1, 2, 3]), np.array([1, 2, 2, 1])) 39 compare([1, 1, 1, 1, 1, 1], np.array([1]), 40 np.array([0, 0, 0, 0, 0, 0]), np.array([6])) 41 42 43if __name__ == "__main__": 44 test_duplicate_basics() 45