1 /*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/aggregation/core/tensor.h"
18
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "fcp/aggregation/core/datatype.h"
26 #include "fcp/aggregation/core/tensor_shape.h"
27 #include "fcp/base/monitoring.h"
28
29 #ifndef FCP_NANOLIBC
30 #include "fcp/aggregation/core/tensor.pb.h"
31 #include "google/protobuf/io/coded_stream.h"
32 #include "google/protobuf/io/zero_copy_stream_impl_lite.h"
33 #endif
34
35 namespace fcp {
36 namespace aggregation {
37
CheckValid() const38 Status Tensor::CheckValid() const {
39 if (dtype_ == DT_INVALID) {
40 return FCP_STATUS(FAILED_PRECONDITION) << "Invalid Tensor dtype.";
41 }
42
43 size_t value_size = 0;
44 DTYPE_CASES(dtype_, T, value_size = sizeof(T));
45
46 // Verify that the storage is consistent with the value size in terms of
47 // size and alignment.
48 FCP_RETURN_IF_ERROR(data_->CheckValid(value_size));
49
50 // Verify that the total size of the data is consistent with the value type
51 // and the shape.
52 // TODO(team): Implement sparse tensors.
53 if (data_->byte_size() != shape_.NumElements() * value_size) {
54 return FCP_STATUS(FAILED_PRECONDITION)
55 << "TensorData byte_size is inconsistent with the Tensor dtype and "
56 "shape.";
57 }
58
59 return FCP_STATUS(OK);
60 }
61
Create(DataType dtype,TensorShape shape,std::unique_ptr<TensorData> data)62 StatusOr<Tensor> Tensor::Create(DataType dtype, TensorShape shape,
63 std::unique_ptr<TensorData> data) {
64 Tensor tensor(dtype, std::move(shape), std::move(data));
65 FCP_RETURN_IF_ERROR(tensor.CheckValid());
66 return std::move(tensor);
67 }
68
69 #ifndef FCP_NANOLIBC
70
71 // SerializedContentNumericData implements TensorData by wrapping the serialized
72 // content string and using it directly as a backing storage. This relies on the
73 // fact that the serialized content uses the same layout as in memory
74 // representation if we assume that this code runs on a little-endian system.
75 // TODO(team): Ensure little-endianness.
76 class SerializedContentNumericData : public TensorData {
77 public:
SerializedContentNumericData(std::string content)78 explicit SerializedContentNumericData(std::string content)
79 : content_(std::move(content)) {}
80 ~SerializedContentNumericData() override = default;
81
82 // Implementation of TensorData methods.
byte_size() const83 size_t byte_size() const override { return content_.size(); }
data() const84 const void* data() const override { return content_.data(); }
85
86 private:
87 std::string content_;
88 };
89
90 // Converts the tensor data to a serialized blob saved as the content field
91 // in the TensorProto. The `num` argument is needed in case the number of
92 // values can't be derived from the TensorData size.
93 template <typename T>
EncodeContent(const TensorData * data,size_t num)94 std::string EncodeContent(const TensorData* data, size_t num) {
95 // Default encoding of tensor data, valid only for numeric data types.
96 return std::string(reinterpret_cast<const char*>(data->data()),
97 data->byte_size());
98 }
99
100 // Specialization of EncodeContent for DT_STRING data type.
101 template <>
EncodeContent(const TensorData * data,size_t num)102 std::string EncodeContent<string_view>(const TensorData* data, size_t num) {
103 std::string content;
104 google::protobuf::io::StringOutputStream out(&content);
105 google::protobuf::io::CodedOutputStream coded_out(&out);
106 auto ptr = reinterpret_cast<const string_view*>(data->data());
107
108 // Write all string sizes as Varint64.
109 for (size_t i = 0; i < num; ++i) {
110 coded_out.WriteVarint64(ptr[i].size());
111 }
112
113 // Write all string contents.
114 for (size_t i = 0; i < num; ++i) {
115 coded_out.WriteRaw(ptr[i].data(), static_cast<int>(ptr[i].size()));
116 }
117
118 return content;
119 }
120
121 // Converts the serialized TensorData content stored in TensorProto to an
122 // instance of TensorData. The `num` argument is needed in case the number of
123 // values can't be derived from the content size.
124 template <typename T>
DecodeContent(std::string content,size_t num)125 StatusOr<std::unique_ptr<TensorData>> DecodeContent(std::string content,
126 size_t num) {
127 // Default decoding of tensor data, valid only for numeric data types.
128 return std::make_unique<SerializedContentNumericData>(std::move(content));
129 }
130
131 // Wraps the serialized TensorData content stored and surfaces it as pointer
132 // string_view values pointing back into the wrapped content. This class is
133 // be created and initialized from within the DecodeContent<string_view>().
134 class SerializedContentStringData : public TensorData {
135 public:
136 SerializedContentStringData() = default;
137 ~SerializedContentStringData() override = default;
138
139 // Implementation of TensorData methods.
byte_size() const140 size_t byte_size() const override {
141 return string_views_.size() * sizeof(string_view);
142 }
data() const143 const void* data() const override { return string_views_.data(); }
144
145 // Initializes the string_view values to point to the strings embedded in the
146 // content.
Initialize(std::string content,size_t num)147 Status Initialize(std::string content, size_t num) {
148 content_ = std::move(content);
149 google::protobuf::io::ArrayInputStream input(content_.data(),
150 static_cast<int>(content_.size()));
151 google::protobuf::io::CodedInputStream coded_input(&input);
152
153 // The pointer to the first string in the content is unknown at this point
154 // because there are multiple string sizes at the front, all encoded as
155 // VarInts. To avoid using the extra storage this code reuses the same
156 // string_views_ vector in the two passes. First it initializes the data
157 // pointers to start with the beginning of the content. Then in the second
158 // pass it shifts all data pointers to where strings actually begin in the
159 // content.
160 string_views_.resize(num);
161 size_t cumulative_size = 0;
162
163 // The first pass reads the string sizes;
164 for (size_t i = 0; i < num; ++i) {
165 size_t size;
166 if (!coded_input.ReadVarint64(&size)) {
167 return FCP_STATUS(INVALID_ARGUMENT)
168 << "Expected to read " << num
169 << " string values but the input tensor content doesn't contain "
170 "a size for the "
171 << i << "th string. The content size is " << content_.size()
172 << " bytes.";
173 }
174 string_views_[i] = string_view(content_.data() + cumulative_size, size);
175 cumulative_size += size;
176 }
177
178 // The current position in the input stream after reading all the string
179 // sizes. The input stream must be at the beginning of the first string now.
180 size_t offset = coded_input.CurrentPosition();
181
182 // Verify that the content is large enough.
183 if (content_.size() < offset + cumulative_size) {
184 return FCP_STATUS(INVALID_ARGUMENT)
185 << "Input tensor content has insufficient size to store " << num
186 << " string values. The content size is " << content_.size()
187 << " bytes, but " << offset + cumulative_size
188 << " bytes are required.";
189 }
190
191 // The second pass offsets string_view pointers so that the first one points
192 // to the first string embedded in the content, then all others are shifted
193 // by the same offset to point to subsequent strings.
194 for (size_t i = 0; i < num; ++i) {
195 string_views_[i] = string_view(string_views_[i].data() + offset,
196 string_views_[i].size());
197 }
198
199 return FCP_STATUS(OK);
200 }
201
202 private:
203 std::string content_;
204 std::vector<string_view> string_views_;
205 };
206
207 template <>
DecodeContent(std::string content,size_t num)208 StatusOr<std::unique_ptr<TensorData>> DecodeContent<string_view>(
209 std::string content, size_t num) {
210 auto tensor_data = std::make_unique<SerializedContentStringData>();
211 FCP_RETURN_IF_ERROR(tensor_data->Initialize(std::move(content), num));
212 return tensor_data;
213 }
214
FromProto(const TensorProto & tensor_proto)215 StatusOr<Tensor> Tensor::FromProto(const TensorProto& tensor_proto) {
216 FCP_ASSIGN_OR_RETURN(TensorShape shape,
217 TensorShape::FromProto(tensor_proto.shape()));
218 // TODO(team): The num_values is valid only for dense tensors.
219 size_t num_values = shape.NumElements();
220 StatusOr<std::unique_ptr<TensorData>> data;
221 DTYPE_CASES(tensor_proto.dtype(), T,
222 data = DecodeContent<T>(tensor_proto.content(), num_values));
223 FCP_RETURN_IF_ERROR(data);
224 return Create(tensor_proto.dtype(), std::move(shape),
225 std::move(data).value());
226 }
227
FromProto(TensorProto && tensor_proto)228 StatusOr<Tensor> Tensor::FromProto(TensorProto&& tensor_proto) {
229 FCP_ASSIGN_OR_RETURN(TensorShape shape,
230 TensorShape::FromProto(tensor_proto.shape()));
231 // TODO(team): The num_values is valid only for dense tensors.
232 size_t num_values = shape.NumElements();
233 std::string content = std::move(*tensor_proto.mutable_content());
234 StatusOr<std::unique_ptr<TensorData>> data;
235 DTYPE_CASES(tensor_proto.dtype(), T,
236 data = DecodeContent<T>(std::move(content), num_values));
237 FCP_RETURN_IF_ERROR(data);
238 return Create(tensor_proto.dtype(), std::move(shape),
239 std::move(data).value());
240 }
241
ToProto() const242 TensorProto Tensor::ToProto() const {
243 TensorProto tensor_proto;
244 tensor_proto.set_dtype(dtype_);
245 *(tensor_proto.mutable_shape()) = shape_.ToProto();
246 // TODO(team): The num_values is valid only for dense tensors.
247 size_t num_values = shape_.NumElements();
248 std::string content;
249 DTYPE_CASES(dtype_, T, content = EncodeContent<T>(data_.get(), num_values));
250 *(tensor_proto.mutable_content()) = std::move(content);
251 return tensor_proto;
252 }
253
254 #endif // FCP_NANOLIBC
255
256 } // namespace aggregation
257 } // namespace fcp
258