1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include "common/common.h" 17 #include "minddata/dataset/kernels/data/concatenate_op.h" 18 #include "utils/log_adapter.h" 19 20 using namespace mindspore::dataset; 21 using mindspore::LogStream; 22 using mindspore::ExceptionType::NoExceptionType; 23 using mindspore::MsLogLevel::INFO; 24 25 class MindDataTestConcatenateOp : public UT::Common { 26 protected: 27 MindDataTestConcatenateOp() {} 28 }; 29 30 TEST_F(MindDataTestConcatenateOp, TestOp) { 31 MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp-SingleRowinput."; 32 std::vector<uint64_t> labels = {1, 1, 2}; 33 std::shared_ptr<Tensor> input; 34 Tensor::CreateFromVector(labels, &input); 35 36 std::vector<uint64_t> append_labels = {4, 4, 4}; 37 std::shared_ptr<Tensor> append; 38 Tensor::CreateFromVector(append_labels, &append); 39 40 std::shared_ptr<Tensor> output; 41 std::unique_ptr<ConcatenateOp> op(new ConcatenateOp(0, nullptr, append)); 42 TensorRow in; 43 in.push_back(input); 44 TensorRow out_row; 45 Status s = op->Compute(in, &out_row); 46 std::vector<uint64_t> out = {1, 1, 2, 4, 4, 4}; 47 48 std::shared_ptr<Tensor> expected; 49 Tensor::CreateFromVector(out, &expected); 50 51 output = out_row[0]; 52 EXPECT_TRUE(s.IsOk()); 53 ASSERT_TRUE(output->shape() == expected->shape()); 54 ASSERT_TRUE(output->type() == expected->type()); 55 MS_LOG(DEBUG) << *output << std::endl; 56 MS_LOG(DEBUG) << *expected << std::endl; 57 ASSERT_TRUE(*output == *expected); 58 } 59 60 TEST_F(MindDataTestConcatenateOp, TestOp2) { 61 MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp2-MultiInput."; 62 std::vector<uint64_t> labels = {1, 12, 2}; 63 std::shared_ptr<Tensor> row_1; 64 Tensor::CreateFromVector(labels, &row_1); 65 66 std::shared_ptr<Tensor> row_2; 67 Tensor::CreateFromVector(labels, &row_2); 68 69 std::vector<uint64_t> append_labels = {4, 4, 4}; 70 std::shared_ptr<Tensor> append; 71 Tensor::CreateFromVector(append_labels, &append); 72 73 TensorRow tensor_list; 74 tensor_list.push_back(row_1); 75 tensor_list.push_back(row_2); 76 77 std::shared_ptr<Tensor> output; 78 std::unique_ptr<ConcatenateOp> op(new ConcatenateOp(0, nullptr, append)); 79 80 TensorRow out_row; 81 Status s = op->Compute(tensor_list, &out_row); 82 std::vector<uint64_t> out = {1, 12, 2, 1, 12, 2, 4, 4, 4}; 83 84 std::shared_ptr<Tensor> expected; 85 Tensor::CreateFromVector(out, &expected); 86 87 output = out_row[0]; 88 EXPECT_TRUE(s.IsOk()); 89 ASSERT_TRUE(output->shape() == expected->shape()); 90 ASSERT_TRUE(output->type() == expected->type()); 91 MS_LOG(DEBUG) << *output << std::endl; 92 MS_LOG(DEBUG) << *expected << std::endl; 93 ASSERT_TRUE(*output == *expected); 94 } 95 96 TEST_F(MindDataTestConcatenateOp, TestOp3) { 97 MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp3-Strings."; 98 std::vector<std::string> labels = {"hello", "bye"}; 99 std::shared_ptr<Tensor> row_1; 100 Tensor::CreateFromVector(labels, &row_1); 101 102 std::vector<std::string> append_labels = {"1", "2", "3"}; 103 std::shared_ptr<Tensor> append; 104 Tensor::CreateFromVector(append_labels, &append); 105 106 TensorRow tensor_list; 107 tensor_list.push_back(row_1); 108 109 std::shared_ptr<Tensor> output; 110 std::unique_ptr<ConcatenateOp> op(new ConcatenateOp(0, nullptr, append)); 111 112 TensorRow out_row; 113 Status s = op->Compute(tensor_list, &out_row); 114 std::vector<std::string> out = {"hello", "bye", "1", "2", "3"}; 115 116 std::shared_ptr<Tensor> expected; 117 Tensor::CreateFromVector(out, &expected); 118 119 output = out_row[0]; 120 EXPECT_TRUE(s.IsOk()); 121 ASSERT_TRUE(output->shape() == expected->shape()); 122 ASSERT_TRUE(output->type() == expected->type()); 123 MS_LOG(DEBUG) << *output << std::endl; 124 MS_LOG(DEBUG) << *expected << std::endl; 125 ASSERT_TRUE(*output == *expected); 126 } 127