• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <memory>
17 #include <string>
18 #include "minddata/dataset/core/client.h"
19 #include "common/common.h"
20 #include "gtest/gtest.h"
21 #include "securec.h"
22 #include "minddata/dataset/core/tensor.h"
23 #include "minddata/dataset/core/cv_tensor.h"
24 #include "minddata/dataset/core/data_type.h"
25 
26 using namespace mindspore::dataset;
27 
28 namespace py = pybind11;
29 
30 class MindDataTestStringTensorDE : public UT::Common {
31  public:
32   MindDataTestStringTensorDE() = default;
33 
SetUp()34   void SetUp() override { GlobalInit(); }
35 };
36 
TEST_F(MindDataTestStringTensorDE,Basics)37 TEST_F(MindDataTestStringTensorDE, Basics) {
38   std::shared_ptr<Tensor> t;
39   Tensor::CreateScalar<std::string>("Hi", &t);
40   ASSERT_TRUE(t->shape() == TensorShape({}));
41   std::string_view s = "";
42   t->GetItemAt(&s, {});
43   ASSERT_TRUE(s == "Hi");
44 
45   std::shared_ptr<Tensor> t2;
46   Tensor::CreateFromVector(std::vector<std::string>{"Hi", "Bye"}, &t2);
47   ASSERT_TRUE(t2->shape() == TensorShape({2}));
48   t2->GetItemAt(&s, {0});
49   ASSERT_TRUE(s == "Hi");
50   t2->GetItemAt(&s, {1});
51   ASSERT_TRUE(s == "Bye");
52 
53   std::vector<std::string> strings{"abc", "defg", "hi", "klmno", "123", "789"};
54   std::shared_ptr<Tensor> t3;
55   Tensor::CreateFromVector(strings, TensorShape({2, 3}), &t3);
56 
57   ASSERT_TRUE(t3->shape() == TensorShape({2, 3}));
58   uint32_t index = 0;
59   for (uint32_t i = 0; i < 2; i++) {
60     for (uint32_t j = 0; j < 3; j++) {
61       std::string_view s = "";
62       t3->GetItemAt(&s, {i, j});
63       ASSERT_TRUE(s == strings[index++]);
64     }
65   }
66 }
67 
TEST_F(MindDataTestStringTensorDE,Basics2)68 TEST_F(MindDataTestStringTensorDE, Basics2) {
69   std::shared_ptr<Tensor> t;
70   Tensor::CreateFromVector(std::vector<std::string>{"abc", "defg", "hi", "klmno", "123", "789"}, TensorShape({2, 3}),
71                            &t);
72 
73   ASSERT_TRUE(t->SizeInBytes() == 6 * 5 + 20 + 4);
74   std::vector<uint32_t> offsets = {0, 4, 9, 12, 18, 22, 26};
75   uint32_t ctr = 0;
76   for (auto i : offsets) {
77     ASSERT_TRUE(*(reinterpret_cast<const uint32_t *>(t->GetBuffer() + ctr)) == i + 28);
78     ctr += 4;
79   }
80   const char *buf = reinterpret_cast<const char *>(t->GetBuffer()) + 6 * 4 + 4;
81   std::vector<uint32_t> starts = {0, 4, 9, 12, 18, 22};
82 
83   uint32_t index = 0;
84   for (uint32_t i = 0; i < 2; i++) {
85     for (uint32_t j = 0; j < 3; j++) {
86       std::string_view s = "";
87       t->GetItemAt(&s, {i, j});
88       ASSERT_TRUE(s.data() == buf + starts[index++]);
89     }
90   }
91 }
92 
TEST_F(MindDataTestStringTensorDE,Empty)93 TEST_F(MindDataTestStringTensorDE, Empty) {
94   std::vector<std::string> strings{"abc", "defg", "", "", "123", ""};
95   std::shared_ptr<Tensor> t;
96   Tensor::CreateFromVector(strings, TensorShape({2, 3}), &t);
97   //  abc_defg___123__
98   //  0123456789012345
99   ASSERT_TRUE(t->SizeInBytes() == 6 * 5 + 10 + 4);
100   std::vector<uint32_t> offsets = {0, 4, 9, 10, 11, 15, 16};
101   uint32_t ctr = 0;
102   for (auto i : offsets) {
103     ASSERT_TRUE(*(reinterpret_cast<const uint32_t *>(t->GetBuffer() + ctr)) == i + 28);
104     ctr += 4;
105   }
106   const char *buf = reinterpret_cast<const char *>(t->GetBuffer()) + 6 * 4 + 4;
107   std::vector<uint32_t> starts = {0, 4, 9, 10, 11, 15};
108 
109   uint32_t index = 0;
110   for (uint32_t i = 0; i < 2; i++) {
111     for (uint32_t j = 0; j < 3; j++) {
112       std::string_view s = "";
113       t->GetItemAt(&s, {i, j});
114       ASSERT_TRUE(s.data() == buf + starts[index]);
115       ASSERT_TRUE(s == strings[index++]);
116     }
117   }
118 }
119 
TEST_F(MindDataTestStringTensorDE,EmptyData)120 TEST_F(MindDataTestStringTensorDE, EmptyData) {
121   std::shared_ptr<Tensor> t;
122   Tensor::CreateScalar<std::string>("", &t);
123   // empty string has 1 element
124   ASSERT_TRUE(t->HasData());
125 
126   std::shared_ptr<Tensor> t1;
127   Tensor::CreateEmpty(TensorShape({0}), DataType(DataType::DE_STRING), &t1);
128   ASSERT_TRUE(!t1->HasData());
129 }
130 
TEST_F(MindDataTestStringTensorDE,SetItem)131 TEST_F(MindDataTestStringTensorDE, SetItem) {
132   std::vector<std::string> strings{"abc", "defg", "hi", "klmno", "123", "789"};
133   std::shared_ptr<Tensor> t3;
134   Tensor::CreateFromVector(strings, TensorShape({2, 3}), &t3);
135 
136   ASSERT_TRUE(t3->shape() == TensorShape({2, 3}));
137 
138   t3->SetItemAt({0, 1}, std::string{"xyzz"});
139   strings[1] = "xyzz";
140 
141   t3->SetItemAt({0, 2}, std::string{"07"});
142   strings[2] = "07";
143 
144   t3->SetItemAt({1, 2}, std::string{"987"});
145   strings[5] = "987";
146 
147   uint32_t index = 0;
148   for (uint32_t i = 0; i < 2; i++) {
149     for (uint32_t j = 0; j < 3; j++) {
150       std::string_view s = "";
151       t3->GetItemAt(&s, {i, j});
152       ASSERT_TRUE(s == strings[index++]);
153     }
154   }
155 }
156 
TEST_F(MindDataTestStringTensorDE,Iterator)157 TEST_F(MindDataTestStringTensorDE, Iterator) {
158   std::vector<std::string> strings{"abc", "defg", "hi", "klmno", "123", "789"};
159   std::shared_ptr<Tensor> t;
160   Tensor::CreateFromVector(strings, TensorShape({2, 3}), &t);
161   uint32_t index = 0;
162   auto itr = t->begin<std::string_view>();
163   for (; itr != t->end<std::string_view>(); itr++) {
164     ASSERT_TRUE(*itr == strings[index++]);
165   }
166 
167   index = 0;
168   itr = t->begin<std::string_view>();
169   for (; itr != t->end<std::string_view>(); itr += 2) {
170     ASSERT_TRUE(*itr == strings[index]);
171     index += 2;
172   }
173 }
174