1 // Copyright 2018 The Amber Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/buffer.h"
16
17 #include <algorithm>
18 #include <cassert>
19 #include <cmath>
20 #include <cstring>
21
22 #include "src/float16_helper.h"
23
24 namespace amber {
25 namespace {
26
27 template <typename T>
ValuesAs(uint8_t * values)28 T* ValuesAs(uint8_t* values) {
29 return reinterpret_cast<T*>(values);
30 }
31
32 template <typename T>
Sub(const uint8_t * buf1,const uint8_t * buf2)33 double Sub(const uint8_t* buf1, const uint8_t* buf2) {
34 return static_cast<double>(*reinterpret_cast<const T*>(buf1) -
35 *reinterpret_cast<const T*>(buf2));
36 }
37
CalculateDiff(const Format::Segment * seg,const uint8_t * buf1,const uint8_t * buf2)38 double CalculateDiff(const Format::Segment* seg,
39 const uint8_t* buf1,
40 const uint8_t* buf2) {
41 FormatMode mode = seg->GetFormatMode();
42 uint32_t num_bits = seg->GetNumBits();
43 if (type::Type::IsInt8(mode, num_bits))
44 return Sub<int8_t>(buf1, buf2);
45 if (type::Type::IsInt16(mode, num_bits))
46 return Sub<int16_t>(buf1, buf2);
47 if (type::Type::IsInt32(mode, num_bits))
48 return Sub<int32_t>(buf1, buf2);
49 if (type::Type::IsInt64(mode, num_bits))
50 return Sub<int64_t>(buf1, buf2);
51 if (type::Type::IsUint8(mode, num_bits))
52 return Sub<uint8_t>(buf1, buf2);
53 if (type::Type::IsUint16(mode, num_bits))
54 return Sub<uint16_t>(buf1, buf2);
55 if (type::Type::IsUint32(mode, num_bits))
56 return Sub<uint32_t>(buf1, buf2);
57 if (type::Type::IsUint64(mode, num_bits))
58 return Sub<uint64_t>(buf1, buf2);
59 if (type::Type::IsFloat16(mode, num_bits)) {
60 float val1 = float16::HexFloatToFloat(buf1, 16);
61 float val2 = float16::HexFloatToFloat(buf2, 16);
62 return static_cast<double>(val1 - val2);
63 }
64 if (type::Type::IsFloat32(mode, num_bits))
65 return Sub<float>(buf1, buf2);
66 if (type::Type::IsFloat64(mode, num_bits))
67 return Sub<double>(buf1, buf2);
68
69 assert(false && "NOTREACHED");
70 return 0.0;
71 }
72
73 } // namespace
74
75 Buffer::Buffer() = default;
76
77 Buffer::~Buffer() = default;
78
CopyTo(Buffer * buffer) const79 Result Buffer::CopyTo(Buffer* buffer) const {
80 if (buffer->width_ != width_)
81 return Result("Buffer::CopyBaseFields() buffers have a different width");
82 if (buffer->height_ != height_)
83 return Result("Buffer::CopyBaseFields() buffers have a different height");
84 if (buffer->element_count_ != element_count_)
85 return Result("Buffer::CopyBaseFields() buffers have a different size");
86 buffer->bytes_ = bytes_;
87 return {};
88 }
89
IsEqual(Buffer * buffer) const90 Result Buffer::IsEqual(Buffer* buffer) const {
91 auto result = CheckCompability(buffer);
92 if (!result.IsSuccess())
93 return result;
94
95 uint32_t num_different = 0;
96 uint32_t first_different_index = 0;
97 uint8_t first_different_left = 0;
98 uint8_t first_different_right = 0;
99 for (uint32_t i = 0; i < bytes_.size(); ++i) {
100 if (bytes_[i] != buffer->bytes_[i]) {
101 if (num_different == 0) {
102 first_different_index = i;
103 first_different_left = bytes_[i];
104 first_different_right = buffer->bytes_[i];
105 }
106 num_different++;
107 }
108 }
109
110 if (num_different) {
111 return Result{"Buffers have different values. " +
112 std::to_string(num_different) +
113 " values differed, first difference at byte " +
114 std::to_string(first_different_index) + " values " +
115 std::to_string(first_different_left) +
116 " != " + std::to_string(first_different_right)};
117 }
118
119 return {};
120 }
121
CalculateDiffs(const Buffer * buffer) const122 std::vector<double> Buffer::CalculateDiffs(const Buffer* buffer) const {
123 std::vector<double> diffs;
124
125 auto* buf_1_ptr = GetValues<uint8_t>();
126 auto* buf_2_ptr = buffer->GetValues<uint8_t>();
127 const auto& segments = format_->GetSegments();
128 for (size_t i = 0; i < ElementCount(); ++i) {
129 for (const auto& seg : segments) {
130 if (seg.IsPadding()) {
131 buf_1_ptr += seg.PaddingBytes();
132 buf_2_ptr += seg.PaddingBytes();
133 continue;
134 }
135
136 diffs.push_back(CalculateDiff(&seg, buf_1_ptr, buf_2_ptr));
137
138 buf_1_ptr += seg.SizeInBytes();
139 buf_2_ptr += seg.SizeInBytes();
140 }
141 }
142
143 return diffs;
144 }
145
CheckCompability(Buffer * buffer) const146 Result Buffer::CheckCompability(Buffer* buffer) const {
147 if (!buffer->format_->Equal(format_))
148 return Result{"Buffers have a different format"};
149 if (buffer->element_count_ != element_count_)
150 return Result{"Buffers have a different size"};
151 if (buffer->width_ != width_)
152 return Result{"Buffers have a different width"};
153 if (buffer->height_ != height_)
154 return Result{"Buffers have a different height"};
155 if (buffer->ValueCount() != ValueCount())
156 return Result{"Buffers have a different number of values"};
157
158 return {};
159 }
160
CompareRMSE(Buffer * buffer,float tolerance) const161 Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const {
162 auto result = CheckCompability(buffer);
163 if (!result.IsSuccess())
164 return result;
165
166 auto diffs = CalculateDiffs(buffer);
167 double sum = 0.0;
168 for (const auto val : diffs)
169 sum += (val * val);
170
171 sum /= static_cast<double>(diffs.size());
172 double rmse = std::sqrt(sum);
173 if (rmse > static_cast<double>(tolerance)) {
174 return Result("Root Mean Square Error of " + std::to_string(rmse) +
175 " is greater than tolerance of " + std::to_string(tolerance));
176 }
177
178 return {};
179 }
180
GetHistogramForChannel(uint32_t channel,uint32_t num_bins) const181 std::vector<uint64_t> Buffer::GetHistogramForChannel(uint32_t channel,
182 uint32_t num_bins) const {
183 assert(num_bins == 256);
184 std::vector<uint64_t> bins(num_bins, 0);
185 auto* buf_ptr = GetValues<uint8_t>();
186 auto num_channels = format_->InputNeededPerElement();
187 uint32_t channel_id = 0;
188
189 for (size_t i = 0; i < ElementCount(); ++i) {
190 for (const auto& seg : format_->GetSegments()) {
191 if (seg.IsPadding()) {
192 buf_ptr += seg.PaddingBytes();
193 continue;
194 }
195 if (channel_id == channel) {
196 assert(type::Type::IsUint8(seg.GetFormatMode(), seg.GetNumBits()));
197 const auto bin = *reinterpret_cast<const uint8_t*>(buf_ptr);
198 bins[bin]++;
199 }
200 buf_ptr += seg.SizeInBytes();
201 channel_id = (channel_id + 1) % num_channels;
202 }
203 }
204
205 return bins;
206 }
207
CompareHistogramEMD(Buffer * buffer,float tolerance) const208 Result Buffer::CompareHistogramEMD(Buffer* buffer, float tolerance) const {
209 auto result = CheckCompability(buffer);
210 if (!result.IsSuccess())
211 return result;
212
213 const int num_bins = 256;
214 auto num_channels = format_->InputNeededPerElement();
215 for (auto segment : format_->GetSegments()) {
216 if (!type::Type::IsUint8(segment.GetFormatMode(), segment.GetNumBits()) ||
217 num_channels != 4) {
218 return Result(
219 "EMD comparison only supports 8bit unorm format with four channels.");
220 }
221 }
222
223 std::vector<std::vector<uint64_t>> histogram1;
224 std::vector<std::vector<uint64_t>> histogram2;
225 for (uint32_t c = 0; c < num_channels; ++c) {
226 histogram1.push_back(GetHistogramForChannel(c, num_bins));
227 histogram2.push_back(buffer->GetHistogramForChannel(c, num_bins));
228 }
229
230 // Earth movers's distance: Calculate the minimal cost of moving "earth" to
231 // transform the first histogram into the second, where each bin of the
232 // histogram can be thought of as a column of units of earth. The cost is the
233 // amount of earth moved times the distance carried (the distance is the
234 // number of adjacent bins over which the earth is carried). Calculate this
235 // using the cumulative difference of the bins, which works as long as both
236 // histograms have the same amount of earth. Sum the absolute values of the
237 // cumulative difference to get the final cost of how much (and how far) the
238 // earth was moved.
239 double max_emd = 0;
240
241 for (uint32_t c = 0; c < num_channels; ++c) {
242 double diff_total = 0;
243 double diff_accum = 0;
244
245 for (size_t i = 0; i < num_bins; ++i) {
246 double hist_normalized_1 =
247 static_cast<double>(histogram1[c][i]) / element_count_;
248 double hist_normalized_2 =
249 static_cast<double>(histogram2[c][i]) / buffer->element_count_;
250 diff_accum += hist_normalized_1 - hist_normalized_2;
251 diff_total += fabs(diff_accum);
252 }
253 // Normalize to range 0..1
254 double emd = diff_total / num_bins;
255 max_emd = std::max(max_emd, emd);
256 }
257
258 if (max_emd > static_cast<double>(tolerance)) {
259 return Result("Histogram EMD value of " + std::to_string(max_emd) +
260 " is greater than tolerance of " + std::to_string(tolerance));
261 }
262
263 return {};
264 }
265
SetData(const std::vector<Value> & data)266 Result Buffer::SetData(const std::vector<Value>& data) {
267 return SetDataWithOffset(data, 0);
268 }
269
RecalculateMaxSizeInBytes(const std::vector<Value> & data,uint32_t offset)270 Result Buffer::RecalculateMaxSizeInBytes(const std::vector<Value>& data,
271 uint32_t offset) {
272 // Multiply by the input needed because the value count will use the needed
273 // input as the multiplier
274 uint32_t value_count =
275 ((offset / format_->SizeInBytes()) * format_->InputNeededPerElement()) +
276 static_cast<uint32_t>(data.size());
277 uint32_t element_count = value_count;
278 if (!format_->IsPacked()) {
279 // This divides by the needed input values, not the values per element.
280 // The assumption being the values coming in are read from the input,
281 // where components are specified. The needed values maybe less then the
282 // values per element.
283 element_count = value_count / format_->InputNeededPerElement();
284 }
285 if (GetMaxSizeInBytes() < element_count * format_->SizeInBytes())
286 SetMaxSizeInBytes(element_count * format_->SizeInBytes());
287 return {};
288 }
289
SetDataWithOffset(const std::vector<Value> & data,uint32_t offset)290 Result Buffer::SetDataWithOffset(const std::vector<Value>& data,
291 uint32_t offset) {
292 // Multiply by the input needed because the value count will use the needed
293 // input as the multiplier
294 uint32_t value_count =
295 ((offset / format_->SizeInBytes()) * format_->InputNeededPerElement()) +
296 static_cast<uint32_t>(data.size());
297
298 // The buffer should only be resized to become bigger. This means that if a
299 // command was run to set the buffer size we'll honour that size until a
300 // request happens to make the buffer bigger.
301 if (value_count > ValueCount())
302 SetValueCount(value_count);
303
304 // Even if the value count doesn't change, the buffer is still resized because
305 // this maybe the first time data is set into the buffer.
306 bytes_.resize(GetSizeInBytes());
307
308 // Set the new memory to zero to be on the safe side.
309 uint32_t new_space =
310 (static_cast<uint32_t>(data.size()) / format_->InputNeededPerElement()) *
311 format_->SizeInBytes();
312 assert(new_space + offset <= GetSizeInBytes());
313
314 if (new_space > 0)
315 memset(bytes_.data() + offset, 0, new_space);
316
317 if (data.size() > (ElementCount() * format_->InputNeededPerElement()))
318 return Result("Mismatched number of items in buffer");
319
320 uint8_t* ptr = bytes_.data() + offset;
321 const auto& segments = format_->GetSegments();
322 for (uint32_t i = 0; i < data.size();) {
323 for (const auto& seg : segments) {
324 if (seg.IsPadding()) {
325 ptr += seg.PaddingBytes();
326 continue;
327 }
328
329 Value v = data[i++];
330 ptr += WriteValueFromComponent(v, seg.GetFormatMode(), seg.GetNumBits(),
331 ptr);
332 if (i >= data.size())
333 break;
334 }
335 }
336 return {};
337 }
338
WriteValueFromComponent(const Value & value,FormatMode mode,uint32_t num_bits,uint8_t * ptr)339 uint32_t Buffer::WriteValueFromComponent(const Value& value,
340 FormatMode mode,
341 uint32_t num_bits,
342 uint8_t* ptr) {
343 if (type::Type::IsInt8(mode, num_bits)) {
344 *(ValuesAs<int8_t>(ptr)) = value.AsInt8();
345 return sizeof(int8_t);
346 }
347 if (type::Type::IsInt16(mode, num_bits)) {
348 *(ValuesAs<int16_t>(ptr)) = value.AsInt16();
349 return sizeof(int16_t);
350 }
351 if (type::Type::IsInt32(mode, num_bits)) {
352 *(ValuesAs<int32_t>(ptr)) = value.AsInt32();
353 return sizeof(int32_t);
354 }
355 if (type::Type::IsInt64(mode, num_bits)) {
356 *(ValuesAs<int64_t>(ptr)) = value.AsInt64();
357 return sizeof(int64_t);
358 }
359 if (type::Type::IsUint8(mode, num_bits)) {
360 *(ValuesAs<uint8_t>(ptr)) = value.AsUint8();
361 return sizeof(uint8_t);
362 }
363 if (type::Type::IsUint16(mode, num_bits)) {
364 *(ValuesAs<uint16_t>(ptr)) = value.AsUint16();
365 return sizeof(uint16_t);
366 }
367 if (type::Type::IsUint32(mode, num_bits)) {
368 *(ValuesAs<uint32_t>(ptr)) = value.AsUint32();
369 return sizeof(uint32_t);
370 }
371 if (type::Type::IsUint64(mode, num_bits)) {
372 *(ValuesAs<uint64_t>(ptr)) = value.AsUint64();
373 return sizeof(uint64_t);
374 }
375 if (type::Type::IsFloat16(mode, num_bits)) {
376 *(ValuesAs<uint16_t>(ptr)) = float16::FloatToHexFloat16(value.AsFloat());
377 return sizeof(uint16_t);
378 }
379 if (type::Type::IsFloat32(mode, num_bits)) {
380 *(ValuesAs<float>(ptr)) = value.AsFloat();
381 return sizeof(float);
382 }
383 if (type::Type::IsFloat64(mode, num_bits)) {
384 *(ValuesAs<double>(ptr)) = value.AsDouble();
385 return sizeof(double);
386 }
387
388 // The float 10 and float 11 sizes are only used in PACKED formats.
389 assert(false && "Not reached");
390 return 0;
391 }
392
SetSizeInElements(uint32_t element_count)393 void Buffer::SetSizeInElements(uint32_t element_count) {
394 element_count_ = element_count;
395 bytes_.resize(element_count * format_->SizeInBytes());
396 }
397
SetSizeInBytes(uint32_t size_in_bytes)398 void Buffer::SetSizeInBytes(uint32_t size_in_bytes) {
399 assert(size_in_bytes % format_->SizeInBytes() == 0);
400 element_count_ = size_in_bytes / format_->SizeInBytes();
401 bytes_.resize(size_in_bytes);
402 }
403
SetMaxSizeInBytes(uint32_t max_size_in_bytes)404 void Buffer::SetMaxSizeInBytes(uint32_t max_size_in_bytes) {
405 max_size_in_bytes_ = max_size_in_bytes;
406 }
407
GetMaxSizeInBytes() const408 uint32_t Buffer::GetMaxSizeInBytes() const {
409 if (max_size_in_bytes_ != 0)
410 return max_size_in_bytes_;
411 else
412 return GetSizeInBytes();
413 }
414
SetDataFromBuffer(const Buffer * src,uint32_t offset)415 Result Buffer::SetDataFromBuffer(const Buffer* src, uint32_t offset) {
416 if (bytes_.size() < offset + src->bytes_.size())
417 bytes_.resize(offset + src->bytes_.size());
418
419 std::memcpy(bytes_.data() + offset, src->bytes_.data(), src->bytes_.size());
420 element_count_ =
421 static_cast<uint32_t>(bytes_.size()) / format_->SizeInBytes();
422 return {};
423 }
424
425 } // namespace amber
426