1 // Copyright 2018 The Amber Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/buffer.h"
16
17 #include <algorithm>
18 #include <cassert>
19 #include <cmath>
20 #include <cstring>
21
22 namespace amber {
23 namespace {
24
25 // Return sign value of 32 bits float.
FloatSign(const uint32_t hex_float)26 uint16_t FloatSign(const uint32_t hex_float) {
27 return static_cast<uint16_t>(hex_float >> 31U);
28 }
29
30 // Return exponent value of 32 bits float.
FloatExponent(const uint32_t hex_float)31 uint16_t FloatExponent(const uint32_t hex_float) {
32 uint32_t exponent = ((hex_float >> 23U) & ((1U << 8U) - 1U)) - 112U;
33 const uint32_t half_exponent_mask = (1U << 5U) - 1U;
34 assert(((exponent & ~half_exponent_mask) == 0U) && "Float exponent overflow");
35 return static_cast<uint16_t>(exponent & half_exponent_mask);
36 }
37
38 // Return mantissa value of 32 bits float. Note that mantissa for 32
39 // bits float is 23 bits and this method must return uint32_t.
FloatMantissa(const uint32_t hex_float)40 uint32_t FloatMantissa(const uint32_t hex_float) {
41 return static_cast<uint32_t>(hex_float & ((1U << 23U) - 1U));
42 }
43
44 // Convert 32 bits float |value| to 16 bits float based on IEEE-754.
FloatToHexFloat16(const float value)45 uint16_t FloatToHexFloat16(const float value) {
46 const uint32_t* hex = reinterpret_cast<const uint32_t*>(&value);
47 return static_cast<uint16_t>(
48 static_cast<uint16_t>(FloatSign(*hex) << 15U) |
49 static_cast<uint16_t>(FloatExponent(*hex) << 10U) |
50 static_cast<uint16_t>(FloatMantissa(*hex) >> 13U));
51 }
52
53 template <typename T>
ValuesAs(uint8_t * values)54 T* ValuesAs(uint8_t* values) {
55 return reinterpret_cast<T*>(values);
56 }
57
58 template <typename T>
Sub(const uint8_t * buf1,const uint8_t * buf2)59 double Sub(const uint8_t* buf1, const uint8_t* buf2) {
60 return static_cast<double>(*reinterpret_cast<const T*>(buf1) -
61 *reinterpret_cast<const T*>(buf2));
62 }
63
CalculateDiff(const Format::Segment * seg,const uint8_t * buf1,const uint8_t * buf2)64 double CalculateDiff(const Format::Segment* seg,
65 const uint8_t* buf1,
66 const uint8_t* buf2) {
67 FormatMode mode = seg->GetFormatMode();
68 uint32_t num_bits = seg->GetNumBits();
69 if (type::Type::IsInt8(mode, num_bits))
70 return Sub<int8_t>(buf1, buf2);
71 if (type::Type::IsInt16(mode, num_bits))
72 return Sub<int16_t>(buf1, buf2);
73 if (type::Type::IsInt32(mode, num_bits))
74 return Sub<int32_t>(buf1, buf2);
75 if (type::Type::IsInt64(mode, num_bits))
76 return Sub<int64_t>(buf1, buf2);
77 if (type::Type::IsUint8(mode, num_bits))
78 return Sub<uint8_t>(buf1, buf2);
79 if (type::Type::IsUint16(mode, num_bits))
80 return Sub<uint16_t>(buf1, buf2);
81 if (type::Type::IsUint32(mode, num_bits))
82 return Sub<uint32_t>(buf1, buf2);
83 if (type::Type::IsUint64(mode, num_bits))
84 return Sub<uint64_t>(buf1, buf2);
85 // TODO(dsinclair): Handle float16 ...
86 if (type::Type::IsFloat16(mode, num_bits)) {
87 assert(false && "Float16 suppport not implemented");
88 return 0.0;
89 }
90 if (type::Type::IsFloat32(mode, num_bits))
91 return Sub<float>(buf1, buf2);
92 if (type::Type::IsFloat64(mode, num_bits))
93 return Sub<double>(buf1, buf2);
94
95 assert(false && "NOTREACHED");
96 return 0.0;
97 }
98
99 } // namespace
100
101 Buffer::Buffer() = default;
102
Buffer(BufferType type)103 Buffer::Buffer(BufferType type) : buffer_type_(type) {}
104
105 Buffer::~Buffer() = default;
106
CopyTo(Buffer * buffer) const107 Result Buffer::CopyTo(Buffer* buffer) const {
108 if (buffer->width_ != width_)
109 return Result("Buffer::CopyBaseFields() buffers have a different width");
110 if (buffer->height_ != height_)
111 return Result("Buffer::CopyBaseFields() buffers have a different height");
112 if (buffer->element_count_ != element_count_)
113 return Result("Buffer::CopyBaseFields() buffers have a different size");
114 buffer->bytes_ = bytes_;
115 return {};
116 }
117
IsEqual(Buffer * buffer) const118 Result Buffer::IsEqual(Buffer* buffer) const {
119 auto result = CheckCompability(buffer);
120 if (!result.IsSuccess())
121 return result;
122
123 uint32_t num_different = 0;
124 uint32_t first_different_index = 0;
125 uint8_t first_different_left = 0;
126 uint8_t first_different_right = 0;
127 for (uint32_t i = 0; i < bytes_.size(); ++i) {
128 if (bytes_[i] != buffer->bytes_[i]) {
129 if (num_different == 0) {
130 first_different_index = i;
131 first_different_left = bytes_[i];
132 first_different_right = buffer->bytes_[i];
133 }
134 num_different++;
135 }
136 }
137
138 if (num_different) {
139 return Result{"Buffers have different values. " +
140 std::to_string(num_different) +
141 " values differed, first difference at byte " +
142 std::to_string(first_different_index) + " values " +
143 std::to_string(first_different_left) +
144 " != " + std::to_string(first_different_right)};
145 }
146
147 return {};
148 }
149
CalculateDiffs(const Buffer * buffer) const150 std::vector<double> Buffer::CalculateDiffs(const Buffer* buffer) const {
151 std::vector<double> diffs;
152
153 auto* buf_1_ptr = GetValues<uint8_t>();
154 auto* buf_2_ptr = buffer->GetValues<uint8_t>();
155 const auto& segments = format_->GetSegments();
156 for (size_t i = 0; i < ElementCount(); ++i) {
157 for (const auto& seg : segments) {
158 if (seg.IsPadding()) {
159 buf_1_ptr += seg.PaddingBytes();
160 buf_2_ptr += seg.PaddingBytes();
161 continue;
162 }
163
164 diffs.push_back(CalculateDiff(&seg, buf_1_ptr, buf_2_ptr));
165
166 buf_1_ptr += seg.SizeInBytes();
167 buf_2_ptr += seg.SizeInBytes();
168 }
169 }
170
171 return diffs;
172 }
173
CheckCompability(Buffer * buffer) const174 Result Buffer::CheckCompability(Buffer* buffer) const {
175 if (!buffer->format_->Equal(format_))
176 return Result{"Buffers have a different format"};
177 if (buffer->element_count_ != element_count_)
178 return Result{"Buffers have a different size"};
179 if (buffer->width_ != width_)
180 return Result{"Buffers have a different width"};
181 if (buffer->height_ != height_)
182 return Result{"Buffers have a different height"};
183 if (buffer->ValueCount() != ValueCount())
184 return Result{"Buffers have a different number of values"};
185
186 return {};
187 }
188
CompareRMSE(Buffer * buffer,float tolerance) const189 Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const {
190 auto result = CheckCompability(buffer);
191 if (!result.IsSuccess())
192 return result;
193
194 auto diffs = CalculateDiffs(buffer);
195 double sum = 0.0;
196 for (const auto val : diffs)
197 sum += (val * val);
198
199 sum /= static_cast<double>(diffs.size());
200 double rmse = std::sqrt(sum);
201 if (rmse > static_cast<double>(tolerance)) {
202 return Result("Root Mean Square Error of " + std::to_string(rmse) +
203 " is greater than tolerance of " + std::to_string(tolerance));
204 }
205
206 return {};
207 }
208
GetHistogramForChannel(uint32_t channel,uint32_t num_bins) const209 std::vector<uint64_t> Buffer::GetHistogramForChannel(uint32_t channel,
210 uint32_t num_bins) const {
211 assert(num_bins == 256);
212 std::vector<uint64_t> bins(num_bins, 0);
213 auto* buf_ptr = GetValues<uint8_t>();
214 auto num_channels = format_->InputNeededPerElement();
215 uint32_t channel_id = 0;
216
217 for (size_t i = 0; i < ElementCount(); ++i) {
218 for (const auto& seg : format_->GetSegments()) {
219 if (seg.IsPadding()) {
220 buf_ptr += seg.PaddingBytes();
221 continue;
222 }
223 if (channel_id == channel) {
224 assert(type::Type::IsUint8(seg.GetFormatMode(), seg.GetNumBits()));
225 const auto bin = *reinterpret_cast<const uint8_t*>(buf_ptr);
226 bins[bin]++;
227 }
228 buf_ptr += seg.SizeInBytes();
229 channel_id = (channel_id + 1) % num_channels;
230 }
231 }
232
233 return bins;
234 }
235
CompareHistogramEMD(Buffer * buffer,float tolerance) const236 Result Buffer::CompareHistogramEMD(Buffer* buffer, float tolerance) const {
237 auto result = CheckCompability(buffer);
238 if (!result.IsSuccess())
239 return result;
240
241 const int num_bins = 256;
242 auto num_channels = format_->InputNeededPerElement();
243 for (auto segment : format_->GetSegments()) {
244 if (!type::Type::IsUint8(segment.GetFormatMode(), segment.GetNumBits()) ||
245 num_channels != 4) {
246 return Result(
247 "EMD comparison only supports 8bit unorm format with four channels.");
248 }
249 }
250
251 std::vector<std::vector<uint64_t>> histogram1;
252 std::vector<std::vector<uint64_t>> histogram2;
253 for (uint32_t c = 0; c < num_channels; ++c) {
254 histogram1.push_back(GetHistogramForChannel(c, num_bins));
255 histogram2.push_back(buffer->GetHistogramForChannel(c, num_bins));
256 }
257
258 // Earth movers's distance: Calculate the minimal cost of moving "earth" to
259 // transform the first histogram into the second, where each bin of the
260 // histogram can be thought of as a column of units of earth. The cost is the
261 // amount of earth moved times the distance carried (the distance is the
262 // number of adjacent bins over which the earth is carried). Calculate this
263 // using the cumulative difference of the bins, which works as long as both
264 // histograms have the same amount of earth. Sum the absolute values of the
265 // cumulative difference to get the final cost of how much (and how far) the
266 // earth was moved.
267 double max_emd = 0;
268
269 for (uint32_t c = 0; c < num_channels; ++c) {
270 double diff_total = 0;
271 double diff_accum = 0;
272
273 for (size_t i = 0; i < num_bins; ++i) {
274 double hist_normalized_1 =
275 static_cast<double>(histogram1[c][i]) / element_count_;
276 double hist_normalized_2 =
277 static_cast<double>(histogram2[c][i]) / buffer->element_count_;
278 diff_accum += hist_normalized_1 - hist_normalized_2;
279 diff_total += fabs(diff_accum);
280 }
281 // Normalize to range 0..1
282 double emd = diff_total / num_bins;
283 max_emd = std::max(max_emd, emd);
284 }
285
286 if (max_emd > static_cast<double>(tolerance)) {
287 return Result("Histogram EMD value of " + std::to_string(max_emd) +
288 " is greater than tolerance of " + std::to_string(tolerance));
289 }
290
291 return {};
292 }
293
SetData(const std::vector<Value> & data)294 Result Buffer::SetData(const std::vector<Value>& data) {
295 return SetDataWithOffset(data, 0);
296 }
297
RecalculateMaxSizeInBytes(const std::vector<Value> & data,uint32_t offset)298 Result Buffer::RecalculateMaxSizeInBytes(const std::vector<Value>& data,
299 uint32_t offset) {
300 // Multiply by the input needed because the value count will use the needed
301 // input as the multiplier
302 uint32_t value_count =
303 ((offset / format_->SizeInBytes()) * format_->InputNeededPerElement()) +
304 static_cast<uint32_t>(data.size());
305 uint32_t element_count = value_count;
306 if (!format_->IsPacked()) {
307 // This divides by the needed input values, not the values per element.
308 // The assumption being the values coming in are read from the input,
309 // where components are specified. The needed values maybe less then the
310 // values per element.
311 element_count = value_count / format_->InputNeededPerElement();
312 }
313 if (GetMaxSizeInBytes() < element_count * format_->SizeInBytes())
314 SetMaxSizeInBytes(element_count * format_->SizeInBytes());
315 return {};
316 }
317
SetDataWithOffset(const std::vector<Value> & data,uint32_t offset)318 Result Buffer::SetDataWithOffset(const std::vector<Value>& data,
319 uint32_t offset) {
320 // Multiply by the input needed because the value count will use the needed
321 // input as the multiplier
322 uint32_t value_count =
323 ((offset / format_->SizeInBytes()) * format_->InputNeededPerElement()) +
324 static_cast<uint32_t>(data.size());
325
326 // The buffer should only be resized to become bigger. This means that if a
327 // command was run to set the buffer size we'll honour that size until a
328 // request happens to make the buffer bigger.
329 if (value_count > ValueCount())
330 SetValueCount(value_count);
331
332 // Even if the value count doesn't change, the buffer is still resized because
333 // this maybe the first time data is set into the buffer.
334 bytes_.resize(GetSizeInBytes());
335
336 // Set the new memory to zero to be on the safe side.
337 uint32_t new_space =
338 (static_cast<uint32_t>(data.size()) / format_->InputNeededPerElement()) *
339 format_->SizeInBytes();
340 assert(new_space + offset <= GetSizeInBytes());
341
342 if (new_space > 0)
343 memset(bytes_.data() + offset, 0, new_space);
344
345 if (data.size() > (ElementCount() * format_->InputNeededPerElement()))
346 return Result("Mismatched number of items in buffer");
347
348 uint8_t* ptr = bytes_.data() + offset;
349 const auto& segments = format_->GetSegments();
350 for (uint32_t i = 0; i < data.size();) {
351 for (const auto& seg : segments) {
352 if (seg.IsPadding()) {
353 ptr += seg.PaddingBytes();
354 continue;
355 }
356
357 Value v = data[i++];
358 ptr += WriteValueFromComponent(v, seg.GetFormatMode(), seg.GetNumBits(),
359 ptr);
360 if (i >= data.size())
361 break;
362 }
363 }
364 return {};
365 }
366
WriteValueFromComponent(const Value & value,FormatMode mode,uint32_t num_bits,uint8_t * ptr)367 uint32_t Buffer::WriteValueFromComponent(const Value& value,
368 FormatMode mode,
369 uint32_t num_bits,
370 uint8_t* ptr) {
371 if (type::Type::IsInt8(mode, num_bits)) {
372 *(ValuesAs<int8_t>(ptr)) = value.AsInt8();
373 return sizeof(int8_t);
374 }
375 if (type::Type::IsInt16(mode, num_bits)) {
376 *(ValuesAs<int16_t>(ptr)) = value.AsInt16();
377 return sizeof(int16_t);
378 }
379 if (type::Type::IsInt32(mode, num_bits)) {
380 *(ValuesAs<int32_t>(ptr)) = value.AsInt32();
381 return sizeof(int32_t);
382 }
383 if (type::Type::IsInt64(mode, num_bits)) {
384 *(ValuesAs<int64_t>(ptr)) = value.AsInt64();
385 return sizeof(int64_t);
386 }
387 if (type::Type::IsUint8(mode, num_bits)) {
388 *(ValuesAs<uint8_t>(ptr)) = value.AsUint8();
389 return sizeof(uint8_t);
390 }
391 if (type::Type::IsUint16(mode, num_bits)) {
392 *(ValuesAs<uint16_t>(ptr)) = value.AsUint16();
393 return sizeof(uint16_t);
394 }
395 if (type::Type::IsUint32(mode, num_bits)) {
396 *(ValuesAs<uint32_t>(ptr)) = value.AsUint32();
397 return sizeof(uint32_t);
398 }
399 if (type::Type::IsUint64(mode, num_bits)) {
400 *(ValuesAs<uint64_t>(ptr)) = value.AsUint64();
401 return sizeof(uint64_t);
402 }
403 if (type::Type::IsFloat16(mode, num_bits)) {
404 *(ValuesAs<uint16_t>(ptr)) = FloatToHexFloat16(value.AsFloat());
405 return sizeof(uint16_t);
406 }
407 if (type::Type::IsFloat32(mode, num_bits)) {
408 *(ValuesAs<float>(ptr)) = value.AsFloat();
409 return sizeof(float);
410 }
411 if (type::Type::IsFloat64(mode, num_bits)) {
412 *(ValuesAs<double>(ptr)) = value.AsDouble();
413 return sizeof(double);
414 }
415
416 // The float 10 and float 11 sizes are only used in PACKED formats.
417 assert(false && "Not reached");
418 return 0;
419 }
420
SetSizeInElements(uint32_t element_count)421 void Buffer::SetSizeInElements(uint32_t element_count) {
422 element_count_ = element_count;
423 bytes_.resize(element_count * format_->SizeInBytes());
424 }
425
SetSizeInBytes(uint32_t size_in_bytes)426 void Buffer::SetSizeInBytes(uint32_t size_in_bytes) {
427 assert(size_in_bytes % format_->SizeInBytes() == 0);
428 element_count_ = size_in_bytes / format_->SizeInBytes();
429 bytes_.resize(size_in_bytes);
430 }
431
SetMaxSizeInBytes(uint32_t max_size_in_bytes)432 void Buffer::SetMaxSizeInBytes(uint32_t max_size_in_bytes) {
433 max_size_in_bytes_ = max_size_in_bytes;
434 }
435
GetMaxSizeInBytes() const436 uint32_t Buffer::GetMaxSizeInBytes() const {
437 if (max_size_in_bytes_ != 0)
438 return max_size_in_bytes_;
439 else
440 return GetSizeInBytes();
441 }
442
SetDataFromBuffer(const Buffer * src,uint32_t offset)443 Result Buffer::SetDataFromBuffer(const Buffer* src, uint32_t offset) {
444 if (bytes_.size() < offset + src->bytes_.size())
445 bytes_.resize(offset + src->bytes_.size());
446
447 std::memcpy(bytes_.data() + offset, src->bytes_.data(), src->bytes_.size());
448 element_count_ =
449 static_cast<uint32_t>(bytes_.size()) / format_->SizeInBytes();
450 return {};
451 }
452
453 } // namespace amber
454