1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/platform/tensor_coding.h"
17
18 #include <vector>
19
20 #include "tensorflow/core/lib/core/coding.h"
21 #include "tensorflow/core/lib/core/stringpiece.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 #include "tensorflow/core/platform/protobuf.h"
24
25 #if defined(TENSORFLOW_PROTOBUF_USES_CORD)
26 #include "strings/cord_varint.h"
27 #endif // defined(TENSORFLOW_PROTOBUF_USES_CORD)
28
29 namespace tensorflow {
30 namespace port {
31
AssignRefCounted(StringPiece src,core::RefCounted * obj,string * out)32 void AssignRefCounted(StringPiece src, core::RefCounted* obj, string* out) {
33 out->assign(src.data(), src.size());
34 }
35
EncodeStringList(const string * strings,int64 n,string * out)36 void EncodeStringList(const string* strings, int64 n, string* out) {
37 out->clear();
38 for (int i = 0; i < n; ++i) {
39 core::PutVarint32(out, strings[i].size());
40 }
41 for (int i = 0; i < n; ++i) {
42 out->append(strings[i]);
43 }
44 }
45
DecodeStringList(const string & src,string * strings,int64 n)46 bool DecodeStringList(const string& src, string* strings, int64 n) {
47 std::vector<uint32> sizes(n);
48 StringPiece reader(src);
49 int64 tot = 0;
50 for (auto& v : sizes) {
51 if (!core::GetVarint32(&reader, &v)) return false;
52 tot += v;
53 }
54 if (tot != static_cast<int64>(reader.size())) {
55 return false;
56 }
57
58 string* data = strings;
59 for (int64 i = 0; i < n; ++i, ++data) {
60 auto size = sizes[i];
61 if (size > reader.size()) {
62 return false;
63 }
64 data->assign(reader.data(), size);
65 reader.remove_prefix(size);
66 }
67
68 return true;
69 }
70
CopyFromArray(string * s,const char * base,size_t bytes)71 void CopyFromArray(string* s, const char* base, size_t bytes) {
72 s->assign(base, bytes);
73 }
74
75 class StringListEncoderImpl : public StringListEncoder {
76 public:
StringListEncoderImpl(string * out)77 explicit StringListEncoderImpl(string* out) : out_(out) {}
78 ~StringListEncoderImpl() override = default;
79
Append(const protobuf::MessageLite & m)80 void Append(const protobuf::MessageLite& m) override {
81 core::PutVarint32(out_, m.ByteSizeLong());
82 tensorflow::string serialized_message;
83 m.AppendToString(&serialized_message);
84 strings::StrAppend(&rest_, serialized_message);
85 }
86
Append(const string & s)87 void Append(const string& s) override {
88 core::PutVarint32(out_, s.length());
89 strings::StrAppend(&rest_, s);
90 }
91
Finalize()92 void Finalize() override { strings::StrAppend(out_, rest_); }
93
94 private:
95 string* out_;
96 string rest_;
97 };
98
99 class StringListDecoderImpl : public StringListDecoder {
100 public:
StringListDecoderImpl(const string & in)101 explicit StringListDecoderImpl(const string& in) : reader_(in) {}
102 ~StringListDecoderImpl() override = default;
103
ReadSizes(std::vector<uint32> * sizes)104 bool ReadSizes(std::vector<uint32>* sizes) override {
105 int64 total = 0;
106 for (auto& size : *sizes) {
107 if (!core::GetVarint32(&reader_, &size)) return false;
108 total += size;
109 }
110 if (total != static_cast<int64>(reader_.size())) {
111 return false;
112 }
113 return true;
114 }
115
Data(uint32 size)116 const char* Data(uint32 size) override {
117 const char* data = reader_.data();
118 reader_.remove_prefix(size);
119 return data;
120 }
121
122 private:
123 StringPiece reader_;
124 };
125
NewStringListEncoder(string * out)126 std::unique_ptr<StringListEncoder> NewStringListEncoder(string* out) {
127 return std::unique_ptr<StringListEncoder>(new StringListEncoderImpl(out));
128 }
129
NewStringListDecoder(const string & in)130 std::unique_ptr<StringListDecoder> NewStringListDecoder(const string& in) {
131 return std::unique_ptr<StringListDecoder>(new StringListDecoderImpl(in));
132 }
133
134 #if defined(TENSORFLOW_PROTOBUF_USES_CORD)
AssignRefCounted(StringPiece src,core::RefCounted * obj,Cord * out)135 void AssignRefCounted(StringPiece src, core::RefCounted* obj, Cord* out) {
136 obj->Ref();
137 out->Clear();
138 // Defines a lambda to unref "obj" when Cord deletes this piece of
139 // memory. +[] converts the lambda to a C style function pointer.
140 auto cleanup = +[](absl::string_view donotcare, void* obj) {
141 reinterpret_cast<core::RefCounted*>(obj)->Unref();
142 };
143 out->AppendExternalMemory(absl::string_view(src.data(), src.size()), obj,
144 cleanup);
145 }
146
EncodeStringList(const string * strings,int64 n,Cord * out)147 void EncodeStringList(const string* strings, int64 n, Cord* out) {
148 out->Clear();
149 for (int i = 0; i < n; ++i) {
150 ::strings::CordAppendVarint(strings[i].size(), out);
151 }
152 for (int i = 0; i < n; ++i) {
153 out->Append(strings[i]);
154 }
155 }
156
DecodeStringList(const Cord & src,string * strings,int64 n)157 bool DecodeStringList(const Cord& src, string* strings, int64 n) {
158 std::vector<uint32> sizes(n);
159 CordReader reader(src);
160 int64 tot = 0;
161 for (auto& v : sizes) {
162 if (!::strings::CordReaderReadVarint(&reader, &v)) return false;
163 tot += v;
164 }
165 if (tot != reader.Available()) {
166 return false;
167 }
168 string* data = strings;
169 for (int i = 0; i < n; ++i, ++data) {
170 auto size = sizes[i];
171 if (size > reader.Available()) {
172 return false;
173 }
174 gtl::STLStringResizeUninitialized(data, size);
175 reader.ReadN(size, gtl::string_as_array(data));
176 }
177 return true;
178 }
179
CopyFromArray(Cord * c,const char * base,size_t bytes)180 void CopyFromArray(Cord* c, const char* base, size_t bytes) {
181 c->CopyFrom(base, bytes);
182 }
183
184 class CordStringListEncoderImpl : public StringListEncoder {
185 public:
CordStringListEncoderImpl(Cord * out)186 explicit CordStringListEncoderImpl(Cord* out) : out_(out) {}
187 ~CordStringListEncoderImpl() override = default;
188
Append(const protobuf::MessageLite & m)189 void Append(const protobuf::MessageLite& m) override {
190 ::strings::CordAppendVarint(m.ByteSizeLong(), out_);
191 m.AppendToString(&rest_);
192 }
193
Append(const string & s)194 void Append(const string& s) override {
195 ::strings::CordAppendVarint(s.length(), out_);
196 rest_.append(s.data(), s.size());
197 }
198
Finalize()199 void Finalize() override { out_->Append(rest_); }
200
201 private:
202 Cord* out_;
203 string rest_;
204 };
205
206 class CordStringListDecoderImpl : public StringListDecoder {
207 public:
CordStringListDecoderImpl(const Cord & in)208 explicit CordStringListDecoderImpl(const Cord& in) : reader_(in) {}
209 ~CordStringListDecoderImpl() override = default;
210
ReadSizes(std::vector<uint32> * sizes)211 bool ReadSizes(std::vector<uint32>* sizes) override {
212 int64 total = 0;
213 for (auto& size : *sizes) {
214 if (!::strings::CordReaderReadVarint(&reader_, &size)) return false;
215 total += size;
216 }
217 if (total != static_cast<int64>(reader_.Available())) {
218 return false;
219 }
220 return true;
221 }
222
Data(uint32 size)223 const char* Data(uint32 size) override {
224 tmp_.resize(size);
225 reader_.ReadN(size, tmp_.data());
226 return tmp_.data();
227 }
228
229 private:
230 CordReader reader_;
231 std::vector<char> tmp_;
232 };
233
NewStringListEncoder(Cord * out)234 std::unique_ptr<StringListEncoder> NewStringListEncoder(Cord* out) {
235 return std::unique_ptr<StringListEncoder>(new CordStringListEncoderImpl(out));
236 }
237
NewStringListDecoder(const Cord & in)238 std::unique_ptr<StringListDecoder> NewStringListDecoder(const Cord& in) {
239 return std::unique_ptr<StringListDecoder>(new CordStringListDecoderImpl(in));
240 }
241
242 #endif // defined(TENSORFLOW_PROTOBUF_USES_CORD)
243
244 } // namespace port
245 } // namespace tensorflow
246