• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/common/string_utils.h"
18 #include <algorithm>
19 #include <vector>
20 #include <string>
21 #include <fstream>
22 #include <climits>
23 
24 namespace mindspore {
25 namespace lite {
ParseTensorBuffer(Tensor * tensor)26 std::vector<StringPack> ParseTensorBuffer(Tensor *tensor) {
27   if (tensor == nullptr) {
28     MS_LOG(ERROR) << "tensor is nullptr.";
29     return std::vector<StringPack>{};
30   }
31   if (tensor->data() == nullptr) {
32     MS_LOG(ERROR) << "Tensor data is null, cannot be parsed";
33     return std::vector<StringPack>{};
34   }
35   return ParseStringBuffer(tensor->MutableData());
36 }
37 
ParseStringBuffer(const void * data)38 std::vector<StringPack> ParseStringBuffer(const void *data) {
39   std::vector<StringPack> buffer;
40   if (data == nullptr) {
41     MS_LOG(ERROR) << "data is nullptr";
42     return buffer;
43   }
44   const auto *offset = reinterpret_cast<const int32_t *>(data);
45   int32_t num = *offset;
46   for (int i = 0; i < num; i++) {
47     offset += 1;
48     buffer.push_back(StringPack{(*(offset + 1)) - (*offset), reinterpret_cast<const char *>(data) + (*offset)});
49   }
50   return buffer;
51 }
52 
WriteStringsToTensor(Tensor * tensor,const std::vector<StringPack> & string_buffer)53 int WriteStringsToTensor(Tensor *tensor, const std::vector<StringPack> &string_buffer) {
54   if (tensor == nullptr) {
55     MS_LOG(ERROR) << "tensor is nullptr.";
56     return RET_ERROR;
57   }
58   size_t num = string_buffer.size();
59   std::vector<int32_t> offset(num + 1);
60   const size_t extra_offset_num = 2;
61   offset[0] = static_cast<int32_t>(sizeof(int32_t) * (num + extra_offset_num));
62   for (size_t i = 0; i < num; i++) {
63     offset[i + 1] = offset[i] + string_buffer[i].len;
64   }
65   std::vector<int> shape = {offset[num]};
66   tensor->set_shape(shape);
67   tensor->set_data_type(kObjectTypeString);
68   tensor->FreeData();
69   void *data = tensor->MutableData();
70   if (data == nullptr) {
71     return RET_ERROR;
72   }
73 
74   auto *string_info = reinterpret_cast<int32_t *>(data);
75   char *string_data = reinterpret_cast<char *>(data);
76 
77   string_info[0] = static_cast<int32_t>(num);
78   for (size_t i = 0; i <= num; i++) {
79     string_info[i + 1] = offset[i];
80   }
81   for (size_t i = 0; i < num; i++) {
82     memcpy(string_data + offset[i], string_buffer[i].data, string_buffer[i].len);
83   }
84   return RET_OK;
85 }
86 
WriteSeperatedStringsToTensor(Tensor * tensor,const std::vector<std::vector<StringPack>> & string_buffer)87 int WriteSeperatedStringsToTensor(Tensor *tensor, const std::vector<std::vector<StringPack>> &string_buffer) {
88   if (tensor == nullptr) {
89     MS_LOG(ERROR) << "tensor is nullptr.";
90     return RET_ERROR;
91   }
92   size_t num = string_buffer.size();
93   std::vector<int32_t> offset(num + 1);
94   const size_t extra_offset_num = 2;
95   offset[0] = static_cast<int32_t>(sizeof(int32_t) * (num + extra_offset_num));
96   std::vector<int> len(num);
97   for (size_t i = 0; i < num; i++) {
98     len[i] = 0;
99     for (int j = 0; j < static_cast<int>(string_buffer[i].size()); j++) {
100       len[i] += string_buffer[i][j].len;
101     }
102     offset[i + 1] = offset[i] + len[i];
103   }
104 
105   std::vector<int> shape = {offset[num]};
106   tensor->set_shape(shape);
107   tensor->FreeData();
108   void *data = tensor->MutableData();
109   if (data == nullptr) {
110     return RET_ERROR;
111   }
112 
113   auto *string_info = reinterpret_cast<int32_t *>(data);
114   auto *string_data = reinterpret_cast<char *>(data);
115 
116   string_info[0] = static_cast<int32_t>(num);
117   for (size_t i = 0; i <= num; i++) {
118     string_info[i + 1] = offset[i];
119   }
120   for (size_t i = 0; i < num; i++) {
121     auto *dst = string_data + offset[i];
122     for (auto string_part : string_buffer[i]) {
123       memcpy(dst, string_part.data, string_part.len);
124       dst += string_part.len;
125     }
126   }
127   return RET_OK;
128 }
129 
GetStringCount(const void * data)130 int GetStringCount(const void *data) { return *(static_cast<const int32_t *>(data)); }
131 
GetStringCount(Tensor * tensor)132 int GetStringCount(Tensor *tensor) {
133   if (tensor == nullptr) {
134     MS_LOG(ERROR) << "tensor is nullptr.";
135     return RET_ERROR;
136   }
137   return GetStringCount(tensor->MutableData());
138 }
139 
140 // Some primes between 2^63 and 2^64
141 namespace {
142 static const uint64_t k0 = 0xc3a5c85c97cb3127ULL;
143 static const uint64_t k1 = 0xb492b66fbe98f273ULL;
144 static const uint64_t k2 = 0x9ae16a3b2f90404fULL;
145 
Fetch64Bit(const char * p)146 uint64_t Fetch64Bit(const char *p) {
147   uint64_t result = 0;
148   memcpy(&result, p, sizeof(uint64_t));
149   return result;
150 }
151 
Fetch32Bit(const char * p)152 uint32_t Fetch32Bit(const char *p) {
153   uint32_t result = 0;
154   memcpy(&result, p, sizeof(uint32_t));
155   return result;
156 }
157 
Rotate64(uint64_t value,int shift)158 uint64_t Rotate64(uint64_t value, int shift) {
159   return shift == 0
160            ? value
161            : ((value >> static_cast<unsigned int>(shift)) | (value << static_cast<unsigned int>((64 - shift))));
162 }
163 
HashLen16(uint64_t u,uint64_t v,uint64_t multiple)164 uint64_t HashLen16(uint64_t u, uint64_t v, uint64_t multiple) {
165   uint64_t a = (u ^ v) * multiple;
166   a ^= (a >> 47);
167   uint64_t b = (v ^ a) * multiple;
168   b ^= (b >> 47);
169   b *= multiple;
170   return b;
171 }
172 
ShiftMix(uint64_t value)173 uint64_t ShiftMix(uint64_t value) { return value ^ (value >> 47); }
174 
HashStringLen0to16(const char * s,size_t len)175 uint64_t HashStringLen0to16(const char *s, size_t len) {
176   if (len >= 8) {
177     uint64_t mul = k2 + len * 2;
178     uint64_t a = Fetch64Bit(s) + k2;
179     uint64_t b = Fetch64Bit(s + len - 8);
180     uint64_t c = Rotate64(b, 37) * mul + a;
181     uint64_t d = (Rotate64(a, 25) + b) * mul;
182     return HashLen16(c, d, mul);
183   }
184   if (len >= 4) {
185     uint64_t mul = k2 + len * 2;
186     uint64_t a = Fetch32Bit(s);
187     return HashLen16(len + (a << 3), Fetch32Bit(s + len - 4), mul);
188   }
189   if (len > 0) {
190     uint8_t a = s[0];
191     uint8_t b = s[len >> 1];
192     uint8_t c = s[len - 1];
193     uint32_t y = static_cast<uint32_t>(a) + (static_cast<uint32_t>(b) << 8);
194     uint32_t z = len + (static_cast<uint32_t>(c) << 2);
195     return ShiftMix((y * k2) ^ (z * k0)) * k2;
196   }
197   return k2;
198 }
199 
HashStringLen17to32(const char * s,size_t len)200 uint64_t HashStringLen17to32(const char *s, size_t len) {
201   uint64_t mul = k2 + len * 2;
202   uint64_t a = Fetch64Bit(s) * k1;
203   uint64_t b = Fetch64Bit(s + 8);
204   uint64_t c = Fetch64Bit(s + len - 8) * mul;
205   uint64_t d = Fetch64Bit(s + len - 16) * k2;
206   return HashLen16(Rotate64(a + b, 43) + Rotate64(c, 30) + d, a + Rotate64(b + k2, 18) + c, mul);
207 }
208 
HashStringLen33to64(const char * s,size_t len)209 uint64_t HashStringLen33to64(const char *s, size_t len) {
210   uint64_t mul = k2 + len * 2;
211   uint64_t a = Fetch64Bit(s) * k2;
212   uint64_t b = Fetch64Bit(s + 8);
213   uint64_t c = Fetch64Bit(s + len - 8) * mul;
214   uint64_t d = Fetch64Bit(s + len - 16) * k2;
215   uint64_t y = Rotate64(a + b, 43) + Rotate64(c, 30) + d;
216   uint64_t z = HashLen16(y, a + Rotate64(b + k2, 18) + c, mul);
217   uint64_t e = Fetch64Bit(s + 16) * mul;
218   uint64_t f = Fetch64Bit(s + 24);
219   uint64_t g = (y + Fetch64Bit(s + len - 32)) * mul;
220   uint64_t h = (z + Fetch64Bit(s + len - 24)) * mul;
221   return HashLen16(Rotate64(e + f, 43) + Rotate64(g, 30) + h, e + Rotate64(f + a, 18) + g, mul);
222 }
223 
HashLen32WithSeeds(const char * s,uint64_t a,uint64_t b)224 std::pair<uint64_t, uint64_t> HashLen32WithSeeds(const char *s, uint64_t a, uint64_t b) {
225   a += Fetch64Bit(s);
226   b = Rotate64(b + a + Fetch64Bit(s + 24), 21);
227   uint64_t c = a;
228   a += Fetch64Bit(s + 8);
229   a += Fetch64Bit(s + 16);
230   b += Rotate64(a, 44);
231   return std::make_pair(a + Fetch64Bit(s + 24), b + c);
232 }
233 }  // namespace
234 
StringHash64(const char * s,size_t len)235 uint64_t StringHash64(const char *s, size_t len) {
236   if (s == nullptr) {
237     return 0;
238   }
239   const uint64_t seed_value = 81;
240   if (len <= 16) {
241     return HashStringLen0to16(s, len);
242   } else if (len <= 32) {
243     return HashStringLen17to32(s, len);
244   } else if (len <= 64) {
245     return HashStringLen33to64(s, len);
246   }
247 
248   uint64_t x = seed_value;
249   uint64_t y = seed_value * k1 + 113;
250   uint64_t tmp = y * k2 + 113;
251   uint64_t z = (tmp ^ (tmp >> 47)) * k2;
252   std::pair<uint64_t, uint64_t> v = std::make_pair(0, 0);
253   std::pair<uint64_t, uint64_t> w = std::make_pair(0, 0);
254   x = x * k2 + Fetch64Bit(s);
255 
256   const char *end = s + ((len - 1) / 64) * 64;
257   const char *last64 = end + ((len - 1) & 63) - 63;
258   MS_ASSERT(s + len - 64 == last64);
259   do {
260     x = Rotate64(x + y + v.first + Fetch64Bit(s + 8), 37) * k1;
261     y = Rotate64(y + v.second + Fetch64Bit(s + 48), 42) * k1;
262     x ^= w.second;
263     y += v.first + Fetch64Bit(s + 40);
264     z = Rotate64(z + w.first, 33) * k1;
265     v = HashLen32WithSeeds(s, v.second * k1, x + w.first);
266     w = HashLen32WithSeeds(s + 32, z + w.second, y + Fetch64Bit(s + 16));
267     std::swap(z, x);
268     s += 64;
269   } while (s != end);
270   uint64_t mul = k1 + ((z & 0xff) << 1);
271   s = last64;
272   w.first += ((len - 1) & 63);
273   v.first += w.first;
274   w.first += v.first;
275   x = Rotate64(x + y + v.first + Fetch64Bit(s + 8), 37) * mul;
276   y = Rotate64(y + v.second + Fetch64Bit(s + 48), 42) * mul;
277   x ^= w.second * 9;
278   y += v.first * 9 + Fetch64Bit(s + 40);
279   z = Rotate64(z + w.first, 33) * mul;
280   v = HashLen32WithSeeds(s, v.second * mul, x + w.first);
281   w = HashLen32WithSeeds(s + 32, z + w.second, y + Fetch64Bit(s + 16));
282   std::swap(z, x);
283   return HashLen16(HashLen16(v.first, w.first, mul) + ShiftMix(y) * k0 + z, HashLen16(v.second, w.second, mul) + x,
284                    mul);
285 }
286 }  // namespace lite
287 }  // namespace mindspore
288