• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/platform/tensor_coding.h"
17 
18 #include <vector>
19 
20 #include "tensorflow/core/lib/core/coding.h"
21 #include "tensorflow/core/lib/core/stringpiece.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 #include "tensorflow/core/platform/protobuf.h"
24 
25 #if defined(TENSORFLOW_PROTOBUF_USES_CORD)
26 #include "strings/cord_varint.h"
27 #endif  // defined(TENSORFLOW_PROTOBUF_USES_CORD)
28 
29 namespace tensorflow {
30 namespace port {
31 
AssignRefCounted(StringPiece src,core::RefCounted * obj,string * out)32 void AssignRefCounted(StringPiece src, core::RefCounted* obj, string* out) {
33   out->assign(src.data(), src.size());
34 }
35 
EncodeStringList(const string * strings,int64 n,string * out)36 void EncodeStringList(const string* strings, int64 n, string* out) {
37   out->clear();
38   for (int i = 0; i < n; ++i) {
39     core::PutVarint32(out, strings[i].size());
40   }
41   for (int i = 0; i < n; ++i) {
42     out->append(strings[i]);
43   }
44 }
45 
DecodeStringList(const string & src,string * strings,int64 n)46 bool DecodeStringList(const string& src, string* strings, int64 n) {
47   std::vector<uint32> sizes(n);
48   StringPiece reader(src);
49   int64 tot = 0;
50   for (auto& v : sizes) {
51     if (!core::GetVarint32(&reader, &v)) return false;
52     tot += v;
53   }
54   if (tot != static_cast<int64>(reader.size())) {
55     return false;
56   }
57 
58   string* data = strings;
59   for (int64 i = 0; i < n; ++i, ++data) {
60     auto size = sizes[i];
61     if (size > reader.size()) {
62       return false;
63     }
64     data->assign(reader.data(), size);
65     reader.remove_prefix(size);
66   }
67 
68   return true;
69 }
70 
CopyFromArray(string * s,const char * base,size_t bytes)71 void CopyFromArray(string* s, const char* base, size_t bytes) {
72   s->assign(base, bytes);
73 }
74 
75 class StringListEncoderImpl : public StringListEncoder {
76  public:
StringListEncoderImpl(string * out)77   explicit StringListEncoderImpl(string* out) : out_(out) {}
78   ~StringListEncoderImpl() override = default;
79 
Append(const protobuf::MessageLite & m)80   void Append(const protobuf::MessageLite& m) override {
81     core::PutVarint32(out_, m.ByteSizeLong());
82     tensorflow::string serialized_message;
83     m.AppendToString(&serialized_message);
84     strings::StrAppend(&rest_, serialized_message);
85   }
86 
Append(const string & s)87   void Append(const string& s) override {
88     core::PutVarint32(out_, s.length());
89     strings::StrAppend(&rest_, s);
90   }
91 
Finalize()92   void Finalize() override { strings::StrAppend(out_, rest_); }
93 
94  private:
95   string* out_;
96   string rest_;
97 };
98 
99 class StringListDecoderImpl : public StringListDecoder {
100  public:
StringListDecoderImpl(const string & in)101   explicit StringListDecoderImpl(const string& in) : reader_(in) {}
102   ~StringListDecoderImpl() override = default;
103 
ReadSizes(std::vector<uint32> * sizes)104   bool ReadSizes(std::vector<uint32>* sizes) override {
105     int64 total = 0;
106     for (auto& size : *sizes) {
107       if (!core::GetVarint32(&reader_, &size)) return false;
108       total += size;
109     }
110     if (total != static_cast<int64>(reader_.size())) {
111       return false;
112     }
113     return true;
114   }
115 
Data(uint32 size)116   const char* Data(uint32 size) override {
117     const char* data = reader_.data();
118     reader_.remove_prefix(size);
119     return data;
120   }
121 
122  private:
123   StringPiece reader_;
124 };
125 
NewStringListEncoder(string * out)126 std::unique_ptr<StringListEncoder> NewStringListEncoder(string* out) {
127   return std::unique_ptr<StringListEncoder>(new StringListEncoderImpl(out));
128 }
129 
NewStringListDecoder(const string & in)130 std::unique_ptr<StringListDecoder> NewStringListDecoder(const string& in) {
131   return std::unique_ptr<StringListDecoder>(new StringListDecoderImpl(in));
132 }
133 
134 #if defined(TENSORFLOW_PROTOBUF_USES_CORD)
AssignRefCounted(StringPiece src,core::RefCounted * obj,Cord * out)135 void AssignRefCounted(StringPiece src, core::RefCounted* obj, Cord* out) {
136   obj->Ref();
137   out->Clear();
138   // Defines a lambda to unref "obj" when Cord deletes this piece of
139   // memory. +[] converts the lambda to a C style function pointer.
140   auto cleanup = +[](absl::string_view donotcare, void* obj) {
141     reinterpret_cast<core::RefCounted*>(obj)->Unref();
142   };
143   out->AppendExternalMemory(absl::string_view(src.data(), src.size()), obj,
144                             cleanup);
145 }
146 
EncodeStringList(const string * strings,int64 n,Cord * out)147 void EncodeStringList(const string* strings, int64 n, Cord* out) {
148   out->Clear();
149   for (int i = 0; i < n; ++i) {
150     ::strings::CordAppendVarint(strings[i].size(), out);
151   }
152   for (int i = 0; i < n; ++i) {
153     out->Append(strings[i]);
154   }
155 }
156 
DecodeStringList(const Cord & src,string * strings,int64 n)157 bool DecodeStringList(const Cord& src, string* strings, int64 n) {
158   std::vector<uint32> sizes(n);
159   CordReader reader(src);
160   int64 tot = 0;
161   for (auto& v : sizes) {
162     if (!::strings::CordReaderReadVarint(&reader, &v)) return false;
163     tot += v;
164   }
165   if (tot != reader.Available()) {
166     return false;
167   }
168   string* data = strings;
169   for (int i = 0; i < n; ++i, ++data) {
170     auto size = sizes[i];
171     if (size > reader.Available()) {
172       return false;
173     }
174     gtl::STLStringResizeUninitialized(data, size);
175     reader.ReadN(size, gtl::string_as_array(data));
176   }
177   return true;
178 }
179 
CopyFromArray(Cord * c,const char * base,size_t bytes)180 void CopyFromArray(Cord* c, const char* base, size_t bytes) {
181   c->CopyFrom(base, bytes);
182 }
183 
184 class CordStringListEncoderImpl : public StringListEncoder {
185  public:
CordStringListEncoderImpl(Cord * out)186   explicit CordStringListEncoderImpl(Cord* out) : out_(out) {}
187   ~CordStringListEncoderImpl() override = default;
188 
Append(const protobuf::MessageLite & m)189   void Append(const protobuf::MessageLite& m) override {
190     ::strings::CordAppendVarint(m.ByteSizeLong(), out_);
191     m.AppendToString(&rest_);
192   }
193 
Append(const string & s)194   void Append(const string& s) override {
195     ::strings::CordAppendVarint(s.length(), out_);
196     rest_.append(s.data(), s.size());
197   }
198 
Finalize()199   void Finalize() override { out_->Append(rest_); }
200 
201  private:
202   Cord* out_;
203   string rest_;
204 };
205 
206 class CordStringListDecoderImpl : public StringListDecoder {
207  public:
CordStringListDecoderImpl(const Cord & in)208   explicit CordStringListDecoderImpl(const Cord& in) : reader_(in) {}
209   ~CordStringListDecoderImpl() override = default;
210 
ReadSizes(std::vector<uint32> * sizes)211   bool ReadSizes(std::vector<uint32>* sizes) override {
212     int64 total = 0;
213     for (auto& size : *sizes) {
214       if (!::strings::CordReaderReadVarint(&reader_, &size)) return false;
215       total += size;
216     }
217     if (total != static_cast<int64>(reader_.Available())) {
218       return false;
219     }
220     return true;
221   }
222 
Data(uint32 size)223   const char* Data(uint32 size) override {
224     tmp_.resize(size);
225     reader_.ReadN(size, tmp_.data());
226     return tmp_.data();
227   }
228 
229  private:
230   CordReader reader_;
231   std::vector<char> tmp_;
232 };
233 
NewStringListEncoder(Cord * out)234 std::unique_ptr<StringListEncoder> NewStringListEncoder(Cord* out) {
235   return std::unique_ptr<StringListEncoder>(new CordStringListEncoderImpl(out));
236 }
237 
NewStringListDecoder(const Cord & in)238 std::unique_ptr<StringListDecoder> NewStringListDecoder(const Cord& in) {
239   return std::unique_ptr<StringListDecoder>(new CordStringListDecoderImpl(in));
240 }
241 
242 #endif  // defined(TENSORFLOW_PROTOBUF_USES_CORD)
243 
244 }  // namespace port
245 }  // namespace tensorflow
246