• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Amber Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/buffer.h"
16 
17 #include <algorithm>
18 #include <cassert>
19 #include <cmath>
20 #include <cstring>
21 
22 #include "src/float16_helper.h"
23 
24 namespace amber {
25 namespace {
26 
27 template <typename T>
ValuesAs(uint8_t * values)28 T* ValuesAs(uint8_t* values) {
29   return reinterpret_cast<T*>(values);
30 }
31 
32 template <typename T>
Sub(const uint8_t * buf1,const uint8_t * buf2)33 double Sub(const uint8_t* buf1, const uint8_t* buf2) {
34   return static_cast<double>(*reinterpret_cast<const T*>(buf1) -
35                              *reinterpret_cast<const T*>(buf2));
36 }
37 
CalculateDiff(const Format::Segment * seg,const uint8_t * buf1,const uint8_t * buf2)38 double CalculateDiff(const Format::Segment* seg,
39                      const uint8_t* buf1,
40                      const uint8_t* buf2) {
41   FormatMode mode = seg->GetFormatMode();
42   uint32_t num_bits = seg->GetNumBits();
43   if (type::Type::IsInt8(mode, num_bits))
44     return Sub<int8_t>(buf1, buf2);
45   if (type::Type::IsInt16(mode, num_bits))
46     return Sub<int16_t>(buf1, buf2);
47   if (type::Type::IsInt32(mode, num_bits))
48     return Sub<int32_t>(buf1, buf2);
49   if (type::Type::IsInt64(mode, num_bits))
50     return Sub<int64_t>(buf1, buf2);
51   if (type::Type::IsUint8(mode, num_bits))
52     return Sub<uint8_t>(buf1, buf2);
53   if (type::Type::IsUint16(mode, num_bits))
54     return Sub<uint16_t>(buf1, buf2);
55   if (type::Type::IsUint32(mode, num_bits))
56     return Sub<uint32_t>(buf1, buf2);
57   if (type::Type::IsUint64(mode, num_bits))
58     return Sub<uint64_t>(buf1, buf2);
59   if (type::Type::IsFloat16(mode, num_bits)) {
60     float val1 = float16::HexFloatToFloat(buf1, 16);
61     float val2 = float16::HexFloatToFloat(buf2, 16);
62     return static_cast<double>(val1 - val2);
63   }
64   if (type::Type::IsFloat32(mode, num_bits))
65     return Sub<float>(buf1, buf2);
66   if (type::Type::IsFloat64(mode, num_bits))
67     return Sub<double>(buf1, buf2);
68 
69   assert(false && "NOTREACHED");
70   return 0.0;
71 }
72 
73 }  // namespace
74 
75 Buffer::Buffer() = default;
76 
77 Buffer::~Buffer() = default;
78 
CopyTo(Buffer * buffer) const79 Result Buffer::CopyTo(Buffer* buffer) const {
80   if (buffer->width_ != width_)
81     return Result("Buffer::CopyBaseFields() buffers have a different width");
82   if (buffer->height_ != height_)
83     return Result("Buffer::CopyBaseFields() buffers have a different height");
84   if (buffer->element_count_ != element_count_)
85     return Result("Buffer::CopyBaseFields() buffers have a different size");
86   buffer->bytes_ = bytes_;
87   return {};
88 }
89 
IsEqual(Buffer * buffer) const90 Result Buffer::IsEqual(Buffer* buffer) const {
91   auto result = CheckCompability(buffer);
92   if (!result.IsSuccess())
93     return result;
94 
95   uint32_t num_different = 0;
96   uint32_t first_different_index = 0;
97   uint8_t first_different_left = 0;
98   uint8_t first_different_right = 0;
99   for (uint32_t i = 0; i < bytes_.size(); ++i) {
100     if (bytes_[i] != buffer->bytes_[i]) {
101       if (num_different == 0) {
102         first_different_index = i;
103         first_different_left = bytes_[i];
104         first_different_right = buffer->bytes_[i];
105       }
106       num_different++;
107     }
108   }
109 
110   if (num_different) {
111     return Result{"Buffers have different values. " +
112                   std::to_string(num_different) +
113                   " values differed, first difference at byte " +
114                   std::to_string(first_different_index) + " values " +
115                   std::to_string(first_different_left) +
116                   " != " + std::to_string(first_different_right)};
117   }
118 
119   return {};
120 }
121 
CalculateDiffs(const Buffer * buffer) const122 std::vector<double> Buffer::CalculateDiffs(const Buffer* buffer) const {
123   std::vector<double> diffs;
124 
125   auto* buf_1_ptr = GetValues<uint8_t>();
126   auto* buf_2_ptr = buffer->GetValues<uint8_t>();
127   const auto& segments = format_->GetSegments();
128   for (size_t i = 0; i < ElementCount(); ++i) {
129     for (const auto& seg : segments) {
130       if (seg.IsPadding()) {
131         buf_1_ptr += seg.PaddingBytes();
132         buf_2_ptr += seg.PaddingBytes();
133         continue;
134       }
135 
136       diffs.push_back(CalculateDiff(&seg, buf_1_ptr, buf_2_ptr));
137 
138       buf_1_ptr += seg.SizeInBytes();
139       buf_2_ptr += seg.SizeInBytes();
140     }
141   }
142 
143   return diffs;
144 }
145 
CheckCompability(Buffer * buffer) const146 Result Buffer::CheckCompability(Buffer* buffer) const {
147   if (!buffer->format_->Equal(format_))
148     return Result{"Buffers have a different format"};
149   if (buffer->element_count_ != element_count_)
150     return Result{"Buffers have a different size"};
151   if (buffer->width_ != width_)
152     return Result{"Buffers have a different width"};
153   if (buffer->height_ != height_)
154     return Result{"Buffers have a different height"};
155   if (buffer->ValueCount() != ValueCount())
156     return Result{"Buffers have a different number of values"};
157 
158   return {};
159 }
160 
CompareRMSE(Buffer * buffer,float tolerance) const161 Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const {
162   auto result = CheckCompability(buffer);
163   if (!result.IsSuccess())
164     return result;
165 
166   auto diffs = CalculateDiffs(buffer);
167   double sum = 0.0;
168   for (const auto val : diffs)
169     sum += (val * val);
170 
171   sum /= static_cast<double>(diffs.size());
172   double rmse = std::sqrt(sum);
173   if (rmse > static_cast<double>(tolerance)) {
174     return Result("Root Mean Square Error of " + std::to_string(rmse) +
175                   " is greater than tolerance of " + std::to_string(tolerance));
176   }
177 
178   return {};
179 }
180 
GetHistogramForChannel(uint32_t channel,uint32_t num_bins) const181 std::vector<uint64_t> Buffer::GetHistogramForChannel(uint32_t channel,
182                                                      uint32_t num_bins) const {
183   assert(num_bins == 256);
184   std::vector<uint64_t> bins(num_bins, 0);
185   auto* buf_ptr = GetValues<uint8_t>();
186   auto num_channels = format_->InputNeededPerElement();
187   uint32_t channel_id = 0;
188 
189   for (size_t i = 0; i < ElementCount(); ++i) {
190     for (const auto& seg : format_->GetSegments()) {
191       if (seg.IsPadding()) {
192         buf_ptr += seg.PaddingBytes();
193         continue;
194       }
195       if (channel_id == channel) {
196         assert(type::Type::IsUint8(seg.GetFormatMode(), seg.GetNumBits()));
197         const auto bin = *reinterpret_cast<const uint8_t*>(buf_ptr);
198         bins[bin]++;
199       }
200       buf_ptr += seg.SizeInBytes();
201       channel_id = (channel_id + 1) % num_channels;
202     }
203   }
204 
205   return bins;
206 }
207 
CompareHistogramEMD(Buffer * buffer,float tolerance) const208 Result Buffer::CompareHistogramEMD(Buffer* buffer, float tolerance) const {
209   auto result = CheckCompability(buffer);
210   if (!result.IsSuccess())
211     return result;
212 
213   const int num_bins = 256;
214   auto num_channels = format_->InputNeededPerElement();
215   for (auto segment : format_->GetSegments()) {
216     if (!type::Type::IsUint8(segment.GetFormatMode(), segment.GetNumBits()) ||
217         num_channels != 4) {
218       return Result(
219           "EMD comparison only supports 8bit unorm format with four channels.");
220     }
221   }
222 
223   std::vector<std::vector<uint64_t>> histogram1;
224   std::vector<std::vector<uint64_t>> histogram2;
225   for (uint32_t c = 0; c < num_channels; ++c) {
226     histogram1.push_back(GetHistogramForChannel(c, num_bins));
227     histogram2.push_back(buffer->GetHistogramForChannel(c, num_bins));
228   }
229 
230   // Earth movers's distance: Calculate the minimal cost of moving "earth" to
231   // transform the first histogram into the second, where each bin of the
232   // histogram can be thought of as a column of units of earth. The cost is the
233   // amount of earth moved times the distance carried (the distance is the
234   // number of adjacent bins over which the earth is carried). Calculate this
235   // using the cumulative difference of the bins, which works as long as both
236   // histograms have the same amount of earth. Sum the absolute values of the
237   // cumulative difference to get the final cost of how much (and how far) the
238   // earth was moved.
239   double max_emd = 0;
240 
241   for (uint32_t c = 0; c < num_channels; ++c) {
242     double diff_total = 0;
243     double diff_accum = 0;
244 
245     for (size_t i = 0; i < num_bins; ++i) {
246       double hist_normalized_1 =
247           static_cast<double>(histogram1[c][i]) / element_count_;
248       double hist_normalized_2 =
249           static_cast<double>(histogram2[c][i]) / buffer->element_count_;
250       diff_accum += hist_normalized_1 - hist_normalized_2;
251       diff_total += fabs(diff_accum);
252     }
253     // Normalize to range 0..1
254     double emd = diff_total / num_bins;
255     max_emd = std::max(max_emd, emd);
256   }
257 
258   if (max_emd > static_cast<double>(tolerance)) {
259     return Result("Histogram EMD value of " + std::to_string(max_emd) +
260                   " is greater than tolerance of " + std::to_string(tolerance));
261   }
262 
263   return {};
264 }
265 
SetData(const std::vector<Value> & data)266 Result Buffer::SetData(const std::vector<Value>& data) {
267   return SetDataWithOffset(data, 0);
268 }
269 
RecalculateMaxSizeInBytes(const std::vector<Value> & data,uint32_t offset)270 Result Buffer::RecalculateMaxSizeInBytes(const std::vector<Value>& data,
271                                          uint32_t offset) {
272   // Multiply by the input needed because the value count will use the needed
273   // input as the multiplier
274   uint32_t value_count =
275       ((offset / format_->SizeInBytes()) * format_->InputNeededPerElement()) +
276       static_cast<uint32_t>(data.size());
277   uint32_t element_count = value_count;
278   if (!format_->IsPacked()) {
279     // This divides by the needed input values, not the values per element.
280     // The assumption being the values coming in are read from the input,
281     // where components are specified. The needed values maybe less then the
282     // values per element.
283     element_count = value_count / format_->InputNeededPerElement();
284   }
285   if (GetMaxSizeInBytes() < element_count * format_->SizeInBytes())
286     SetMaxSizeInBytes(element_count * format_->SizeInBytes());
287   return {};
288 }
289 
SetDataWithOffset(const std::vector<Value> & data,uint32_t offset)290 Result Buffer::SetDataWithOffset(const std::vector<Value>& data,
291                                  uint32_t offset) {
292   // Multiply by the input needed because the value count will use the needed
293   // input as the multiplier
294   uint32_t value_count =
295       ((offset / format_->SizeInBytes()) * format_->InputNeededPerElement()) +
296       static_cast<uint32_t>(data.size());
297 
298   // The buffer should only be resized to become bigger. This means that if a
299   // command was run to set the buffer size we'll honour that size until a
300   // request happens to make the buffer bigger.
301   if (value_count > ValueCount())
302     SetValueCount(value_count);
303 
304   // Even if the value count doesn't change, the buffer is still resized because
305   // this maybe the first time data is set into the buffer.
306   bytes_.resize(GetSizeInBytes());
307 
308   // Set the new memory to zero to be on the safe side.
309   uint32_t new_space =
310       (static_cast<uint32_t>(data.size()) / format_->InputNeededPerElement()) *
311       format_->SizeInBytes();
312   assert(new_space + offset <= GetSizeInBytes());
313 
314   if (new_space > 0)
315     memset(bytes_.data() + offset, 0, new_space);
316 
317   if (data.size() > (ElementCount() * format_->InputNeededPerElement()))
318     return Result("Mismatched number of items in buffer");
319 
320   uint8_t* ptr = bytes_.data() + offset;
321   const auto& segments = format_->GetSegments();
322   for (uint32_t i = 0; i < data.size();) {
323     for (const auto& seg : segments) {
324       if (seg.IsPadding()) {
325         ptr += seg.PaddingBytes();
326         continue;
327       }
328 
329       Value v = data[i++];
330       ptr += WriteValueFromComponent(v, seg.GetFormatMode(), seg.GetNumBits(),
331                                      ptr);
332       if (i >= data.size())
333         break;
334     }
335   }
336   return {};
337 }
338 
WriteValueFromComponent(const Value & value,FormatMode mode,uint32_t num_bits,uint8_t * ptr)339 uint32_t Buffer::WriteValueFromComponent(const Value& value,
340                                          FormatMode mode,
341                                          uint32_t num_bits,
342                                          uint8_t* ptr) {
343   if (type::Type::IsInt8(mode, num_bits)) {
344     *(ValuesAs<int8_t>(ptr)) = value.AsInt8();
345     return sizeof(int8_t);
346   }
347   if (type::Type::IsInt16(mode, num_bits)) {
348     *(ValuesAs<int16_t>(ptr)) = value.AsInt16();
349     return sizeof(int16_t);
350   }
351   if (type::Type::IsInt32(mode, num_bits)) {
352     *(ValuesAs<int32_t>(ptr)) = value.AsInt32();
353     return sizeof(int32_t);
354   }
355   if (type::Type::IsInt64(mode, num_bits)) {
356     *(ValuesAs<int64_t>(ptr)) = value.AsInt64();
357     return sizeof(int64_t);
358   }
359   if (type::Type::IsUint8(mode, num_bits)) {
360     *(ValuesAs<uint8_t>(ptr)) = value.AsUint8();
361     return sizeof(uint8_t);
362   }
363   if (type::Type::IsUint16(mode, num_bits)) {
364     *(ValuesAs<uint16_t>(ptr)) = value.AsUint16();
365     return sizeof(uint16_t);
366   }
367   if (type::Type::IsUint32(mode, num_bits)) {
368     *(ValuesAs<uint32_t>(ptr)) = value.AsUint32();
369     return sizeof(uint32_t);
370   }
371   if (type::Type::IsUint64(mode, num_bits)) {
372     *(ValuesAs<uint64_t>(ptr)) = value.AsUint64();
373     return sizeof(uint64_t);
374   }
375   if (type::Type::IsFloat16(mode, num_bits)) {
376     *(ValuesAs<uint16_t>(ptr)) = float16::FloatToHexFloat16(value.AsFloat());
377     return sizeof(uint16_t);
378   }
379   if (type::Type::IsFloat32(mode, num_bits)) {
380     *(ValuesAs<float>(ptr)) = value.AsFloat();
381     return sizeof(float);
382   }
383   if (type::Type::IsFloat64(mode, num_bits)) {
384     *(ValuesAs<double>(ptr)) = value.AsDouble();
385     return sizeof(double);
386   }
387 
388   // The float 10 and float 11 sizes are only used in PACKED formats.
389   assert(false && "Not reached");
390   return 0;
391 }
392 
SetSizeInElements(uint32_t element_count)393 void Buffer::SetSizeInElements(uint32_t element_count) {
394   element_count_ = element_count;
395   bytes_.resize(element_count * format_->SizeInBytes());
396 }
397 
SetSizeInBytes(uint32_t size_in_bytes)398 void Buffer::SetSizeInBytes(uint32_t size_in_bytes) {
399   assert(size_in_bytes % format_->SizeInBytes() == 0);
400   element_count_ = size_in_bytes / format_->SizeInBytes();
401   bytes_.resize(size_in_bytes);
402 }
403 
SetMaxSizeInBytes(uint32_t max_size_in_bytes)404 void Buffer::SetMaxSizeInBytes(uint32_t max_size_in_bytes) {
405   max_size_in_bytes_ = max_size_in_bytes;
406 }
407 
GetMaxSizeInBytes() const408 uint32_t Buffer::GetMaxSizeInBytes() const {
409   if (max_size_in_bytes_ != 0)
410     return max_size_in_bytes_;
411   else
412     return GetSizeInBytes();
413 }
414 
SetDataFromBuffer(const Buffer * src,uint32_t offset)415 Result Buffer::SetDataFromBuffer(const Buffer* src, uint32_t offset) {
416   if (bytes_.size() < offset + src->bytes_.size())
417     bytes_.resize(offset + src->bytes_.size());
418 
419   std::memcpy(bytes_.data() + offset, src->bytes_.data(), src->bytes_.size());
420   element_count_ =
421       static_cast<uint32_t>(bytes_.size()) / format_->SizeInBytes();
422   return {};
423 }
424 
425 }  // namespace amber
426