• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Amber Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/buffer.h"
16 
17 #include <algorithm>
18 #include <cassert>
19 #include <cmath>
20 #include <cstring>
21 
22 namespace amber {
23 namespace {
24 
25 // Return sign value of 32 bits float.
FloatSign(const uint32_t hex_float)26 uint16_t FloatSign(const uint32_t hex_float) {
27   return static_cast<uint16_t>(hex_float >> 31U);
28 }
29 
30 // Return exponent value of 32 bits float.
FloatExponent(const uint32_t hex_float)31 uint16_t FloatExponent(const uint32_t hex_float) {
32   uint32_t exponent = ((hex_float >> 23U) & ((1U << 8U) - 1U)) - 112U;
33   const uint32_t half_exponent_mask = (1U << 5U) - 1U;
34   assert(((exponent & ~half_exponent_mask) == 0U) && "Float exponent overflow");
35   return static_cast<uint16_t>(exponent & half_exponent_mask);
36 }
37 
38 // Return mantissa value of 32 bits float. Note that mantissa for 32
39 // bits float is 23 bits and this method must return uint32_t.
FloatMantissa(const uint32_t hex_float)40 uint32_t FloatMantissa(const uint32_t hex_float) {
41   return static_cast<uint32_t>(hex_float & ((1U << 23U) - 1U));
42 }
43 
44 // Convert 32 bits float |value| to 16 bits float based on IEEE-754.
FloatToHexFloat16(const float value)45 uint16_t FloatToHexFloat16(const float value) {
46   const uint32_t* hex = reinterpret_cast<const uint32_t*>(&value);
47   return static_cast<uint16_t>(
48       static_cast<uint16_t>(FloatSign(*hex) << 15U) |
49       static_cast<uint16_t>(FloatExponent(*hex) << 10U) |
50       static_cast<uint16_t>(FloatMantissa(*hex) >> 13U));
51 }
52 
53 template <typename T>
ValuesAs(uint8_t * values)54 T* ValuesAs(uint8_t* values) {
55   return reinterpret_cast<T*>(values);
56 }
57 
58 template <typename T>
Sub(const uint8_t * buf1,const uint8_t * buf2)59 double Sub(const uint8_t* buf1, const uint8_t* buf2) {
60   return static_cast<double>(*reinterpret_cast<const T*>(buf1) -
61                              *reinterpret_cast<const T*>(buf2));
62 }
63 
CalculateDiff(const Format::Segment * seg,const uint8_t * buf1,const uint8_t * buf2)64 double CalculateDiff(const Format::Segment* seg,
65                      const uint8_t* buf1,
66                      const uint8_t* buf2) {
67   FormatMode mode = seg->GetFormatMode();
68   uint32_t num_bits = seg->GetNumBits();
69   if (type::Type::IsInt8(mode, num_bits))
70     return Sub<int8_t>(buf1, buf2);
71   if (type::Type::IsInt16(mode, num_bits))
72     return Sub<int16_t>(buf1, buf2);
73   if (type::Type::IsInt32(mode, num_bits))
74     return Sub<int32_t>(buf1, buf2);
75   if (type::Type::IsInt64(mode, num_bits))
76     return Sub<int64_t>(buf1, buf2);
77   if (type::Type::IsUint8(mode, num_bits))
78     return Sub<uint8_t>(buf1, buf2);
79   if (type::Type::IsUint16(mode, num_bits))
80     return Sub<uint16_t>(buf1, buf2);
81   if (type::Type::IsUint32(mode, num_bits))
82     return Sub<uint32_t>(buf1, buf2);
83   if (type::Type::IsUint64(mode, num_bits))
84     return Sub<uint64_t>(buf1, buf2);
85   // TODO(dsinclair): Handle float16 ...
86   if (type::Type::IsFloat16(mode, num_bits)) {
87     assert(false && "Float16 suppport not implemented");
88     return 0.0;
89   }
90   if (type::Type::IsFloat32(mode, num_bits))
91     return Sub<float>(buf1, buf2);
92   if (type::Type::IsFloat64(mode, num_bits))
93     return Sub<double>(buf1, buf2);
94 
95   assert(false && "NOTREACHED");
96   return 0.0;
97 }
98 
99 }  // namespace
100 
101 Buffer::Buffer() = default;
102 
Buffer(BufferType type)103 Buffer::Buffer(BufferType type) : buffer_type_(type) {}
104 
105 Buffer::~Buffer() = default;
106 
CopyTo(Buffer * buffer) const107 Result Buffer::CopyTo(Buffer* buffer) const {
108   if (buffer->width_ != width_)
109     return Result("Buffer::CopyBaseFields() buffers have a different width");
110   if (buffer->height_ != height_)
111     return Result("Buffer::CopyBaseFields() buffers have a different height");
112   if (buffer->element_count_ != element_count_)
113     return Result("Buffer::CopyBaseFields() buffers have a different size");
114   buffer->bytes_ = bytes_;
115   return {};
116 }
117 
IsEqual(Buffer * buffer) const118 Result Buffer::IsEqual(Buffer* buffer) const {
119   auto result = CheckCompability(buffer);
120   if (!result.IsSuccess())
121     return result;
122 
123   uint32_t num_different = 0;
124   uint32_t first_different_index = 0;
125   uint8_t first_different_left = 0;
126   uint8_t first_different_right = 0;
127   for (uint32_t i = 0; i < bytes_.size(); ++i) {
128     if (bytes_[i] != buffer->bytes_[i]) {
129       if (num_different == 0) {
130         first_different_index = i;
131         first_different_left = bytes_[i];
132         first_different_right = buffer->bytes_[i];
133       }
134       num_different++;
135     }
136   }
137 
138   if (num_different) {
139     return Result{"Buffers have different values. " +
140                   std::to_string(num_different) +
141                   " values differed, first difference at byte " +
142                   std::to_string(first_different_index) + " values " +
143                   std::to_string(first_different_left) +
144                   " != " + std::to_string(first_different_right)};
145   }
146 
147   return {};
148 }
149 
CalculateDiffs(const Buffer * buffer) const150 std::vector<double> Buffer::CalculateDiffs(const Buffer* buffer) const {
151   std::vector<double> diffs;
152 
153   auto* buf_1_ptr = GetValues<uint8_t>();
154   auto* buf_2_ptr = buffer->GetValues<uint8_t>();
155   const auto& segments = format_->GetSegments();
156   for (size_t i = 0; i < ElementCount(); ++i) {
157     for (const auto& seg : segments) {
158       if (seg.IsPadding()) {
159         buf_1_ptr += seg.PaddingBytes();
160         buf_2_ptr += seg.PaddingBytes();
161         continue;
162       }
163 
164       diffs.push_back(CalculateDiff(&seg, buf_1_ptr, buf_2_ptr));
165 
166       buf_1_ptr += seg.SizeInBytes();
167       buf_2_ptr += seg.SizeInBytes();
168     }
169   }
170 
171   return diffs;
172 }
173 
CheckCompability(Buffer * buffer) const174 Result Buffer::CheckCompability(Buffer* buffer) const {
175   if (!buffer->format_->Equal(format_))
176     return Result{"Buffers have a different format"};
177   if (buffer->element_count_ != element_count_)
178     return Result{"Buffers have a different size"};
179   if (buffer->width_ != width_)
180     return Result{"Buffers have a different width"};
181   if (buffer->height_ != height_)
182     return Result{"Buffers have a different height"};
183   if (buffer->ValueCount() != ValueCount())
184     return Result{"Buffers have a different number of values"};
185 
186   return {};
187 }
188 
CompareRMSE(Buffer * buffer,float tolerance) const189 Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const {
190   auto result = CheckCompability(buffer);
191   if (!result.IsSuccess())
192     return result;
193 
194   auto diffs = CalculateDiffs(buffer);
195   double sum = 0.0;
196   for (const auto val : diffs)
197     sum += (val * val);
198 
199   sum /= static_cast<double>(diffs.size());
200   double rmse = std::sqrt(sum);
201   if (rmse > static_cast<double>(tolerance)) {
202     return Result("Root Mean Square Error of " + std::to_string(rmse) +
203                   " is greater than tolerance of " + std::to_string(tolerance));
204   }
205 
206   return {};
207 }
208 
GetHistogramForChannel(uint32_t channel,uint32_t num_bins) const209 std::vector<uint64_t> Buffer::GetHistogramForChannel(uint32_t channel,
210                                                      uint32_t num_bins) const {
211   assert(num_bins == 256);
212   std::vector<uint64_t> bins(num_bins, 0);
213   auto* buf_ptr = GetValues<uint8_t>();
214   auto num_channels = format_->InputNeededPerElement();
215   uint32_t channel_id = 0;
216 
217   for (size_t i = 0; i < ElementCount(); ++i) {
218     for (const auto& seg : format_->GetSegments()) {
219       if (seg.IsPadding()) {
220         buf_ptr += seg.PaddingBytes();
221         continue;
222       }
223       if (channel_id == channel) {
224         assert(type::Type::IsUint8(seg.GetFormatMode(), seg.GetNumBits()));
225         const auto bin = *reinterpret_cast<const uint8_t*>(buf_ptr);
226         bins[bin]++;
227       }
228       buf_ptr += seg.SizeInBytes();
229       channel_id = (channel_id + 1) % num_channels;
230     }
231   }
232 
233   return bins;
234 }
235 
CompareHistogramEMD(Buffer * buffer,float tolerance) const236 Result Buffer::CompareHistogramEMD(Buffer* buffer, float tolerance) const {
237   auto result = CheckCompability(buffer);
238   if (!result.IsSuccess())
239     return result;
240 
241   const int num_bins = 256;
242   auto num_channels = format_->InputNeededPerElement();
243   for (auto segment : format_->GetSegments()) {
244     if (!type::Type::IsUint8(segment.GetFormatMode(), segment.GetNumBits()) ||
245         num_channels != 4) {
246       return Result(
247           "EMD comparison only supports 8bit unorm format with four channels.");
248     }
249   }
250 
251   std::vector<std::vector<uint64_t>> histogram1;
252   std::vector<std::vector<uint64_t>> histogram2;
253   for (uint32_t c = 0; c < num_channels; ++c) {
254     histogram1.push_back(GetHistogramForChannel(c, num_bins));
255     histogram2.push_back(buffer->GetHistogramForChannel(c, num_bins));
256   }
257 
258   // Earth movers's distance: Calculate the minimal cost of moving "earth" to
259   // transform the first histogram into the second, where each bin of the
260   // histogram can be thought of as a column of units of earth. The cost is the
261   // amount of earth moved times the distance carried (the distance is the
262   // number of adjacent bins over which the earth is carried). Calculate this
263   // using the cumulative difference of the bins, which works as long as both
264   // histograms have the same amount of earth. Sum the absolute values of the
265   // cumulative difference to get the final cost of how much (and how far) the
266   // earth was moved.
267   double max_emd = 0;
268 
269   for (uint32_t c = 0; c < num_channels; ++c) {
270     double diff_total = 0;
271     double diff_accum = 0;
272 
273     for (size_t i = 0; i < num_bins; ++i) {
274       double hist_normalized_1 =
275           static_cast<double>(histogram1[c][i]) / element_count_;
276       double hist_normalized_2 =
277           static_cast<double>(histogram2[c][i]) / buffer->element_count_;
278       diff_accum += hist_normalized_1 - hist_normalized_2;
279       diff_total += fabs(diff_accum);
280     }
281     // Normalize to range 0..1
282     double emd = diff_total / num_bins;
283     max_emd = std::max(max_emd, emd);
284   }
285 
286   if (max_emd > static_cast<double>(tolerance)) {
287     return Result("Histogram EMD value of " + std::to_string(max_emd) +
288                   " is greater than tolerance of " + std::to_string(tolerance));
289   }
290 
291   return {};
292 }
293 
SetData(const std::vector<Value> & data)294 Result Buffer::SetData(const std::vector<Value>& data) {
295   return SetDataWithOffset(data, 0);
296 }
297 
RecalculateMaxSizeInBytes(const std::vector<Value> & data,uint32_t offset)298 Result Buffer::RecalculateMaxSizeInBytes(const std::vector<Value>& data,
299                                          uint32_t offset) {
300   // Multiply by the input needed because the value count will use the needed
301   // input as the multiplier
302   uint32_t value_count =
303       ((offset / format_->SizeInBytes()) * format_->InputNeededPerElement()) +
304       static_cast<uint32_t>(data.size());
305   uint32_t element_count = value_count;
306   if (!format_->IsPacked()) {
307     // This divides by the needed input values, not the values per element.
308     // The assumption being the values coming in are read from the input,
309     // where components are specified. The needed values maybe less then the
310     // values per element.
311     element_count = value_count / format_->InputNeededPerElement();
312   }
313   if (GetMaxSizeInBytes() < element_count * format_->SizeInBytes())
314     SetMaxSizeInBytes(element_count * format_->SizeInBytes());
315   return {};
316 }
317 
SetDataWithOffset(const std::vector<Value> & data,uint32_t offset)318 Result Buffer::SetDataWithOffset(const std::vector<Value>& data,
319                                  uint32_t offset) {
320   // Multiply by the input needed because the value count will use the needed
321   // input as the multiplier
322   uint32_t value_count =
323       ((offset / format_->SizeInBytes()) * format_->InputNeededPerElement()) +
324       static_cast<uint32_t>(data.size());
325 
326   // The buffer should only be resized to become bigger. This means that if a
327   // command was run to set the buffer size we'll honour that size until a
328   // request happens to make the buffer bigger.
329   if (value_count > ValueCount())
330     SetValueCount(value_count);
331 
332   // Even if the value count doesn't change, the buffer is still resized because
333   // this maybe the first time data is set into the buffer.
334   bytes_.resize(GetSizeInBytes());
335 
336   // Set the new memory to zero to be on the safe side.
337   uint32_t new_space =
338       (static_cast<uint32_t>(data.size()) / format_->InputNeededPerElement()) *
339       format_->SizeInBytes();
340   assert(new_space + offset <= GetSizeInBytes());
341 
342   if (new_space > 0)
343     memset(bytes_.data() + offset, 0, new_space);
344 
345   if (data.size() > (ElementCount() * format_->InputNeededPerElement()))
346     return Result("Mismatched number of items in buffer");
347 
348   uint8_t* ptr = bytes_.data() + offset;
349   const auto& segments = format_->GetSegments();
350   for (uint32_t i = 0; i < data.size();) {
351     for (const auto& seg : segments) {
352       if (seg.IsPadding()) {
353         ptr += seg.PaddingBytes();
354         continue;
355       }
356 
357       Value v = data[i++];
358       ptr += WriteValueFromComponent(v, seg.GetFormatMode(), seg.GetNumBits(),
359                                      ptr);
360       if (i >= data.size())
361         break;
362     }
363   }
364   return {};
365 }
366 
WriteValueFromComponent(const Value & value,FormatMode mode,uint32_t num_bits,uint8_t * ptr)367 uint32_t Buffer::WriteValueFromComponent(const Value& value,
368                                          FormatMode mode,
369                                          uint32_t num_bits,
370                                          uint8_t* ptr) {
371   if (type::Type::IsInt8(mode, num_bits)) {
372     *(ValuesAs<int8_t>(ptr)) = value.AsInt8();
373     return sizeof(int8_t);
374   }
375   if (type::Type::IsInt16(mode, num_bits)) {
376     *(ValuesAs<int16_t>(ptr)) = value.AsInt16();
377     return sizeof(int16_t);
378   }
379   if (type::Type::IsInt32(mode, num_bits)) {
380     *(ValuesAs<int32_t>(ptr)) = value.AsInt32();
381     return sizeof(int32_t);
382   }
383   if (type::Type::IsInt64(mode, num_bits)) {
384     *(ValuesAs<int64_t>(ptr)) = value.AsInt64();
385     return sizeof(int64_t);
386   }
387   if (type::Type::IsUint8(mode, num_bits)) {
388     *(ValuesAs<uint8_t>(ptr)) = value.AsUint8();
389     return sizeof(uint8_t);
390   }
391   if (type::Type::IsUint16(mode, num_bits)) {
392     *(ValuesAs<uint16_t>(ptr)) = value.AsUint16();
393     return sizeof(uint16_t);
394   }
395   if (type::Type::IsUint32(mode, num_bits)) {
396     *(ValuesAs<uint32_t>(ptr)) = value.AsUint32();
397     return sizeof(uint32_t);
398   }
399   if (type::Type::IsUint64(mode, num_bits)) {
400     *(ValuesAs<uint64_t>(ptr)) = value.AsUint64();
401     return sizeof(uint64_t);
402   }
403   if (type::Type::IsFloat16(mode, num_bits)) {
404     *(ValuesAs<uint16_t>(ptr)) = FloatToHexFloat16(value.AsFloat());
405     return sizeof(uint16_t);
406   }
407   if (type::Type::IsFloat32(mode, num_bits)) {
408     *(ValuesAs<float>(ptr)) = value.AsFloat();
409     return sizeof(float);
410   }
411   if (type::Type::IsFloat64(mode, num_bits)) {
412     *(ValuesAs<double>(ptr)) = value.AsDouble();
413     return sizeof(double);
414   }
415 
416   // The float 10 and float 11 sizes are only used in PACKED formats.
417   assert(false && "Not reached");
418   return 0;
419 }
420 
SetSizeInElements(uint32_t element_count)421 void Buffer::SetSizeInElements(uint32_t element_count) {
422   element_count_ = element_count;
423   bytes_.resize(element_count * format_->SizeInBytes());
424 }
425 
SetSizeInBytes(uint32_t size_in_bytes)426 void Buffer::SetSizeInBytes(uint32_t size_in_bytes) {
427   assert(size_in_bytes % format_->SizeInBytes() == 0);
428   element_count_ = size_in_bytes / format_->SizeInBytes();
429   bytes_.resize(size_in_bytes);
430 }
431 
SetMaxSizeInBytes(uint32_t max_size_in_bytes)432 void Buffer::SetMaxSizeInBytes(uint32_t max_size_in_bytes) {
433   max_size_in_bytes_ = max_size_in_bytes;
434 }
435 
GetMaxSizeInBytes() const436 uint32_t Buffer::GetMaxSizeInBytes() const {
437   if (max_size_in_bytes_ != 0)
438     return max_size_in_bytes_;
439   else
440     return GetSizeInBytes();
441 }
442 
SetDataFromBuffer(const Buffer * src,uint32_t offset)443 Result Buffer::SetDataFromBuffer(const Buffer* src, uint32_t offset) {
444   if (bytes_.size() < offset + src->bytes_.size())
445     bytes_.resize(offset + src->bytes_.size());
446 
447   std::memcpy(bytes_.data() + offset, src->bytes_.data(), src->bytes_.size());
448   element_count_ =
449       static_cast<uint32_t>(bytes_.size()) / format_->SizeInBytes();
450   return {};
451 }
452 
453 }  // namespace amber
454