• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/aggregation/core/tensor.h"
18 
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "fcp/aggregation/core/datatype.h"
26 #include "fcp/aggregation/core/tensor_shape.h"
27 #include "fcp/base/monitoring.h"
28 
29 #ifndef FCP_NANOLIBC
30 #include "fcp/aggregation/core/tensor.pb.h"
31 #include "google/protobuf/io/coded_stream.h"
32 #include "google/protobuf/io/zero_copy_stream_impl_lite.h"
33 #endif
34 
35 namespace fcp {
36 namespace aggregation {
37 
CheckValid() const38 Status Tensor::CheckValid() const {
39   if (dtype_ == DT_INVALID) {
40     return FCP_STATUS(FAILED_PRECONDITION) << "Invalid Tensor dtype.";
41   }
42 
43   size_t value_size = 0;
44   DTYPE_CASES(dtype_, T, value_size = sizeof(T));
45 
46   // Verify that the storage is consistent with the value size in terms of
47   // size and alignment.
48   FCP_RETURN_IF_ERROR(data_->CheckValid(value_size));
49 
50   // Verify that the total size of the data is consistent with the value type
51   // and the shape.
52   // TODO(team): Implement sparse tensors.
53   if (data_->byte_size() != shape_.NumElements() * value_size) {
54     return FCP_STATUS(FAILED_PRECONDITION)
55            << "TensorData byte_size is inconsistent with the Tensor dtype and "
56               "shape.";
57   }
58 
59   return FCP_STATUS(OK);
60 }
61 
Create(DataType dtype,TensorShape shape,std::unique_ptr<TensorData> data)62 StatusOr<Tensor> Tensor::Create(DataType dtype, TensorShape shape,
63                                 std::unique_ptr<TensorData> data) {
64   Tensor tensor(dtype, std::move(shape), std::move(data));
65   FCP_RETURN_IF_ERROR(tensor.CheckValid());
66   return std::move(tensor);
67 }
68 
69 #ifndef FCP_NANOLIBC
70 
71 // SerializedContentNumericData implements TensorData by wrapping the serialized
72 // content string and using it directly as a backing storage. This relies on the
73 // fact that the serialized content uses the same layout as in memory
74 // representation if we assume that this code runs on a little-endian system.
75 // TODO(team): Ensure little-endianness.
76 class SerializedContentNumericData : public TensorData {
77  public:
SerializedContentNumericData(std::string content)78   explicit SerializedContentNumericData(std::string content)
79       : content_(std::move(content)) {}
80   ~SerializedContentNumericData() override = default;
81 
82   // Implementation of TensorData methods.
byte_size() const83   size_t byte_size() const override { return content_.size(); }
data() const84   const void* data() const override { return content_.data(); }
85 
86  private:
87   std::string content_;
88 };
89 
90 // Converts the tensor data to a serialized blob saved as the content field
91 // in the TensorProto.  The `num` argument is needed in case the number of
92 // values can't be derived from the TensorData size.
93 template <typename T>
EncodeContent(const TensorData * data,size_t num)94 std::string EncodeContent(const TensorData* data, size_t num) {
95   // Default encoding of tensor data, valid only for numeric data types.
96   return std::string(reinterpret_cast<const char*>(data->data()),
97                      data->byte_size());
98 }
99 
100 // Specialization of EncodeContent for DT_STRING data type.
101 template <>
EncodeContent(const TensorData * data,size_t num)102 std::string EncodeContent<string_view>(const TensorData* data, size_t num) {
103   std::string content;
104   google::protobuf::io::StringOutputStream out(&content);
105   google::protobuf::io::CodedOutputStream coded_out(&out);
106   auto ptr = reinterpret_cast<const string_view*>(data->data());
107 
108   // Write all string sizes as Varint64.
109   for (size_t i = 0; i < num; ++i) {
110     coded_out.WriteVarint64(ptr[i].size());
111   }
112 
113   // Write all string contents.
114   for (size_t i = 0; i < num; ++i) {
115     coded_out.WriteRaw(ptr[i].data(), static_cast<int>(ptr[i].size()));
116   }
117 
118   return content;
119 }
120 
121 // Converts the serialized TensorData content stored in TensorProto to an
122 // instance of TensorData.   The `num` argument is needed in case the number of
123 // values can't be derived from the content size.
124 template <typename T>
DecodeContent(std::string content,size_t num)125 StatusOr<std::unique_ptr<TensorData>> DecodeContent(std::string content,
126                                                     size_t num) {
127   // Default decoding of tensor data, valid only for numeric data types.
128   return std::make_unique<SerializedContentNumericData>(std::move(content));
129 }
130 
131 // Wraps the serialized TensorData content stored and surfaces it as pointer
132 // string_view values pointing back into the wrapped content. This class is
133 // be created and initialized from within the DecodeContent<string_view>().
134 class SerializedContentStringData : public TensorData {
135  public:
136   SerializedContentStringData() = default;
137   ~SerializedContentStringData() override = default;
138 
139   // Implementation of TensorData methods.
byte_size() const140   size_t byte_size() const override {
141     return string_views_.size() * sizeof(string_view);
142   }
data() const143   const void* data() const override { return string_views_.data(); }
144 
145   // Initializes the string_view values to point to the strings embedded in the
146   // content.
Initialize(std::string content,size_t num)147   Status Initialize(std::string content, size_t num) {
148     content_ = std::move(content);
149     google::protobuf::io::ArrayInputStream input(content_.data(),
150                                        static_cast<int>(content_.size()));
151     google::protobuf::io::CodedInputStream coded_input(&input);
152 
153     // The pointer to the first string in the content is unknown at this point
154     // because there are multiple string sizes at the front, all encoded as
155     // VarInts.  To avoid using the extra storage this code reuses the same
156     // string_views_ vector in the two passes. First it initializes the data
157     // pointers to start with the beginning of the content. Then in the second
158     // pass it shifts all data pointers to where strings actually begin in the
159     // content.
160     string_views_.resize(num);
161     size_t cumulative_size = 0;
162 
163     // The first pass reads the string sizes;
164     for (size_t i = 0; i < num; ++i) {
165       size_t size;
166       if (!coded_input.ReadVarint64(&size)) {
167         return FCP_STATUS(INVALID_ARGUMENT)
168                << "Expected to read " << num
169                << " string values but the input tensor content doesn't contain "
170                   "a size for the "
171                << i << "th string. The content size is " << content_.size()
172                << " bytes.";
173       }
174       string_views_[i] = string_view(content_.data() + cumulative_size, size);
175       cumulative_size += size;
176     }
177 
178     // The current position in the input stream after reading all the string
179     // sizes. The input stream must be at the beginning of the first string now.
180     size_t offset = coded_input.CurrentPosition();
181 
182     // Verify that the content is large enough.
183     if (content_.size() < offset + cumulative_size) {
184       return FCP_STATUS(INVALID_ARGUMENT)
185              << "Input tensor content has insufficient size to store " << num
186              << " string values. The content size is " << content_.size()
187              << " bytes, but " << offset + cumulative_size
188              << " bytes are required.";
189     }
190 
191     // The second pass offsets string_view pointers so that the first one points
192     // to the first string embedded in the content, then all others are shifted
193     // by the same offset to point to subsequent strings.
194     for (size_t i = 0; i < num; ++i) {
195       string_views_[i] = string_view(string_views_[i].data() + offset,
196                                      string_views_[i].size());
197     }
198 
199     return FCP_STATUS(OK);
200   }
201 
202  private:
203   std::string content_;
204   std::vector<string_view> string_views_;
205 };
206 
207 template <>
DecodeContent(std::string content,size_t num)208 StatusOr<std::unique_ptr<TensorData>> DecodeContent<string_view>(
209     std::string content, size_t num) {
210   auto tensor_data = std::make_unique<SerializedContentStringData>();
211   FCP_RETURN_IF_ERROR(tensor_data->Initialize(std::move(content), num));
212   return tensor_data;
213 }
214 
FromProto(const TensorProto & tensor_proto)215 StatusOr<Tensor> Tensor::FromProto(const TensorProto& tensor_proto) {
216   FCP_ASSIGN_OR_RETURN(TensorShape shape,
217                        TensorShape::FromProto(tensor_proto.shape()));
218   // TODO(team): The num_values is valid only for dense tensors.
219   size_t num_values = shape.NumElements();
220   StatusOr<std::unique_ptr<TensorData>> data;
221   DTYPE_CASES(tensor_proto.dtype(), T,
222               data = DecodeContent<T>(tensor_proto.content(), num_values));
223   FCP_RETURN_IF_ERROR(data);
224   return Create(tensor_proto.dtype(), std::move(shape),
225                 std::move(data).value());
226 }
227 
FromProto(TensorProto && tensor_proto)228 StatusOr<Tensor> Tensor::FromProto(TensorProto&& tensor_proto) {
229   FCP_ASSIGN_OR_RETURN(TensorShape shape,
230                        TensorShape::FromProto(tensor_proto.shape()));
231   // TODO(team): The num_values is valid only for dense tensors.
232   size_t num_values = shape.NumElements();
233   std::string content = std::move(*tensor_proto.mutable_content());
234   StatusOr<std::unique_ptr<TensorData>> data;
235   DTYPE_CASES(tensor_proto.dtype(), T,
236               data = DecodeContent<T>(std::move(content), num_values));
237   FCP_RETURN_IF_ERROR(data);
238   return Create(tensor_proto.dtype(), std::move(shape),
239                 std::move(data).value());
240 }
241 
ToProto() const242 TensorProto Tensor::ToProto() const {
243   TensorProto tensor_proto;
244   tensor_proto.set_dtype(dtype_);
245   *(tensor_proto.mutable_shape()) = shape_.ToProto();
246   // TODO(team): The num_values is valid only for dense tensors.
247   size_t num_values = shape_.NumElements();
248   std::string content;
249   DTYPE_CASES(dtype_, T, content = EncodeContent<T>(data_.get(), num_values));
250   *(tensor_proto.mutable_content()) = std::move(content);
251   return tensor_proto;
252 }
253 
254 #endif  // FCP_NANOLIBC
255 
256 }  // namespace aggregation
257 }  // namespace fcp
258