• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "common/common.h"
17 #include "minddata/dataset/kernels/data/concatenate_op.h"
18 #include "utils/log_adapter.h"
19 
20 using namespace mindspore::dataset;
21 using mindspore::LogStream;
22 using mindspore::ExceptionType::NoExceptionType;
23 using mindspore::MsLogLevel::INFO;
24 
25 class MindDataTestConcatenateOp : public UT::Common {
26  protected:
27   MindDataTestConcatenateOp() {}
28 };
29 
30 TEST_F(MindDataTestConcatenateOp, TestOp) {
31   MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp-SingleRowinput.";
32   std::vector<uint64_t> labels = {1, 1, 2};
33   std::shared_ptr<Tensor> input;
34   Tensor::CreateFromVector(labels, &input);
35 
36   std::vector<uint64_t> append_labels = {4, 4, 4};
37   std::shared_ptr<Tensor> append;
38   Tensor::CreateFromVector(append_labels, &append);
39 
40   std::shared_ptr<Tensor> output;
41   std::unique_ptr<ConcatenateOp> op(new ConcatenateOp(0, nullptr, append));
42   TensorRow in;
43   in.push_back(input);
44   TensorRow out_row;
45   Status s = op->Compute(in, &out_row);
46   std::vector<uint64_t> out = {1, 1, 2, 4, 4, 4};
47 
48   std::shared_ptr<Tensor> expected;
49   Tensor::CreateFromVector(out, &expected);
50 
51   output = out_row[0];
52   EXPECT_TRUE(s.IsOk());
53   ASSERT_TRUE(output->shape() == expected->shape());
54   ASSERT_TRUE(output->type() == expected->type());
55   MS_LOG(DEBUG) << *output << std::endl;
56   MS_LOG(DEBUG) << *expected << std::endl;
57   ASSERT_TRUE(*output == *expected);
58 }
59 
60 TEST_F(MindDataTestConcatenateOp, TestOp2) {
61   MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp2-MultiInput.";
62   std::vector<uint64_t> labels = {1, 12, 2};
63   std::shared_ptr<Tensor> row_1;
64   Tensor::CreateFromVector(labels, &row_1);
65 
66   std::shared_ptr<Tensor> row_2;
67   Tensor::CreateFromVector(labels, &row_2);
68 
69   std::vector<uint64_t> append_labels = {4, 4, 4};
70   std::shared_ptr<Tensor> append;
71   Tensor::CreateFromVector(append_labels, &append);
72 
73   TensorRow tensor_list;
74   tensor_list.push_back(row_1);
75   tensor_list.push_back(row_2);
76 
77   std::shared_ptr<Tensor> output;
78   std::unique_ptr<ConcatenateOp> op(new ConcatenateOp(0, nullptr, append));
79 
80   TensorRow out_row;
81   Status s = op->Compute(tensor_list, &out_row);
82   std::vector<uint64_t> out = {1, 12, 2, 1, 12, 2, 4, 4, 4};
83 
84   std::shared_ptr<Tensor> expected;
85   Tensor::CreateFromVector(out, &expected);
86 
87   output = out_row[0];
88   EXPECT_TRUE(s.IsOk());
89   ASSERT_TRUE(output->shape() == expected->shape());
90   ASSERT_TRUE(output->type() == expected->type());
91   MS_LOG(DEBUG) << *output << std::endl;
92   MS_LOG(DEBUG) << *expected << std::endl;
93   ASSERT_TRUE(*output == *expected);
94 }
95 
96 TEST_F(MindDataTestConcatenateOp, TestOp3) {
97   MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp3-Strings.";
98   std::vector<std::string> labels = {"hello", "bye"};
99   std::shared_ptr<Tensor> row_1;
100   Tensor::CreateFromVector(labels, &row_1);
101 
102   std::vector<std::string> append_labels = {"1", "2", "3"};
103   std::shared_ptr<Tensor> append;
104   Tensor::CreateFromVector(append_labels, &append);
105 
106   TensorRow tensor_list;
107   tensor_list.push_back(row_1);
108 
109   std::shared_ptr<Tensor> output;
110   std::unique_ptr<ConcatenateOp> op(new ConcatenateOp(0, nullptr, append));
111 
112   TensorRow out_row;
113   Status s = op->Compute(tensor_list, &out_row);
114   std::vector<std::string> out = {"hello", "bye", "1", "2", "3"};
115 
116   std::shared_ptr<Tensor> expected;
117   Tensor::CreateFromVector(out, &expected);
118 
119   output = out_row[0];
120   EXPECT_TRUE(s.IsOk());
121   ASSERT_TRUE(output->shape() == expected->shape());
122   ASSERT_TRUE(output->type() == expected->type());
123   MS_LOG(DEBUG) << *output << std::endl;
124   MS_LOG(DEBUG) << *expected << std::endl;
125   ASSERT_TRUE(*output == *expected);
126 }
127