1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <memory>
17 #include <string>
18 #include "minddata/dataset/core/client.h"
19 #include "common/common.h"
20 #include "gtest/gtest.h"
21 #include "securec.h"
22 #include "minddata/dataset/core/tensor.h"
23 #include "minddata/dataset/core/cv_tensor.h"
24 #include "minddata/dataset/core/data_type.h"
25 #include "minddata/dataset/kernels/data/mask_op.h"
26 #include "minddata/dataset/kernels/data/data_utils.h"
27
28 using namespace mindspore::dataset;
29
30 namespace py = pybind11;
31
32 class MindDataTestMaskOp : public UT::Common {
33 public:
MindDataTestMaskOp()34 MindDataTestMaskOp() {}
35
SetUp()36 void SetUp() { GlobalInit(); }
37 };
38
TEST_F(MindDataTestMaskOp,Basics)39 TEST_F(MindDataTestMaskOp, Basics) {
40 std::shared_ptr<Tensor> t;
41 Tensor::CreateFromVector(std::vector<uint32_t>({1, 2, 3, 4, 5, 6}), &t);
42 std::shared_ptr<Tensor> v;
43 Tensor::CreateFromVector(std::vector<uint32_t>({3}), TensorShape::CreateScalar(), &v);
44 std::shared_ptr<MaskOp> op = std::make_shared<MaskOp>(RelationalOp::kEqual, v, DataType(DataType::DE_UINT16));
45 std::shared_ptr<Tensor> out;
46 ASSERT_TRUE(op->Compute(t, &out).IsOk());
47
48 op = std::make_shared<MaskOp>(RelationalOp::kNotEqual, v, DataType(DataType::DE_UINT16));
49 ASSERT_TRUE(op->Compute(t, &out).IsOk());
50
51 op = std::make_shared<MaskOp>(RelationalOp::kLessEqual, v, DataType(DataType::DE_UINT16));
52 ASSERT_TRUE(op->Compute(t, &out).IsOk());
53
54 op = std::make_shared<MaskOp>(RelationalOp::kLess, v, DataType(DataType::DE_UINT16));
55 ASSERT_TRUE(op->Compute(t, &out).IsOk());
56
57 op = std::make_shared<MaskOp>(RelationalOp::kGreaterEqual, v, DataType(DataType::DE_UINT16));
58 ASSERT_TRUE(op->Compute(t, &out).IsOk());
59
60 op = std::make_shared<MaskOp>(RelationalOp::kGreater, v, DataType(DataType::DE_UINT16));
61 ASSERT_TRUE(op->Compute(t, &out).IsOk());
62 }
63