• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Helper routines for encoding/decoding tensor contents.
17 #ifndef TENSORFLOW_PLATFORM_TENSOR_CODING_H_
18 #define TENSORFLOW_PLATFORM_TENSOR_CODING_H_
19 
20 #include <string>
21 
22 #include "tensorflow/core/platform/platform.h"
23 #include "tensorflow/core/platform/protobuf.h"
24 #include "tensorflow/core/platform/refcount.h"
25 #include "tensorflow/core/platform/stringpiece.h"
26 #include "tensorflow/core/platform/types.h"
27 
28 namespace tensorflow {
29 namespace port {
30 
31 // Store src contents in *out.  If backing memory for src is shared with *out,
32 // will ref obj during the call and will arrange to unref obj when no
33 // longer needed.
34 void AssignRefCounted(StringPiece src, core::RefCounted* obj, std::string* out);
35 
36 // Copy contents of src to dst[0,src.size()-1].
CopyToArray(const std::string & src,char * dst)37 inline void CopyToArray(const std::string& src, char* dst) {
38   memcpy(dst, src.data(), src.size());
39 }
40 
41 // Copy subrange [pos:(pos + n)) from src to dst. If pos >= src.size() the
42 // result is empty. If pos + n > src.size() the subrange [pos, size()) is
43 // copied.
CopySubrangeToArray(const std::string & src,size_t pos,size_t n,char * dst)44 inline void CopySubrangeToArray(const std::string& src, size_t pos, size_t n,
45                                 char* dst) {
46   if (pos >= src.size()) return;
47   memcpy(dst, src.data() + pos, std::min(n, src.size() - pos));
48 }
49 
50 // Store encoding of strings[0..n-1] in *out.
51 void EncodeStringList(const tstring* strings, int64 n, std::string* out);
52 
53 // Decode n strings from src and store in strings[0..n-1].
54 // Returns true if successful, false on parse error.
55 bool DecodeStringList(const std::string& src, tstring* strings, int64 n);
56 
57 // Assigns base[0..bytes-1] to *s
58 void CopyFromArray(std::string* s, const char* base, size_t bytes);
59 
60 // Encodes sequences of strings and serialized protocol buffers into a string.
61 // Normal usage consists of zero or more calls to Append() and a single call to
62 // Finalize().
63 class StringListEncoder {
64  public:
65   virtual ~StringListEncoder() = default;
66 
67   // Encodes the given protocol buffer. This may not be called after Finalize().
68   virtual void Append(const protobuf::MessageLite& m) = 0;
69 
70   // Encodes the given string. This may not be called after Finalize().
71   virtual void Append(const std::string& s) = 0;
72 
73   // Signals end of the encoding process. No other calls are allowed after this.
74   virtual void Finalize() = 0;
75 };
76 
77 // Decodes a string into sequences of strings (which may represent serialized
78 // protocol buffers). Normal usage involves a single call to ReadSizes() in
79 // order to retrieve the length of all the strings in the sequence. For each
80 // size returned a call to Data() is expected and will return the actual
81 // string.
82 class StringListDecoder {
83  public:
84   virtual ~StringListDecoder() = default;
85 
86   // Populates the given vector with the lengths of each string in the sequence
87   // being decoded. Upon returning the vector is guaranteed to contain as many
88   // elements as there are strings in the sequence.
89   virtual bool ReadSizes(std::vector<uint32>* sizes) = 0;
90 
91   // Returns a pointer to the next string in the sequence, then prepares for the
92   // next call by advancing 'size' characters in the sequence.
93   virtual const char* Data(uint32 size) = 0;
94 };
95 
96 std::unique_ptr<StringListEncoder> NewStringListEncoder(string* out);
97 std::unique_ptr<StringListDecoder> NewStringListDecoder(const string& in);
98 
99 #if defined(TENSORFLOW_PROTOBUF_USES_CORD)
100 // Store src contents in *out.  If backing memory for src is shared with *out,
101 // will ref obj during the call and will arrange to unref obj when no
102 // longer needed.
103 void AssignRefCounted(StringPiece src, core::RefCounted* obj, absl::Cord* out);
104 
105 // TODO(kmensah): Macro guard this with a check for Cord support.
CopyToArray(const absl::Cord & src,char * dst)106 inline void CopyToArray(const absl::Cord& src, char* dst) {
107   src.CopyToArray(dst);
108 }
109 
110 // Copy n bytes of src to dst. If pos >= src.size() the result is empty.
111 // If pos + n > src.size() the subrange [pos, size()) is copied.
CopySubrangeToArray(const absl::Cord & src,int64 pos,int64 n,char * dst)112 inline void CopySubrangeToArray(const absl::Cord& src, int64 pos, int64 n,
113                                 char* dst) {
114   src.Subcord(pos, n).CopyToArray(dst);
115 }
116 
117 // Store encoding of strings[0..n-1] in *out.
118 void EncodeStringList(const tstring* strings, int64 n, absl::Cord* out);
119 
120 // Decode n strings from src and store in strings[0..n-1].
121 // Returns true if successful, false on parse error.
122 bool DecodeStringList(const absl::Cord& src, std::string* strings, int64 n);
123 bool DecodeStringList(const absl::Cord& src, tstring* strings, int64 n);
124 
125 // Assigns base[0..bytes-1] to *c
126 void CopyFromArray(absl::Cord* c, const char* base, size_t bytes);
127 
128 std::unique_ptr<StringListEncoder> NewStringListEncoder(absl::Cord* out);
129 std::unique_ptr<StringListDecoder> NewStringListDecoder(const absl::Cord& in);
130 #endif  // defined(TENSORFLOW_PROTOBUF_USES_CORD)
131 
132 }  // namespace port
133 }  // namespace tensorflow
134 
135 #endif  // TENSORFLOW_PLATFORM_TENSOR_CODING_H_
136