1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2018 The Khronos Group Inc.
6 * Copyright (c) 2015 Samsung Electronics Co., Ltd.
7 * Copyright (c) 2016 The Android Open Source Project
8 *
9 * Licensed under the Apache License, Version 2.0 (the "License");
10 * you may not use this file except in compliance with the License.
11 * You may obtain a copy of the License at
12 *
13 * http://www.apache.org/licenses/LICENSE-2.0
14 *
15 * Unless required by applicable law or agreed to in writing, software
16 * distributed under the License is distributed on an "AS IS" BASIS,
17 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 * See the License for the specific language governing permissions and
19 * limitations under the License.
20 *
21 *//*!
22 * \file
23 * \brief Precision and range tests for builtins and types.
24 *
25 *//*--------------------------------------------------------------------*/
26
27 #include "vktShaderBuiltinPrecisionTests.hpp"
28 #include "vktShaderExecutor.hpp"
29 #include "amber/vktAmberTestCase.hpp"
30
31 #include "deMath.h"
32 #include "deMemory.h"
33 #include "deFloat16.h"
34 #include "deDefs.hpp"
35 #include "deRandom.hpp"
36 #include "deSTLUtil.hpp"
37 #include "deStringUtil.hpp"
38 #include "deUniquePtr.hpp"
39 #include "deSharedPtr.hpp"
40 #include "deArrayUtil.hpp"
41
42 #include "tcuCommandLine.hpp"
43 #include "tcuFloatFormat.hpp"
44 #include "tcuInterval.hpp"
45 #include "tcuTestLog.hpp"
46 #include "tcuVector.hpp"
47 #include "tcuMatrix.hpp"
48 #include "tcuResultCollector.hpp"
49 #include "tcuMaybe.hpp"
50
51 #include "gluContextInfo.hpp"
52 #include "gluVarType.hpp"
53 #include "gluRenderContext.hpp"
54 #include "glwDefs.hpp"
55
56 #include <cmath>
57 #include <string>
58 #include <sstream>
59 #include <iostream>
60 #include <map>
61 #include <utility>
62 #include <limits>
63
64 // Uncomment this to get evaluation trace dumps to std::cerr
65 // #define GLS_ENABLE_TRACE
66
67 // set this to true to dump even passing results
68 #define GLS_LOG_ALL_RESULTS false
69
70 #define FLOAT16_1_0 0x3C00 //1.0 float16bit
71 #define FLOAT16_180_0 0x59A0 //180.0 float16bit
72 #define FLOAT16_2_0 0x4000 //2.0 float16bit
73 #define FLOAT16_3_0 0x4200 //3.0 float16bit
74 #define FLOAT16_0_5 0x3800 //0.5 float16bit
75 #define FLOAT16_0_0 0x0000 //0.0 float16bit
76
77
78 using tcu::Vector;
79 typedef Vector<deFloat16, 1> Vec1_16Bit;
80 typedef Vector<deFloat16, 2> Vec2_16Bit;
81 typedef Vector<deFloat16, 3> Vec3_16Bit;
82 typedef Vector<deFloat16, 4> Vec4_16Bit;
83
84 typedef Vector<double, 1> Vec1_64Bit;
85 typedef Vector<double, 2> Vec2_64Bit;
86 typedef Vector<double, 3> Vec3_64Bit;
87 typedef Vector<double, 4> Vec4_64Bit;
88
89 enum
90 {
91 // Computing reference intervals can take a non-trivial amount of time, especially on
92 // platforms where toggling floating-point rounding mode is slow (emulated arm on x86).
93 // As a workaround watchdog is kept happy by touching it periodically during reference
94 // interval computation.
95 TOUCH_WATCHDOG_VALUE_FREQUENCY = 512
96 };
97
98 namespace vkt
99 {
100 namespace shaderexecutor
101 {
102
103 using std::string;
104 using std::map;
105 using std::ostream;
106 using std::ostringstream;
107 using std::pair;
108 using std::vector;
109 using std::set;
110
111 using de::MovePtr;
112 using de::Random;
113 using de::SharedPtr;
114 using de::UniquePtr;
115 using tcu::Interval;
116 using tcu::FloatFormat;
117 using tcu::MessageBuilder;
118 using tcu::TestLog;
119 using tcu::Vector;
120 using tcu::Matrix;
121 using glu::Precision;
122 using glu::VarType;
123 using glu::DataType;
124 using glu::ShaderType;
125
126 enum PrecisionTestFeatureBits
127 {
128 PRECISION_TEST_FEATURES_NONE = 0u,
129 PRECISION_TEST_FEATURES_16BIT_BUFFER_ACCESS = (1u << 1),
130 PRECISION_TEST_FEATURES_16BIT_UNIFORM_AND_STORAGE_BUFFER_ACCESS = (1u << 2),
131 PRECISION_TEST_FEATURES_16BIT_PUSH_CONSTANT = (1u << 3),
132 PRECISION_TEST_FEATURES_16BIT_INPUT_OUTPUT = (1u << 4),
133 PRECISION_TEST_FEATURES_16BIT_SHADER_FLOAT = (1u << 5),
134 PRECISION_TEST_FEATURES_64BIT_SHADER_FLOAT = (1u << 6),
135 };
136 typedef deUint32 PrecisionTestFeatures;
137
138
areFeaturesSupported(const Context & context,deUint32 toCheck)139 void areFeaturesSupported (const Context& context, deUint32 toCheck)
140 {
141 if (toCheck == PRECISION_TEST_FEATURES_NONE) return;
142
143 const vk::VkPhysicalDevice16BitStorageFeatures& extensionFeatures = context.get16BitStorageFeatures();
144
145 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_BUFFER_ACCESS) != 0 && extensionFeatures.storageBuffer16BitAccess == VK_FALSE)
146 TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
147
148 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_UNIFORM_AND_STORAGE_BUFFER_ACCESS) != 0 && extensionFeatures.uniformAndStorageBuffer16BitAccess == VK_FALSE)
149 TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
150
151 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_PUSH_CONSTANT) != 0 && extensionFeatures.storagePushConstant16 == VK_FALSE)
152 TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
153
154 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_INPUT_OUTPUT) != 0 && extensionFeatures.storageInputOutput16 == VK_FALSE)
155 TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
156
157 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_SHADER_FLOAT) != 0 && context.getShaderFloat16Int8Features().shaderFloat16 == VK_FALSE)
158 TCU_THROW(NotSupportedError, "Requested 16-bit floats (halfs) are not supported in shader code");
159
160 if ((toCheck & PRECISION_TEST_FEATURES_64BIT_SHADER_FLOAT) != 0 && context.getDeviceFeatures().shaderFloat64 == VK_FALSE)
161 TCU_THROW(NotSupportedError, "Requested 64-bit floats are not supported in shader code");
162 }
163
164 /*--------------------------------------------------------------------*//*!
165 * \brief Generic singleton creator.
166 *
167 * instance<T>() returns a reference to a unique default-constructed instance
168 * of T. This is mainly used for our GLSL function implementations: each
169 * function is implemented by an object, and each of the objects has a
170 * distinct class. It would be extremely toilsome to maintain a separate
171 * context object that contained individual instances of the function classes,
172 * so we have to resort to global singleton instances.
173 *
174 *//*--------------------------------------------------------------------*/
175 template <typename T>
instance(void)176 const T& instance (void)
177 {
178 static const T s_instance = T();
179 return s_instance;
180 }
181
182 /*--------------------------------------------------------------------*//*!
183 * \brief Empty placeholder type for unused template parameters.
184 *
185 * In the precision tests we are dealing with functions of different arities.
186 * To minimize code duplication, we only define templates with the maximum
187 * number of arguments, currently four. If a function's arity is less than the
188 * maximum, Void us used as the type for unused arguments.
189 *
190 * Although Voids are not used at run-time, they still must be compilable, so
191 * they must support all operations that other types do.
192 *
193 *//*--------------------------------------------------------------------*/
194 struct Void
195 {
196 typedef Void Element;
197 enum
198 {
199 SIZE = 0,
200 };
201
202 template <typename T>
Voidvkt::shaderexecutor::Void203 explicit Void (const T&) {}
Voidvkt::shaderexecutor::Void204 Void (void) {}
operator doublevkt::shaderexecutor::Void205 operator double (void) const { return TCU_NAN; }
206
207 // These are used to make Voids usable as containers in container-generic code.
operator []vkt::shaderexecutor::Void208 Void& operator[] (int) { return *this; }
operator []vkt::shaderexecutor::Void209 const Void& operator[] (int) const { return *this; }
210 };
211
operator <<(ostream & os,Void)212 ostream& operator<< (ostream& os, Void) { return os << "()"; }
213
214 //! Returns true for all other types except Void
isTypeValid(void)215 template <typename T> bool isTypeValid (void) { return true; }
isTypeValid(void)216 template <> bool isTypeValid<Void> (void) { return false; }
217
isInteger(void)218 template <typename T> bool isInteger (void) { return false; }
isInteger(void)219 template <> bool isInteger<int> (void) { return true; }
isInteger(void)220 template <> bool isInteger<tcu::IVec2> (void) { return true; }
isInteger(void)221 template <> bool isInteger<tcu::IVec3> (void) { return true; }
isInteger(void)222 template <> bool isInteger<tcu::IVec4> (void) { return true; }
223
224 //! Utility function for getting the name of a data type.
225 //! This is used in vector and matrix constructors.
226 template <typename T>
dataTypeNameOf(void)227 const char* dataTypeNameOf (void)
228 {
229 return glu::getDataTypeName(glu::dataTypeOf<T>());
230 }
231
232 template <>
dataTypeNameOf(void)233 const char* dataTypeNameOf<Void> (void)
234 {
235 DE_FATAL("Impossible");
236 return DE_NULL;
237 }
238
239 template <typename T>
getVarTypeOf(Precision prec=glu::PRECISION_LAST)240 VarType getVarTypeOf (Precision prec = glu::PRECISION_LAST)
241 {
242 return glu::varTypeOf<T>(prec);
243 }
244
245 //! A hack to get Void support for VarType.
246 template <>
getVarTypeOf(Precision)247 VarType getVarTypeOf<Void> (Precision)
248 {
249 DE_FATAL("Impossible");
250 return VarType();
251 }
252
253 /*--------------------------------------------------------------------*//*!
254 * \brief Type traits for generalized interval types.
255 *
256 * We are trying to compute sets of acceptable values not only for
257 * float-valued expressions but also for compound values: vectors and
258 * matrices. We approximate a set of vectors as a vector of intervals and
259 * likewise for matrices.
260 *
261 * We now need generalized operations for each type and its interval
262 * approximation. These are given in the type Traits<T>.
263 *
264 * The type Traits<T>::IVal is the approximation of T: it is `Interval` for
265 * scalar types, and a vector or matrix of intervals for container types.
266 *
267 * To allow template inference to take place, there are function wrappers for
268 * the actual operations in Traits<T>. Hence we can just use:
269 *
270 * makeIVal(someFloat)
271 *
272 * instead of:
273 *
274 * Traits<float>::doMakeIVal(value)
275 *
276 *//*--------------------------------------------------------------------*/
277
278 template <typename T> struct Traits;
279
280 //! Create container from elementwise singleton values.
281 template <typename T>
makeIVal(const T & value)282 typename Traits<T>::IVal makeIVal (const T& value)
283 {
284 return Traits<T>::doMakeIVal(value);
285 }
286
287 //! Elementwise union of intervals.
288 template <typename T>
unionIVal(const typename Traits<T>::IVal & a,const typename Traits<T>::IVal & b)289 typename Traits<T>::IVal unionIVal (const typename Traits<T>::IVal& a,
290 const typename Traits<T>::IVal& b)
291 {
292 return Traits<T>::doUnion(a, b);
293 }
294
295 //! Returns true iff every element of `ival` contains the corresponding element of `value`.
296 template <typename T, typename U = Void>
contains(const typename Traits<T>::IVal & ival,const T & value,bool is16Bit=false,const tcu::Maybe<U> & modularDivisor=tcu::Nothing)297 bool contains (const typename Traits<T>::IVal& ival, const T& value, bool is16Bit = false, const tcu::Maybe<U>& modularDivisor = tcu::Nothing)
298 {
299 return Traits<T>::doContains(ival, value, is16Bit, modularDivisor);
300 }
301
302 //! Print out an interval with the precision of `fmt`.
303 template <typename T>
printIVal(const FloatFormat & fmt,const typename Traits<T>::IVal & ival,ostream & os)304 void printIVal (const FloatFormat& fmt, const typename Traits<T>::IVal& ival, ostream& os)
305 {
306 Traits<T>::doPrintIVal(fmt, ival, os);
307 }
308
309 template <typename T>
intervalToString(const FloatFormat & fmt,const typename Traits<T>::IVal & ival)310 string intervalToString (const FloatFormat& fmt, const typename Traits<T>::IVal& ival)
311 {
312 ostringstream oss;
313 printIVal<T>(fmt, ival, oss);
314 return oss.str();
315 }
316
317 //! Print out a value with the precision of `fmt`.
318 template <typename T>
printValue16(const FloatFormat & fmt,const T & value,ostream & os)319 void printValue16 (const FloatFormat& fmt, const T& value, ostream& os)
320 {
321 Traits<T>::doPrintValue16(fmt, value, os);
322 }
323
324 template <typename T>
value16ToString(const FloatFormat & fmt,const T & val)325 string value16ToString(const FloatFormat& fmt, const T& val)
326 {
327 ostringstream oss;
328 printValue16(fmt, val, oss);
329 return oss.str();
330 }
331
getComparisonOperation(const int ndx)332 const std::string getComparisonOperation(const int ndx)
333 {
334 const int operationCount = 10;
335 DE_ASSERT(de::inBounds(ndx, 0, operationCount));
336 const std::string operations[operationCount] =
337 {
338 "OpFOrdEqual\t\t\t",
339 "OpFOrdGreaterThan\t",
340 "OpFOrdLessThan\t\t",
341 "OpFOrdGreaterThanEqual",
342 "OpFOrdLessThanEqual\t",
343 "OpFUnordEqual\t\t",
344 "OpFUnordGreaterThan\t",
345 "OpFUnordLessThan\t",
346 "OpFUnordGreaterThanEqual",
347 "OpFUnordLessThanEqual"
348 };
349 return operations[ndx];
350 }
351
352 template <typename T>
comparisonMessage(const T & val)353 string comparisonMessage(const T& val)
354 {
355 DE_UNREF(val);
356 return "";
357 }
358
359 template <>
comparisonMessage(const int & val)360 string comparisonMessage(const int& val)
361 {
362 ostringstream oss;
363
364 int flags = val;
365 for(int ndx = 0; ndx < 10; ++ndx)
366 {
367 oss << getComparisonOperation(ndx) << "\t:\t" << ((flags & 1) == 1 ? "TRUE" : "FALSE") << "\n";
368 flags = flags >> 1;
369 }
370 return oss.str();
371 }
372
373 template <>
comparisonMessage(const tcu::IVec2 & val)374 string comparisonMessage(const tcu::IVec2& val)
375 {
376 ostringstream oss;
377 tcu::IVec2 flags = val;
378 for (int ndx = 0; ndx < 10; ++ndx)
379 {
380 oss << getComparisonOperation(ndx) << "\t:\t" << ((flags.x() & 1) == 1 ? "TRUE" : "FALSE") << "\t" << ((flags.y() & 1) == 1 ? "TRUE" : "FALSE") << "\n";
381 flags.x() = flags.x() >> 1;
382 flags.y() = flags.y() >> 1;
383 }
384 return oss.str();
385 }
386
387 template <>
comparisonMessage(const tcu::IVec3 & val)388 string comparisonMessage(const tcu::IVec3& val)
389 {
390 ostringstream oss;
391 tcu::IVec3 flags = val;
392 for (int ndx = 0; ndx < 10; ++ndx)
393 {
394 oss << getComparisonOperation(ndx) << "\t:\t" << ((flags.x() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
395 << ((flags.y() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
396 << ((flags.z() & 1) == 1 ? "TRUE" : "FALSE") << "\n";
397 flags.x() = flags.x() >> 1;
398 flags.y() = flags.y() >> 1;
399 flags.z() = flags.z() >> 1;
400 }
401 return oss.str();
402 }
403
404 template <>
comparisonMessage(const tcu::IVec4 & val)405 string comparisonMessage(const tcu::IVec4& val)
406 {
407 ostringstream oss;
408 tcu::IVec4 flags = val;
409 for (int ndx = 0; ndx < 10; ++ndx)
410 {
411 oss << getComparisonOperation(ndx) << "\t:\t" << ((flags.x() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
412 << ((flags.y() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
413 << ((flags.z() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
414 << ((flags.w() & 1) == 1 ? "TRUE" : "FALSE") << "\n";
415 flags.x() = flags.x() >> 1;
416 flags.y() = flags.y() >> 1;
417 flags.z() = flags.z() >> 1;
418 flags.w() = flags.z() >> 1;
419 }
420 return oss.str();
421 }
422 //! Print out a value with the precision of `fmt`.
423 template <typename T>
printValue32(const FloatFormat & fmt,const T & value,ostream & os)424 void printValue32 (const FloatFormat& fmt, const T& value, ostream& os)
425 {
426 Traits<T>::doPrintValue32(fmt, value, os);
427 }
428
429 template <typename T>
value32ToString(const FloatFormat & fmt,const T & val)430 string value32ToString (const FloatFormat& fmt, const T& val)
431 {
432 ostringstream oss;
433 printValue32(fmt, val, oss);
434 return oss.str();
435 }
436
437 template <typename T>
printValue64(const FloatFormat & fmt,const T & value,ostream & os)438 void printValue64 (const FloatFormat& fmt, const T& value, ostream& os)
439 {
440 Traits<T>::doPrintValue64(fmt, value, os);
441 }
442
443 template <typename T>
value64ToString(const FloatFormat & fmt,const T & val)444 string value64ToString (const FloatFormat& fmt, const T& val)
445 {
446 ostringstream oss;
447 printValue64(fmt, val, oss);
448 return oss.str();
449 }
450
451 //! Approximate `value` elementwise to the float precision defined in `fmt`.
452 //! The resulting interval might not be a singleton if rounding in both
453 //! directions is allowed.
454 template <typename T>
round(const FloatFormat & fmt,const T & value)455 typename Traits<T>::IVal round (const FloatFormat& fmt, const T& value)
456 {
457 return Traits<T>::doRound(fmt, value);
458 }
459
460 template <typename T>
convert(const FloatFormat & fmt,const typename Traits<T>::IVal & value)461 typename Traits<T>::IVal convert (const FloatFormat& fmt,
462 const typename Traits<T>::IVal& value)
463 {
464 return Traits<T>::doConvert(fmt, value);
465 }
466
467 // Matching input and output types. We may be in a modulo case and modularDivisor may have an actual value.
468 template <typename T>
intervalContains(const Interval & interval,T value,const tcu::Maybe<T> & modularDivisor)469 bool intervalContains (const Interval& interval, T value, const tcu::Maybe<T>& modularDivisor)
470 {
471 bool contained = interval.contains(value);
472
473 if (!contained && modularDivisor)
474 {
475 const T divisor = modularDivisor.get();
476
477 // In a modulo operation, if the calculated answer contains the divisor, allow exactly 0.0 as a replacement. Alternatively,
478 // if the calculated answer contains 0.0, allow exactly the divisor as a replacement.
479 if (interval.contains(static_cast<double>(divisor)))
480 contained |= (value == 0.0);
481 if (interval.contains(0.0))
482 contained |= (value == divisor);
483 }
484 return contained;
485 }
486
487 // When the input and output types do not match, we are not in a real modulo operation. Do not take the divisor into account. This
488 // version is provided for syntactical compatibility only.
489 template <typename T, typename U>
intervalContains(const Interval & interval,T value,const tcu::Maybe<U> & modularDivisor)490 bool intervalContains (const Interval& interval, T value, const tcu::Maybe<U>& modularDivisor)
491 {
492 DE_UNREF(modularDivisor); // For release builds.
493 DE_ASSERT(!modularDivisor);
494 return interval.contains(value);
495 }
496
497 //! Common traits for scalar types.
498 template <typename T>
499 struct ScalarTraits
500 {
501 typedef Interval IVal;
502
doMakeIValvkt::shaderexecutor::ScalarTraits503 static Interval doMakeIVal (const T& value)
504 {
505 // Thankfully all scalar types have a well-defined conversion to `double`,
506 // hence Interval can represent their ranges without problems.
507 return Interval(double(value));
508 }
509
doUnionvkt::shaderexecutor::ScalarTraits510 static Interval doUnion (const Interval& a, const Interval& b)
511 {
512 return a | b;
513 }
514
doContainsvkt::shaderexecutor::ScalarTraits515 static bool doContains (const Interval& a, T value)
516 {
517 return a.contains(double(value));
518 }
519
doConvertvkt::shaderexecutor::ScalarTraits520 static Interval doConvert (const FloatFormat& fmt, const IVal& ival)
521 {
522 return fmt.convert(ival);
523 }
524
doConvertvkt::shaderexecutor::ScalarTraits525 static Interval doConvert (const FloatFormat& fmt, const IVal& ival, bool is16Bit)
526 {
527 DE_UNREF(is16Bit);
528 return fmt.convert(ival);
529 }
530
doRoundvkt::shaderexecutor::ScalarTraits531 static Interval doRound (const FloatFormat& fmt, T value)
532 {
533 return fmt.roundOut(double(value), false);
534 }
535 };
536
537 template <>
538 struct ScalarTraits<deUint16>
539 {
540 typedef Interval IVal;
541
doMakeIValvkt::shaderexecutor::ScalarTraits542 static Interval doMakeIVal (const deUint16& value)
543 {
544 // Thankfully all scalar types have a well-defined conversion to `double`,
545 // hence Interval can represent their ranges without problems.
546 return Interval(double(deFloat16To32(value)));
547 }
548
doUnionvkt::shaderexecutor::ScalarTraits549 static Interval doUnion (const Interval& a, const Interval& b)
550 {
551 return a | b;
552 }
553
doConvertvkt::shaderexecutor::ScalarTraits554 static Interval doConvert (const FloatFormat& fmt, const IVal& ival)
555 {
556 return fmt.convert(ival);
557 }
558
doRoundvkt::shaderexecutor::ScalarTraits559 static Interval doRound (const FloatFormat& fmt, deUint16 value)
560 {
561 return fmt.roundOut(double(deFloat16To32(value)), false);
562 }
563 };
564
565 template<>
566 struct Traits<float> : ScalarTraits<float>
567 {
doPrintIValvkt::shaderexecutor::Traits568 static void doPrintIVal (const FloatFormat& fmt,
569 const Interval& ival,
570 ostream& os)
571 {
572 os << fmt.intervalToHex(ival);
573 }
574
doPrintValue16vkt::shaderexecutor::Traits575 static void doPrintValue16 (const FloatFormat& fmt,
576 const float& value,
577 ostream& os)
578 {
579 const deUint32 iRep = reinterpret_cast<const deUint32 & >(value);
580 float res0 = deFloat16To32((deFloat16)(iRep & 0xFFFF));
581 float res1 = deFloat16To32((deFloat16)(iRep >> 16));
582 os << fmt.floatToHex(res0) << " " << fmt.floatToHex(res1);
583 }
584
doPrintValue32vkt::shaderexecutor::Traits585 static void doPrintValue32 (const FloatFormat& fmt,
586 const float& value,
587 ostream& os)
588 {
589 os << fmt.floatToHex(value);
590 }
591
doPrintValue64vkt::shaderexecutor::Traits592 static void doPrintValue64 (const FloatFormat& fmt,
593 const float& value,
594 ostream& os)
595 {
596 os << fmt.floatToHex(value);
597 }
598
599 template <typename U>
doContainsvkt::shaderexecutor::Traits600 static bool doContains (const Interval& a, const float& value, bool is16Bit, const tcu::Maybe<U>& modularDivisor)
601 {
602 if(is16Bit)
603 {
604 // Note: for deFloat16s packed in 32 bits, the original divisor is provided as a float to the shader in the input
605 // buffer, so U is also float here and we call the right interlvalContains() version.
606 const deUint32 iRep = reinterpret_cast<const deUint32&>(value);
607 float res0 = deFloat16To32((deFloat16)(iRep & 0xFFFF));
608 float res1 = deFloat16To32((deFloat16)(iRep >> 16));
609 return intervalContains(a, res0, modularDivisor) && (res1 == -1.0);
610 }
611 return intervalContains(a, value, modularDivisor);
612 }
613 };
614
615 template<>
616 struct Traits<double> : ScalarTraits<double>
617 {
doPrintIValvkt::shaderexecutor::Traits618 static void doPrintIVal (const FloatFormat& fmt,
619 const Interval& ival,
620 ostream& os)
621 {
622 os << fmt.intervalToHex(ival);
623 }
624
doPrintValue16vkt::shaderexecutor::Traits625 static void doPrintValue16 (const FloatFormat& fmt,
626 const double& value,
627 ostream& os)
628 {
629 const deUint64 iRep = reinterpret_cast<const deUint64&>(value);
630 double byte0 = deFloat16To64((deFloat16)((iRep ) & 0xffff));
631 double byte1 = deFloat16To64((deFloat16)((iRep >> 16) & 0xffff));
632 double byte2 = deFloat16To64((deFloat16)((iRep >> 32) & 0xffff));
633 double byte3 = deFloat16To64((deFloat16)((iRep >> 48) & 0xffff));
634 os << fmt.floatToHex(byte0) << " " << fmt.floatToHex(byte1) << " " << fmt.floatToHex(byte2) << " " << fmt.floatToHex(byte3);
635 }
636
doPrintValue32vkt::shaderexecutor::Traits637 static void doPrintValue32 (const FloatFormat& fmt,
638 const double& value,
639 ostream& os)
640 {
641 const deUint64 iRep = reinterpret_cast<const deUint64&>(value);
642 double res0 = static_cast<double>((float)((iRep ) & 0xffffffff));
643 double res1 = static_cast<double>((float)((iRep >> 32) & 0xffffffff));
644 os << fmt.floatToHex(res0) << " " << fmt.floatToHex(res1);
645 }
646
doPrintValue64vkt::shaderexecutor::Traits647 static void doPrintValue64 (const FloatFormat& fmt,
648 const double& value,
649 ostream& os)
650 {
651 os << fmt.floatToHex(value);
652 }
653
654 template <class U>
doContainsvkt::shaderexecutor::Traits655 static bool doContains (const Interval& a, const double& value, bool is16Bit, const tcu::Maybe<U>& modularDivisor)
656 {
657 DE_UNREF(is16Bit);
658 DE_ASSERT(!is16Bit);
659 return intervalContains(a, value, modularDivisor);
660 }
661 };
662
663 template<>
664 struct Traits<deFloat16> : ScalarTraits<deFloat16>
665 {
doPrintIValvkt::shaderexecutor::Traits666 static void doPrintIVal (const FloatFormat& fmt,
667 const Interval& ival,
668 ostream& os)
669 {
670 os << fmt.intervalToHex(ival);
671 }
672
doPrintValue16vkt::shaderexecutor::Traits673 static void doPrintValue16 (const FloatFormat& fmt,
674 const deFloat16& value,
675 ostream& os)
676 {
677 const float res0 = deFloat16To32(value);
678 os << fmt.floatToHex(static_cast<double>(res0));
679 }
doPrintValue32vkt::shaderexecutor::Traits680 static void doPrintValue32 (const FloatFormat& fmt,
681 const deFloat16& value,
682 ostream& os)
683 {
684 const float res0 = deFloat16To32(value);
685 os << fmt.floatToHex(static_cast<double>(res0));
686 }
687
doPrintValue64vkt::shaderexecutor::Traits688 static void doPrintValue64 (const FloatFormat& fmt,
689 const deFloat16& value,
690 ostream& os)
691 {
692 const double res0 = deFloat16To64(value);
693 os << fmt.floatToHex(res0);
694 }
695
696 // When the value and divisor are both deFloat16, convert both to float to call the right intervalContains version.
doContainsvkt::shaderexecutor::Traits697 static bool doContains (const Interval& a, const deFloat16& value, bool is16Bit, const tcu::Maybe<deFloat16>& modularDivisor)
698 {
699 DE_UNREF(is16Bit);
700 float res0 = deFloat16To32(value);
701 const tcu::Maybe<float> convertedDivisor = (modularDivisor ? tcu::just(deFloat16To32(modularDivisor.get())) : tcu::Nothing);
702 return intervalContains(a, res0, convertedDivisor);
703 }
704
705 // If the types don't match we should not be in a modulo operation, so no conversion should take place.
706 template <class U>
doContainsvkt::shaderexecutor::Traits707 static bool doContains (const Interval& a, const deFloat16& value, bool is16Bit, const tcu::Maybe<U>& modularDivisor)
708 {
709 DE_UNREF(is16Bit);
710 float res0 = deFloat16To32(value);
711 return intervalContains(a, res0, modularDivisor);
712 }
713 };
714
715 template<>
716 struct Traits<bool> : ScalarTraits<bool>
717 {
doPrintValue16vkt::shaderexecutor::Traits718 static void doPrintValue16 (const FloatFormat&,
719 const float& value,
720 ostream& os)
721 {
722 os << (value != 0.0f ? "true" : "false");
723 }
724
doPrintValue32vkt::shaderexecutor::Traits725 static void doPrintValue32 (const FloatFormat&,
726 const float& value,
727 ostream& os)
728 {
729 os << (value != 0.0f ? "true" : "false");
730 }
731
doPrintValue64vkt::shaderexecutor::Traits732 static void doPrintValue64 (const FloatFormat&,
733 const float& value,
734 ostream& os)
735 {
736 os << (value != 0.0f ? "true" : "false");
737 }
738
doPrintIValvkt::shaderexecutor::Traits739 static void doPrintIVal (const FloatFormat&,
740 const Interval& ival,
741 ostream& os)
742 {
743 os << "{";
744 if (ival.contains(false))
745 os << "false";
746 if (ival.contains(false) && ival.contains(true))
747 os << ", ";
748 if (ival.contains(true))
749 os << "true";
750 os << "}";
751 }
752 };
753
754 template<>
755 struct Traits<int> : ScalarTraits<int>
756 {
doPrintValue16vkt::shaderexecutor::Traits757 static void doPrintValue16 (const FloatFormat&,
758 const int& value,
759 ostream& os)
760 {
761 int res0 = value & 0xFFFF;
762 int res1 = value >> 16;
763 os << res0 << " " << res1;
764 }
765
doPrintValue32vkt::shaderexecutor::Traits766 static void doPrintValue32 (const FloatFormat&,
767 const int& value,
768 ostream& os)
769 {
770 os << value;
771 }
772
doPrintValue64vkt::shaderexecutor::Traits773 static void doPrintValue64 (const FloatFormat&,
774 const int& value,
775 ostream& os)
776 {
777 os << value;
778 }
779
doPrintIValvkt::shaderexecutor::Traits780 static void doPrintIVal (const FloatFormat&,
781 const Interval& ival,
782 ostream& os)
783 {
784 os << "[" << int(ival.lo()) << ", " << int(ival.hi()) << "]";
785 }
786
787 template <typename U>
doContainsvkt::shaderexecutor::Traits788 static bool doContains (const Interval& a, const int& value, bool is16Bit, const tcu::Maybe<U>& modularDivisor)
789 {
790 DE_UNREF(is16Bit);
791 return intervalContains(a, value, modularDivisor);
792 }
793 };
794
795 //! Common traits for containers, i.e. vectors and matrices.
796 //! T is the container type itself, I is the same type with interval elements.
797 template <typename T, typename I>
798 struct ContainerTraits
799 {
800 typedef typename T::Element Element;
801 typedef I IVal;
802
doMakeIValvkt::shaderexecutor::ContainerTraits803 static IVal doMakeIVal (const T& value)
804 {
805 IVal ret;
806
807 for (int ndx = 0; ndx < T::SIZE; ++ndx)
808 ret[ndx] = makeIVal(value[ndx]);
809
810 return ret;
811 }
812
doUnionvkt::shaderexecutor::ContainerTraits813 static IVal doUnion (const IVal& a, const IVal& b)
814 {
815 IVal ret;
816
817 for (int ndx = 0; ndx < T::SIZE; ++ndx)
818 ret[ndx] = unionIVal<Element>(a[ndx], b[ndx]);
819
820 return ret;
821 }
822
823 // When the input and output types match, we may be in a modulo operation. If the divisor is provided, use each of its
824 // components to determine if the obtained result is fine.
doContainsvkt::shaderexecutor::ContainerTraits825 static bool doContains (const IVal& ival, const T& value, bool is16Bit, const tcu::Maybe<T>& modularDivisor)
826 {
827 using DivisorElement = typename T::Element;
828
829 for (int ndx = 0; ndx < T::SIZE; ++ndx)
830 {
831 const tcu::Maybe<DivisorElement> divisorElement = (modularDivisor ? tcu::just((*modularDivisor)[ndx]) : tcu::Nothing);
832 if (!contains(ival[ndx], value[ndx], is16Bit, divisorElement))
833 return false;
834 }
835
836 return true;
837 }
838
839 // When the input and output types do not match we should not be in a modulo operation. This version is provided for syntactical
840 // compatibility.
841 template <typename U>
doContainsvkt::shaderexecutor::ContainerTraits842 static bool doContains (const IVal& ival, const T& value, bool is16Bit, const tcu::Maybe<U>& modularDivisor)
843 {
844 for (int ndx = 0; ndx < T::SIZE; ++ndx)
845 {
846 if (!contains(ival[ndx], value[ndx], is16Bit, modularDivisor))
847 return false;
848 }
849
850 return true;
851 }
852
doPrintIValvkt::shaderexecutor::ContainerTraits853 static void doPrintIVal (const FloatFormat& fmt, const IVal ival, ostream& os)
854 {
855 os << "(";
856
857 for (int ndx = 0; ndx < T::SIZE; ++ndx)
858 {
859 if (ndx > 0)
860 os << ", ";
861
862 printIVal<Element>(fmt, ival[ndx], os);
863 }
864
865 os << ")";
866 }
867
doPrintValue16vkt::shaderexecutor::ContainerTraits868 static void doPrintValue16 (const FloatFormat& fmt, const T& value, ostream& os)
869 {
870 os << dataTypeNameOf<T>() << "(";
871
872 for (int ndx = 0; ndx < T::SIZE; ++ndx)
873 {
874 if (ndx > 0)
875 os << ", ";
876
877 printValue16<Element>(fmt, value[ndx], os);
878 }
879
880 os << ")";
881 }
882
doPrintValue32vkt::shaderexecutor::ContainerTraits883 static void doPrintValue32 (const FloatFormat& fmt, const T& value, ostream& os)
884 {
885 os << dataTypeNameOf<T>() << "(";
886
887 for (int ndx = 0; ndx < T::SIZE; ++ndx)
888 {
889 if (ndx > 0)
890 os << ", ";
891
892 printValue32<Element>(fmt, value[ndx], os);
893 }
894
895 os << ")";
896 }
897
doPrintValue64vkt::shaderexecutor::ContainerTraits898 static void doPrintValue64 (const FloatFormat& fmt, const T& value, ostream& os)
899 {
900 os << dataTypeNameOf<T>() << "(";
901
902 for (int ndx = 0; ndx < T::SIZE; ++ndx)
903 {
904 if (ndx > 0)
905 os << ", ";
906
907 printValue64<Element>(fmt, value[ndx], os);
908 }
909
910 os << ")";
911 }
912
doConvertvkt::shaderexecutor::ContainerTraits913 static IVal doConvert (const FloatFormat& fmt, const IVal& value)
914 {
915 IVal ret;
916
917 for (int ndx = 0; ndx < T::SIZE; ++ndx)
918 ret[ndx] = convert<Element>(fmt, value[ndx]);
919
920 return ret;
921 }
922
doRoundvkt::shaderexecutor::ContainerTraits923 static IVal doRound (const FloatFormat& fmt, T value)
924 {
925 IVal ret;
926
927 for (int ndx = 0; ndx < T::SIZE; ++ndx)
928 ret[ndx] = round(fmt, value[ndx]);
929
930 return ret;
931 }
932 };
933
934 template <typename T, int Size>
935 struct Traits<Vector<T, Size> > :
936 ContainerTraits<Vector<T, Size>, Vector<typename Traits<T>::IVal, Size> >
937 {
938 };
939
940 template <typename T, int Rows, int Cols>
941 struct Traits<Matrix<T, Rows, Cols> > :
942 ContainerTraits<Matrix<T, Rows, Cols>, Matrix<typename Traits<T>::IVal, Rows, Cols> >
943 {
944 };
945
946 //! Void traits. These are just dummies, but technically valid: a Void is a
947 //! unit type with a single possible value.
948 template<>
949 struct Traits<Void>
950 {
951 typedef Void IVal;
952
doMakeIValvkt::shaderexecutor::Traits953 static Void doMakeIVal (const Void& value) { return value; }
doUnionvkt::shaderexecutor::Traits954 static Void doUnion (const Void&, const Void&) { return Void(); }
doContainsvkt::shaderexecutor::Traits955 static bool doContains (const Void&, Void) { return true; }
956 template <typename U>
doContainsvkt::shaderexecutor::Traits957 static bool doContains (const Void&, const Void& value, bool is16Bit, const tcu::Maybe<U>& modularDivisor) { DE_UNREF(value); DE_UNREF(is16Bit); DE_UNREF(modularDivisor); return true; }
doRoundvkt::shaderexecutor::Traits958 static Void doRound (const FloatFormat&, const Void& value) { return value; }
doConvertvkt::shaderexecutor::Traits959 static Void doConvert (const FloatFormat&, const Void& value) { return value; }
960
doPrintValue16vkt::shaderexecutor::Traits961 static void doPrintValue16 (const FloatFormat&, const Void&, ostream& os)
962 {
963 os << "()";
964 }
965
doPrintValue32vkt::shaderexecutor::Traits966 static void doPrintValue32 (const FloatFormat&, const Void&, ostream& os)
967 {
968 os << "()";
969 }
970
doPrintValue64vkt::shaderexecutor::Traits971 static void doPrintValue64 (const FloatFormat&, const Void&, ostream& os)
972 {
973 os << "()";
974 }
975
doPrintIValvkt::shaderexecutor::Traits976 static void doPrintIVal (const FloatFormat&, const Void&, ostream& os)
977 {
978 os << "()";
979 }
980 };
981
982 //! This is needed for container-generic operations.
983 //! We want a scalar type T to be its own "one-element vector".
984 template <typename T, int Size> struct ContainerOf { typedef Vector<T, Size> Container; };
985
986 template <typename T> struct ContainerOf<T, 1> { typedef T Container; };
987 template <int Size> struct ContainerOf<Void, Size> { typedef Void Container; };
988
989 // This is a kludge that is only needed to get the ExprP::operator[] syntactic sugar to work.
990 template <typename T> struct ElementOf { typedef typename T::Element Element; };
991 template <> struct ElementOf<float> { typedef void Element; };
992 template <> struct ElementOf<double>{ typedef void Element; };
993 template <> struct ElementOf<bool> { typedef void Element; };
994 template <> struct ElementOf<int> { typedef void Element; };
995
996 template <typename T>
comparisonMessageInterval(const typename Traits<T>::IVal & val)997 string comparisonMessageInterval(const typename Traits<T>::IVal& val)
998 {
999 DE_UNREF(val);
1000 return "";
1001 }
1002
1003 template <>
comparisonMessageInterval(const Traits<int>::IVal & val)1004 string comparisonMessageInterval<int>(const Traits<int>::IVal& val)
1005 {
1006 return comparisonMessage(static_cast<int>(val.lo()));
1007 }
1008
1009 template <>
comparisonMessageInterval(const Traits<float>::IVal & val)1010 string comparisonMessageInterval<float>(const Traits<float>::IVal& val)
1011 {
1012 return comparisonMessage(static_cast<int>(val.lo()));
1013 }
1014
1015 template <>
comparisonMessageInterval(const tcu::Vector<tcu::Interval,2> & val)1016 string comparisonMessageInterval<tcu::Vector<int, 2> >(const tcu::Vector<tcu::Interval, 2> & val)
1017 {
1018 tcu::IVec2 result(static_cast<int>(val[0].lo()), static_cast<int>(val[1].lo()));
1019 return comparisonMessage(result);
1020 }
1021
1022 template <>
comparisonMessageInterval(const tcu::Vector<tcu::Interval,3> & val)1023 string comparisonMessageInterval<tcu::Vector<int, 3> >(const tcu::Vector<tcu::Interval, 3> & val)
1024 {
1025 tcu::IVec3 result(static_cast<int>(val[0].lo()), static_cast<int>(val[1].lo()), static_cast<int>(val[2].lo()));
1026 return comparisonMessage(result);
1027 }
1028
1029 template <>
comparisonMessageInterval(const tcu::Vector<tcu::Interval,4> & val)1030 string comparisonMessageInterval<tcu::Vector<int, 4> >(const tcu::Vector<tcu::Interval, 4> & val)
1031 {
1032 tcu::IVec4 result(static_cast<int>(val[0].lo()), static_cast<int>(val[1].lo()), static_cast<int>(val[2].lo()), static_cast<int>(val[3].lo()));
1033 return comparisonMessage(result);
1034 }
1035
1036 /*--------------------------------------------------------------------*//*!
1037 *
1038 * \name Abstract syntax for expressions and statements.
1039 *
1040 * We represent GLSL programs as syntax objects: an Expr<T> represents an
1041 * expression whose GLSL type corresponds to the C++ type T, and a Statement
1042 * represents a statement.
1043 *
1044 * To ease memory management, we use shared pointers to refer to expressions
1045 * and statements. ExprP<T> is a shared pointer to an Expr<T>, and StatementP
1046 * is a shared pointer to a Statement.
1047 *
1048 * \{
1049 *
1050 *//*--------------------------------------------------------------------*/
1051
1052 class ExprBase;
1053 class ExpandContext;
1054 class Statement;
1055 class StatementP;
1056 class FuncBase;
1057 template <typename T> class ExprP;
1058 template <typename T> class Variable;
1059 template <typename T> class VariableP;
1060 template <typename T> class DefaultSampling;
1061
1062 typedef set<const FuncBase*> FuncSet;
1063
1064 template <typename T>
1065 VariableP<T> variable (const string& name);
1066 StatementP compoundStatement (const vector<StatementP>& statements);
1067
1068 /*--------------------------------------------------------------------*//*!
1069 * \brief A variable environment.
1070 *
1071 * An Environment object maintains the mapping between variables of the
1072 * abstract syntax tree and their values.
1073 *
1074 * \todo [2014-03-28 lauri] At least run-time type safety.
1075 *
1076 *//*--------------------------------------------------------------------*/
1077 class Environment
1078 {
1079 public:
1080 template<typename T>
bind(const Variable<T> & variable,const typename Traits<T>::IVal & value)1081 void bind (const Variable<T>& variable,
1082 const typename Traits<T>::IVal& value)
1083 {
1084 deUint8* const data = new deUint8[sizeof(value)];
1085
1086 deMemcpy(data, &value, sizeof(value));
1087 de::insert(m_map, variable.getName(), SharedPtr<deUint8>(data, de::ArrayDeleter<deUint8>()));
1088 }
1089
1090 template<typename T>
lookup(const Variable<T> & variable) const1091 typename Traits<T>::IVal& lookup (const Variable<T>& variable) const
1092 {
1093 deUint8* const data = de::lookup(m_map, variable.getName()).get();
1094
1095 return *reinterpret_cast<typename Traits<T>::IVal*>(data);
1096 }
1097
1098 private:
1099 map<string, SharedPtr<deUint8> > m_map;
1100 };
1101
1102 /*--------------------------------------------------------------------*//*!
1103 * \brief Evaluation context.
1104 *
1105 * The evaluation context contains everything that separates one execution of
1106 * an expression from the next. Currently this means the desired floating
1107 * point precision and the current variable environment.
1108 *
1109 *//*--------------------------------------------------------------------*/
1110 struct EvalContext
1111 {
EvalContextvkt::shaderexecutor::EvalContext1112 EvalContext (const FloatFormat& format_,
1113 Precision floatPrecision_,
1114 Environment& env_,
1115 int callDepth_)
1116 : format (format_)
1117 , floatPrecision (floatPrecision_)
1118 , env (env_)
1119 , callDepth (callDepth_) {}
1120
1121 FloatFormat format;
1122 Precision floatPrecision;
1123 Environment& env;
1124 int callDepth;
1125 };
1126
1127 /*--------------------------------------------------------------------*//*!
1128 * \brief Simple incremental counter.
1129 *
1130 * This is used to make sure that different ExpandContexts will not produce
1131 * overlapping temporary names.
1132 *
1133 *//*--------------------------------------------------------------------*/
1134 class Counter
1135 {
1136 public:
Counter(int count=0)1137 Counter (int count = 0) : m_count(count) {}
operator ()(void)1138 int operator() (void) { return m_count++; }
1139
1140 private:
1141 int m_count;
1142 };
1143
1144 class ExpandContext
1145 {
1146 public:
ExpandContext(Counter & symCounter)1147 ExpandContext (Counter& symCounter) : m_symCounter(symCounter) {}
ExpandContext(const ExpandContext & parent)1148 ExpandContext (const ExpandContext& parent)
1149 : m_symCounter(parent.m_symCounter) {}
1150
1151 template<typename T>
genSym(const string & baseName)1152 VariableP<T> genSym (const string& baseName)
1153 {
1154 return variable<T>(baseName + de::toString(m_symCounter()));
1155 }
1156
addStatement(const StatementP & stmt)1157 void addStatement (const StatementP& stmt)
1158 {
1159 m_statements.push_back(stmt);
1160 }
1161
getStatements(void) const1162 vector<StatementP> getStatements (void) const
1163 {
1164 return m_statements;
1165 }
1166 private:
1167 Counter& m_symCounter;
1168 vector<StatementP> m_statements;
1169 };
1170
1171 /*--------------------------------------------------------------------*//*!
1172 * \brief A statement or declaration.
1173 *
1174 * Statements have no values. Instead, they are executed for their side
1175 * effects only: the execute() method should modify at least one variable in
1176 * the environment.
1177 *
1178 * As a bit of a kludge, a Statement object can also represent a declaration:
1179 * when it is evaluated, it can add a variable binding to the environment
1180 * instead of modifying a current one.
1181 *
1182 *//*--------------------------------------------------------------------*/
1183 class Statement
1184 {
1185 public:
~Statement(void)1186 virtual ~Statement (void) { }
1187 //! Execute the statement, modifying the environment of `ctx`
execute(EvalContext & ctx) const1188 void execute (EvalContext& ctx) const { this->doExecute(ctx); }
print(ostream & os) const1189 void print (ostream& os) const { this->doPrint(os); }
1190 //! Add the functions used in this statement to `dst`.
getUsedFuncs(FuncSet & dst) const1191 void getUsedFuncs (FuncSet& dst) const { this->doGetUsedFuncs(dst); }
failed(EvalContext & ctx) const1192 void failed (EvalContext& ctx) const { this->doFail(ctx); }
1193
1194 protected:
1195 virtual void doPrint (ostream& os) const = 0;
1196 virtual void doExecute (EvalContext& ctx) const = 0;
1197 virtual void doGetUsedFuncs (FuncSet& dst) const = 0;
doFail(EvalContext & ctx) const1198 virtual void doFail (EvalContext& ctx) const { DE_UNREF(ctx); }
1199 };
1200
operator <<(ostream & os,const Statement & stmt)1201 ostream& operator<<(ostream& os, const Statement& stmt)
1202 {
1203 stmt.print(os);
1204 return os;
1205 }
1206
1207 /*--------------------------------------------------------------------*//*!
1208 * \brief Smart pointer for statements (and declarations)
1209 *
1210 *//*--------------------------------------------------------------------*/
1211 class StatementP : public SharedPtr<const Statement>
1212 {
1213 public:
1214 typedef SharedPtr<const Statement> Super;
1215
StatementP(void)1216 StatementP (void) {}
StatementP(const Statement * ptr)1217 explicit StatementP (const Statement* ptr) : Super(ptr) {}
StatementP(const Super & ptr)1218 StatementP (const Super& ptr) : Super(ptr) {}
1219 };
1220
1221 /*--------------------------------------------------------------------*//*!
1222 * \brief
1223 *
1224 * A statement that modifies a variable or a declaration that binds a variable.
1225 *
1226 *//*--------------------------------------------------------------------*/
1227 template <typename T>
1228 class VariableStatement : public Statement
1229 {
1230 public:
VariableStatement(const VariableP<T> & variable,const ExprP<T> & value,bool isDeclaration)1231 VariableStatement (const VariableP<T>& variable, const ExprP<T>& value,
1232 bool isDeclaration)
1233 : m_variable (variable)
1234 , m_value (value)
1235 , m_isDeclaration (isDeclaration) {}
1236
1237 protected:
doPrint(ostream & os) const1238 void doPrint (ostream& os) const
1239 {
1240 if (m_isDeclaration)
1241 os << glu::declare(getVarTypeOf<T>(), m_variable->getName());
1242 else
1243 os << m_variable->getName();
1244
1245 os << " = ";
1246 os<< *m_value << ";\n";
1247 }
1248
doExecute(EvalContext & ctx) const1249 void doExecute (EvalContext& ctx) const
1250 {
1251 if (m_isDeclaration)
1252 ctx.env.bind(*m_variable, m_value->evaluate(ctx));
1253 else
1254 ctx.env.lookup(*m_variable) = m_value->evaluate(ctx);
1255 }
1256
doGetUsedFuncs(FuncSet & dst) const1257 void doGetUsedFuncs (FuncSet& dst) const
1258 {
1259 m_value->getUsedFuncs(dst);
1260 }
1261
doFail(EvalContext & ctx) const1262 virtual void doFail (EvalContext& ctx) const
1263 {
1264 if (m_isDeclaration)
1265 ctx.env.bind(*m_variable, m_value->fails(ctx));
1266 else
1267 ctx.env.lookup(*m_variable) = m_value->fails(ctx);
1268 }
1269
1270 VariableP<T> m_variable;
1271 ExprP<T> m_value;
1272 bool m_isDeclaration;
1273 };
1274
1275 template <typename T>
variableStatement(const VariableP<T> & variable,const ExprP<T> & value,bool isDeclaration)1276 StatementP variableStatement (const VariableP<T>& variable,
1277 const ExprP<T>& value,
1278 bool isDeclaration)
1279 {
1280 return StatementP(new VariableStatement<T>(variable, value, isDeclaration));
1281 }
1282
1283 template <typename T>
variableDeclaration(const VariableP<T> & variable,const ExprP<T> & definiens)1284 StatementP variableDeclaration (const VariableP<T>& variable, const ExprP<T>& definiens)
1285 {
1286 return variableStatement(variable, definiens, true);
1287 }
1288
1289 template <typename T>
variableAssignment(const VariableP<T> & variable,const ExprP<T> & value)1290 StatementP variableAssignment (const VariableP<T>& variable, const ExprP<T>& value)
1291 {
1292 return variableStatement(variable, value, false);
1293 }
1294
1295 /*--------------------------------------------------------------------*//*!
1296 * \brief A compound statement, i.e. a block.
1297 *
1298 * A compound statement is executed by executing its constituent statements in
1299 * sequence.
1300 *
1301 *//*--------------------------------------------------------------------*/
1302 class CompoundStatement : public Statement
1303 {
1304 public:
CompoundStatement(const vector<StatementP> & statements)1305 CompoundStatement (const vector<StatementP>& statements)
1306 : m_statements (statements) {}
1307
1308 protected:
doPrint(ostream & os) const1309 void doPrint (ostream& os) const
1310 {
1311 os << "{\n";
1312
1313 for (size_t ndx = 0; ndx < m_statements.size(); ++ndx)
1314 os << *m_statements[ndx];
1315
1316 os << "}\n";
1317 }
1318
doExecute(EvalContext & ctx) const1319 void doExecute (EvalContext& ctx) const
1320 {
1321 for (size_t ndx = 0; ndx < m_statements.size(); ++ndx)
1322 m_statements[ndx]->execute(ctx);
1323 }
1324
doGetUsedFuncs(FuncSet & dst) const1325 void doGetUsedFuncs (FuncSet& dst) const
1326 {
1327 for (size_t ndx = 0; ndx < m_statements.size(); ++ndx)
1328 m_statements[ndx]->getUsedFuncs(dst);
1329 }
1330
1331 vector<StatementP> m_statements;
1332 };
1333
compoundStatement(const vector<StatementP> & statements)1334 StatementP compoundStatement(const vector<StatementP>& statements)
1335 {
1336 return StatementP(new CompoundStatement(statements));
1337 }
1338
1339 //! Common base class for all expressions regardless of their type.
1340 class ExprBase
1341 {
1342 public:
~ExprBase(void)1343 virtual ~ExprBase (void) {}
printExpr(ostream & os) const1344 void printExpr (ostream& os) const { this->doPrintExpr(os); }
1345
1346 //! Output the functions that this expression refers to
getUsedFuncs(FuncSet & dst) const1347 void getUsedFuncs (FuncSet& dst) const
1348 {
1349 this->doGetUsedFuncs(dst);
1350 }
1351
1352 protected:
doPrintExpr(ostream &) const1353 virtual void doPrintExpr (ostream&) const {}
doGetUsedFuncs(FuncSet &) const1354 virtual void doGetUsedFuncs (FuncSet&) const {}
1355 };
1356
1357 //! Type-specific operations for an expression representing type T.
1358 template <typename T>
1359 class Expr : public ExprBase
1360 {
1361 public:
1362 typedef T Val;
1363 typedef typename Traits<T>::IVal IVal;
1364
1365 IVal evaluate (const EvalContext& ctx) const;
fails(const EvalContext & ctx) const1366 IVal fails (const EvalContext& ctx) const { return this->doFails(ctx); }
1367
1368 protected:
1369 virtual IVal doEvaluate (const EvalContext& ctx) const = 0;
doFails(const EvalContext & ctx) const1370 virtual IVal doFails (const EvalContext& ctx) const {return doEvaluate(ctx);}
1371 };
1372
1373 //! Evaluate an expression with the given context, optionally tracing the calls to stderr.
1374 template <typename T>
evaluate(const EvalContext & ctx) const1375 typename Traits<T>::IVal Expr<T>::evaluate (const EvalContext& ctx) const
1376 {
1377 #ifdef GLS_ENABLE_TRACE
1378 static const FloatFormat highpFmt (-126, 127, 23, true,
1379 tcu::MAYBE,
1380 tcu::YES,
1381 tcu::MAYBE);
1382 EvalContext newCtx (ctx.format, ctx.floatPrecision,
1383 ctx.env, ctx.callDepth + 1);
1384 const IVal ret = this->doEvaluate(newCtx);
1385
1386 if (isTypeValid<T>())
1387 {
1388 std::cerr << string(ctx.callDepth, ' ');
1389 this->printExpr(std::cerr);
1390 std::cerr << " -> " << intervalToString<T>(highpFmt, ret) << std::endl;
1391 }
1392 return ret;
1393 #else
1394 return this->doEvaluate(ctx);
1395 #endif
1396 }
1397
1398 template <typename T>
1399 class ExprPBase : public SharedPtr<const Expr<T> >
1400 {
1401 public:
1402 };
1403
operator <<(ostream & os,const ExprBase & expr)1404 ostream& operator<< (ostream& os, const ExprBase& expr)
1405 {
1406 expr.printExpr(os);
1407 return os;
1408 }
1409
1410 /*--------------------------------------------------------------------*//*!
1411 * \brief Shared pointer to an expression of a container type.
1412 *
1413 * Container types (i.e. vectors and matrices) support the subscription
1414 * operator. This class provides a bit of syntactic sugar to allow us to use
1415 * the C++ subscription operator to create a subscription expression.
1416 *//*--------------------------------------------------------------------*/
1417 template <typename T>
1418 class ContainerExprPBase : public ExprPBase<T>
1419 {
1420 public:
1421 ExprP<typename T::Element> operator[] (int i) const;
1422 };
1423
1424 template <typename T>
1425 class ExprP : public ExprPBase<T> {};
1426
1427 // We treat Voids as containers since the unused parameters in generalized
1428 // vector functions are represented as Voids.
1429 template <>
1430 class ExprP<Void> : public ContainerExprPBase<Void> {};
1431
1432 template <typename T, int Size>
1433 class ExprP<Vector<T, Size> > : public ContainerExprPBase<Vector<T, Size> > {};
1434
1435 template <typename T, int Rows, int Cols>
1436 class ExprP<Matrix<T, Rows, Cols> > : public ContainerExprPBase<Matrix<T, Rows, Cols> > {};
1437
exprP(void)1438 template <typename T> ExprP<T> exprP (void)
1439 {
1440 return ExprP<T>();
1441 }
1442
1443 template <typename T>
exprP(const SharedPtr<const Expr<T>> & ptr)1444 ExprP<T> exprP (const SharedPtr<const Expr<T> >& ptr)
1445 {
1446 ExprP<T> ret;
1447 static_cast<SharedPtr<const Expr<T> >&>(ret) = ptr;
1448 return ret;
1449 }
1450
1451 template <typename T>
exprP(const Expr<T> * ptr)1452 ExprP<T> exprP (const Expr<T>* ptr)
1453 {
1454 return exprP(SharedPtr<const Expr<T> >(ptr));
1455 }
1456
1457 /*--------------------------------------------------------------------*//*!
1458 * \brief A shared pointer to a variable expression.
1459 *
1460 * This is just a narrowing of ExprP for the operations that require a variable
1461 * instead of an arbitrary expression.
1462 *
1463 *//*--------------------------------------------------------------------*/
1464 template <typename T>
1465 class VariableP : public SharedPtr<const Variable<T> >
1466 {
1467 public:
1468 typedef SharedPtr<const Variable<T> > Super;
VariableP(const Variable<T> * ptr)1469 explicit VariableP (const Variable<T>* ptr) : Super(ptr) {}
VariableP(void)1470 VariableP (void) {}
VariableP(const Super & ptr)1471 VariableP (const Super& ptr) : Super(ptr) {}
1472
operator ExprP<T>(void) const1473 operator ExprP<T> (void) const { return exprP(SharedPtr<const Expr<T> >(*this)); }
1474 };
1475
1476 /*--------------------------------------------------------------------*//*!
1477 * \name Syntactic sugar operators for expressions.
1478 *
1479 * @{
1480 *
1481 * These operators allow the use of C++ syntax to construct GLSL expressions
1482 * containing operators: e.g. "a+b" creates an addition expression with
1483 * operands a and b, and so on.
1484 *
1485 *//*--------------------------------------------------------------------*/
1486 ExprP<float> operator+ (const ExprP<float>& arg0,
1487 const ExprP<float>& arg1);
1488 ExprP<deFloat16> operator+ (const ExprP<deFloat16>& arg0,
1489 const ExprP<deFloat16>& arg1);
1490 ExprP<double> operator+ (const ExprP<double>& arg0,
1491 const ExprP<double>& arg1);
1492 template <typename T>
1493 ExprP<T> operator- (const ExprP<T>& arg0);
1494 template <typename T>
1495 ExprP<T> operator- (const ExprP<T>& arg0,
1496 const ExprP<T>& arg1);
1497 template<int Left, int Mid, int Right, typename T>
1498 ExprP<Matrix<T, Left, Right> > operator* (const ExprP<Matrix<T, Left, Mid> >& left,
1499 const ExprP<Matrix<T, Mid, Right> >& right);
1500 ExprP<float> operator* (const ExprP<float>& arg0,
1501 const ExprP<float>& arg1);
1502 ExprP<deFloat16> operator* (const ExprP<deFloat16>& arg0,
1503 const ExprP<deFloat16>& arg1);
1504 ExprP<double> operator* (const ExprP<double>& arg0,
1505 const ExprP<double>& arg1);
1506 template <typename T>
1507 ExprP<T> operator/ (const ExprP<T>& arg0,
1508 const ExprP<T>& arg1);
1509 template<typename T, int Size>
1510 ExprP<Vector<T, Size> > operator- (const ExprP<Vector<T, Size> >& arg0);
1511 template<typename T, int Size>
1512 ExprP<Vector<T, Size> > operator- (const ExprP<Vector<T, Size> >& arg0,
1513 const ExprP<Vector<T, Size> >& arg1);
1514 template<int Size, typename T>
1515 ExprP<Vector<T, Size> > operator* (const ExprP<Vector<T, Size> >& arg0,
1516 const ExprP<T>& arg1);
1517 template<typename T, int Size>
1518 ExprP<Vector<T, Size> > operator* (const ExprP<Vector<T, Size> >& arg0,
1519 const ExprP<Vector<T, Size> >& arg1);
1520 template<int Rows, int Cols, typename T>
1521 ExprP<Vector<T, Rows> > operator* (const ExprP<Vector<T, Cols> >& left,
1522 const ExprP<Matrix<T, Rows, Cols> >& right);
1523 template<int Rows, int Cols, typename T>
1524 ExprP<Vector<T, Cols> > operator* (const ExprP<Matrix<T, Rows, Cols> >& left,
1525 const ExprP<Vector<T, Rows> >& right);
1526 template<int Rows, int Cols, typename T>
1527 ExprP<Matrix<T, Rows, Cols> > operator* (const ExprP<Matrix<T, Rows, Cols> >& left,
1528 const ExprP<T>& right);
1529 template<int Rows, int Cols>
1530 ExprP<Matrix<float, Rows, Cols> > operator+ (const ExprP<Matrix<float, Rows, Cols> >& left,
1531 const ExprP<Matrix<float, Rows, Cols> >& right);
1532 template<int Rows, int Cols>
1533 ExprP<Matrix<deFloat16, Rows, Cols> > operator+ (const ExprP<Matrix<deFloat16, Rows, Cols> >& left,
1534 const ExprP<Matrix<deFloat16, Rows, Cols> >& right);
1535 template<int Rows, int Cols>
1536 ExprP<Matrix<double, Rows, Cols> > operator+ (const ExprP<Matrix<double, Rows, Cols> >& left,
1537 const ExprP<Matrix<double, Rows, Cols> >& right);
1538 template<typename T, int Rows, int Cols>
1539 ExprP<Matrix<T, Rows, Cols> > operator- (const ExprP<Matrix<T, Rows, Cols> >& mat);
1540
1541 //! @}
1542
1543 /*--------------------------------------------------------------------*//*!
1544 * \brief Variable expression.
1545 *
1546 * A variable is evaluated by looking up its range of possible values from an
1547 * environment.
1548 *//*--------------------------------------------------------------------*/
1549 template <typename T>
1550 class Variable : public Expr<T>
1551 {
1552 public:
1553 typedef typename Expr<T>::IVal IVal;
1554
Variable(const string & name)1555 Variable (const string& name) : m_name (name) {}
getName(void) const1556 string getName (void) const { return m_name; }
1557
1558 protected:
doPrintExpr(ostream & os) const1559 void doPrintExpr (ostream& os) const { os << m_name; }
doEvaluate(const EvalContext & ctx) const1560 IVal doEvaluate (const EvalContext& ctx) const
1561 {
1562 return ctx.env.lookup<T>(*this);
1563 }
1564
1565 private:
1566 string m_name;
1567 };
1568
1569 template <typename T>
variable(const string & name)1570 VariableP<T> variable (const string& name)
1571 {
1572 return VariableP<T>(new Variable<T>(name));
1573 }
1574
1575 template <typename T>
bindExpression(const string & name,ExpandContext & ctx,const ExprP<T> & expr)1576 VariableP<T> bindExpression (const string& name, ExpandContext& ctx, const ExprP<T>& expr)
1577 {
1578 VariableP<T> var = ctx.genSym<T>(name);
1579 ctx.addStatement(variableDeclaration(var, expr));
1580 return var;
1581 }
1582
1583 /*--------------------------------------------------------------------*//*!
1584 * \brief Constant expression.
1585 *
1586 * A constant is evaluated by rounding it to a set of possible values allowed
1587 * by the current floating point precision.
1588 *//*--------------------------------------------------------------------*/
1589 template <typename T>
1590 class Constant : public Expr<T>
1591 {
1592 public:
1593 typedef typename Expr<T>::IVal IVal;
1594
Constant(const T & value)1595 Constant (const T& value) : m_value(value) {}
1596
1597 protected:
doPrintExpr(ostream & os) const1598 void doPrintExpr (ostream& os) const { os << m_value; }
doEvaluate(const EvalContext &) const1599 IVal doEvaluate (const EvalContext&) const { return makeIVal(m_value); }
1600
1601 private:
1602 T m_value;
1603 };
1604
1605 template <typename T>
constant(const T & value)1606 ExprP<T> constant (const T& value)
1607 {
1608 return exprP(new Constant<T>(value));
1609 }
1610
1611 //! Return a reference to a singleton void constant.
voidP(void)1612 const ExprP<Void>& voidP (void)
1613 {
1614 static const ExprP<Void> singleton = constant(Void());
1615
1616 return singleton;
1617 }
1618
1619 /*--------------------------------------------------------------------*//*!
1620 * \brief Four-element tuple.
1621 *
1622 * This is used for various things where we need one thing for each possible
1623 * function parameter. Currently the maximum supported number of parameters is
1624 * four.
1625 *//*--------------------------------------------------------------------*/
1626 template <typename T0 = Void, typename T1 = Void, typename T2 = Void, typename T3 = Void>
1627 struct Tuple4
1628 {
Tuple4vkt::shaderexecutor::Tuple41629 explicit Tuple4 (const T0 e0 = T0(),
1630 const T1 e1 = T1(),
1631 const T2 e2 = T2(),
1632 const T3 e3 = T3())
1633 : a (e0)
1634 , b (e1)
1635 , c (e2)
1636 , d (e3)
1637 {
1638 }
1639
1640 T0 a;
1641 T1 b;
1642 T2 c;
1643 T3 d;
1644 };
1645
1646 /*--------------------------------------------------------------------*//*!
1647 * \brief Function signature.
1648 *
1649 * This is a purely compile-time structure used to bundle all types in a
1650 * function signature together. This makes passing the signature around in
1651 * templates easier, since we only need to take and pass a single Sig instead
1652 * of a bunch of parameter types and a return type.
1653 *
1654 *//*--------------------------------------------------------------------*/
1655 template <typename R,
1656 typename P0 = Void, typename P1 = Void,
1657 typename P2 = Void, typename P3 = Void>
1658 struct Signature
1659 {
1660 typedef R Ret;
1661 typedef P0 Arg0;
1662 typedef P1 Arg1;
1663 typedef P2 Arg2;
1664 typedef P3 Arg3;
1665 typedef typename Traits<Ret>::IVal IRet;
1666 typedef typename Traits<Arg0>::IVal IArg0;
1667 typedef typename Traits<Arg1>::IVal IArg1;
1668 typedef typename Traits<Arg2>::IVal IArg2;
1669 typedef typename Traits<Arg3>::IVal IArg3;
1670
1671 typedef Tuple4< const Arg0&, const Arg1&, const Arg2&, const Arg3&> Args;
1672 typedef Tuple4< const IArg0&, const IArg1&, const IArg2&, const IArg3&> IArgs;
1673 typedef Tuple4< ExprP<Arg0>, ExprP<Arg1>, ExprP<Arg2>, ExprP<Arg3> > ArgExprs;
1674 };
1675
1676 typedef vector<const ExprBase*> BaseArgExprs;
1677
1678 /*--------------------------------------------------------------------*//*!
1679 * \brief Type-independent operations for function objects.
1680 *
1681 *//*--------------------------------------------------------------------*/
1682 class FuncBase
1683 {
1684 public:
~FuncBase(void)1685 virtual ~FuncBase (void) {}
1686 virtual string getName (void) const = 0;
1687 //! Name of extension that this function requires, or empty.
getRequiredExtension(void) const1688 virtual string getRequiredExtension (void) const { return ""; }
getInputRange(const bool is16bit) const1689 virtual Interval getInputRange (const bool is16bit) const {DE_UNREF(is16bit); return Interval(true, -TCU_INFINITY, TCU_INFINITY); }
1690 virtual void print (ostream&,
1691 const BaseArgExprs&) const = 0;
1692 //! Index of output parameter, or -1 if none of the parameters is output.
getOutParamIndex(void) const1693 virtual int getOutParamIndex (void) const { return -1; }
1694
getSpirvCase(void) const1695 virtual SpirVCaseT getSpirvCase (void) const { return SPIRV_CASETYPE_NONE; }
1696
printDefinition(ostream & os) const1697 void printDefinition (ostream& os) const
1698 {
1699 doPrintDefinition(os);
1700 }
1701
getUsedFuncs(FuncSet & dst) const1702 void getUsedFuncs (FuncSet& dst) const
1703 {
1704 this->doGetUsedFuncs(dst);
1705 }
1706
1707 protected:
1708 virtual void doPrintDefinition (ostream& os) const = 0;
1709 virtual void doGetUsedFuncs (FuncSet& dst) const = 0;
1710 };
1711
1712 typedef Tuple4<string, string, string, string> ParamNames;
1713
1714 /*--------------------------------------------------------------------*//*!
1715 * \brief Function objects.
1716 *
1717 * Each Func object represents a GLSL function. It can be applied to interval
1718 * arguments, and it returns the an interval that is a conservative
1719 * approximation of the image of the GLSL function over the argument
1720 * intervals. That is, it is given a set of possible arguments and it returns
1721 * the set of possible values.
1722 *
1723 *//*--------------------------------------------------------------------*/
1724 template <typename Sig_>
1725 class Func : public FuncBase
1726 {
1727 public:
1728 typedef Sig_ Sig;
1729 typedef typename Sig::Ret Ret;
1730 typedef typename Sig::Arg0 Arg0;
1731 typedef typename Sig::Arg1 Arg1;
1732 typedef typename Sig::Arg2 Arg2;
1733 typedef typename Sig::Arg3 Arg3;
1734 typedef typename Sig::IRet IRet;
1735 typedef typename Sig::IArg0 IArg0;
1736 typedef typename Sig::IArg1 IArg1;
1737 typedef typename Sig::IArg2 IArg2;
1738 typedef typename Sig::IArg3 IArg3;
1739 typedef typename Sig::Args Args;
1740 typedef typename Sig::IArgs IArgs;
1741 typedef typename Sig::ArgExprs ArgExprs;
1742
print(ostream & os,const BaseArgExprs & args) const1743 void print (ostream& os,
1744 const BaseArgExprs& args) const
1745 {
1746 this->doPrint(os, args);
1747 }
1748
apply(const EvalContext & ctx,const IArg0 & arg0=IArg0 (),const IArg1 & arg1=IArg1 (),const IArg2 & arg2=IArg2 (),const IArg3 & arg3=IArg3 ()) const1749 IRet apply (const EvalContext& ctx,
1750 const IArg0& arg0 = IArg0(),
1751 const IArg1& arg1 = IArg1(),
1752 const IArg2& arg2 = IArg2(),
1753 const IArg3& arg3 = IArg3()) const
1754 {
1755 return this->applyArgs(ctx, IArgs(arg0, arg1, arg2, arg3));
1756 }
1757
fail(const EvalContext & ctx,const IArg0 & arg0=IArg0 (),const IArg1 & arg1=IArg1 (),const IArg2 & arg2=IArg2 (),const IArg3 & arg3=IArg3 ()) const1758 IRet fail (const EvalContext& ctx,
1759 const IArg0& arg0 = IArg0(),
1760 const IArg1& arg1 = IArg1(),
1761 const IArg2& arg2 = IArg2(),
1762 const IArg3& arg3 = IArg3()) const
1763 {
1764 return this->doFail(ctx, IArgs(arg0, arg1, arg2, arg3));
1765 }
applyArgs(const EvalContext & ctx,const IArgs & args) const1766 IRet applyArgs (const EvalContext& ctx,
1767 const IArgs& args) const
1768 {
1769 return this->doApply(ctx, args);
1770 }
1771 ExprP<Ret> operator() (const ExprP<Arg0>& arg0 = voidP(),
1772 const ExprP<Arg1>& arg1 = voidP(),
1773 const ExprP<Arg2>& arg2 = voidP(),
1774 const ExprP<Arg3>& arg3 = voidP()) const;
1775
getParamNames(void) const1776 const ParamNames& getParamNames (void) const
1777 {
1778 return this->doGetParamNames();
1779 }
1780
1781 protected:
1782 virtual IRet doApply (const EvalContext&,
1783 const IArgs&) const = 0;
doFail(const EvalContext & ctx,const IArgs & args) const1784 virtual IRet doFail (const EvalContext& ctx,
1785 const IArgs& args) const
1786 {
1787 return this->doApply(ctx, args);
1788 }
doPrint(ostream & os,const BaseArgExprs & args) const1789 virtual void doPrint (ostream& os, const BaseArgExprs& args) const
1790 {
1791 os << getName() << "(";
1792
1793 if (isTypeValid<Arg0>())
1794 os << *args[0];
1795
1796 if (isTypeValid<Arg1>())
1797 os << ", " << *args[1];
1798
1799 if (isTypeValid<Arg2>())
1800 os << ", " << *args[2];
1801
1802 if (isTypeValid<Arg3>())
1803 os << ", " << *args[3];
1804
1805 os << ")";
1806 }
1807
doGetParamNames(void) const1808 virtual const ParamNames& doGetParamNames (void) const
1809 {
1810 static ParamNames names ("a", "b", "c", "d");
1811 return names;
1812 }
1813 };
1814
1815 template <typename Sig>
1816 class Apply : public Expr<typename Sig::Ret>
1817 {
1818 public:
1819 typedef typename Sig::Ret Ret;
1820 typedef typename Sig::Arg0 Arg0;
1821 typedef typename Sig::Arg1 Arg1;
1822 typedef typename Sig::Arg2 Arg2;
1823 typedef typename Sig::Arg3 Arg3;
1824 typedef typename Expr<Ret>::Val Val;
1825 typedef typename Expr<Ret>::IVal IVal;
1826 typedef Func<Sig> ApplyFunc;
1827 typedef typename ApplyFunc::ArgExprs ArgExprs;
1828
Apply(const ApplyFunc & func,const ExprP<Arg0> & arg0=voidP (),const ExprP<Arg1> & arg1=voidP (),const ExprP<Arg2> & arg2=voidP (),const ExprP<Arg3> & arg3=voidP ())1829 Apply (const ApplyFunc& func,
1830 const ExprP<Arg0>& arg0 = voidP(),
1831 const ExprP<Arg1>& arg1 = voidP(),
1832 const ExprP<Arg2>& arg2 = voidP(),
1833 const ExprP<Arg3>& arg3 = voidP())
1834 : m_func (func),
1835 m_args (arg0, arg1, arg2, arg3) {}
1836
Apply(const ApplyFunc & func,const ArgExprs & args)1837 Apply (const ApplyFunc& func,
1838 const ArgExprs& args)
1839 : m_func (func),
1840 m_args (args) {}
1841 protected:
doPrintExpr(ostream & os) const1842 void doPrintExpr (ostream& os) const
1843 {
1844 BaseArgExprs args;
1845 args.push_back(m_args.a.get());
1846 args.push_back(m_args.b.get());
1847 args.push_back(m_args.c.get());
1848 args.push_back(m_args.d.get());
1849 m_func.print(os, args);
1850 }
1851
doEvaluate(const EvalContext & ctx) const1852 IVal doEvaluate (const EvalContext& ctx) const
1853 {
1854 return m_func.apply(ctx,
1855 m_args.a->evaluate(ctx), m_args.b->evaluate(ctx),
1856 m_args.c->evaluate(ctx), m_args.d->evaluate(ctx));
1857 }
1858
doGetUsedFuncs(FuncSet & dst) const1859 void doGetUsedFuncs (FuncSet& dst) const
1860 {
1861 m_func.getUsedFuncs(dst);
1862 m_args.a->getUsedFuncs(dst);
1863 m_args.b->getUsedFuncs(dst);
1864 m_args.c->getUsedFuncs(dst);
1865 m_args.d->getUsedFuncs(dst);
1866 }
1867
1868 const ApplyFunc& m_func;
1869 ArgExprs m_args;
1870 };
1871
1872 template<typename T>
1873 class Alternatives : public Func<Signature<T, T, T> >
1874 {
1875 public:
1876 typedef typename Alternatives::Sig Sig;
1877
1878 protected:
1879 typedef typename Alternatives::IRet IRet;
1880 typedef typename Alternatives::IArgs IArgs;
1881
getName(void) const1882 virtual string getName (void) const { return "alternatives"; }
doPrintDefinition(std::ostream &) const1883 virtual void doPrintDefinition (std::ostream&) const {}
doGetUsedFuncs(FuncSet &) const1884 void doGetUsedFuncs (FuncSet&) const {}
1885
doApply(const EvalContext &,const IArgs & args) const1886 virtual IRet doApply (const EvalContext&, const IArgs& args) const
1887 {
1888 return unionIVal<T>(args.a, args.b);
1889 }
1890
doPrint(ostream & os,const BaseArgExprs & args) const1891 virtual void doPrint (ostream& os, const BaseArgExprs& args) const
1892 {
1893 os << "{" << *args[0] << " | " << *args[1] << "}";
1894 }
1895 };
1896
1897 template <typename Sig>
createApply(const Func<Sig> & func,const typename Func<Sig>::ArgExprs & args)1898 ExprP<typename Sig::Ret> createApply (const Func<Sig>& func,
1899 const typename Func<Sig>::ArgExprs& args)
1900 {
1901 return exprP(new Apply<Sig>(func, args));
1902 }
1903
1904 template <typename Sig>
createApply(const Func<Sig> & func,const ExprP<typename Sig::Arg0> & arg0=voidP (),const ExprP<typename Sig::Arg1> & arg1=voidP (),const ExprP<typename Sig::Arg2> & arg2=voidP (),const ExprP<typename Sig::Arg3> & arg3=voidP ())1905 ExprP<typename Sig::Ret> createApply (
1906 const Func<Sig>& func,
1907 const ExprP<typename Sig::Arg0>& arg0 = voidP(),
1908 const ExprP<typename Sig::Arg1>& arg1 = voidP(),
1909 const ExprP<typename Sig::Arg2>& arg2 = voidP(),
1910 const ExprP<typename Sig::Arg3>& arg3 = voidP())
1911 {
1912 return exprP(new Apply<Sig>(func, arg0, arg1, arg2, arg3));
1913 }
1914
1915 template <typename Sig>
operator ()(const ExprP<typename Sig::Arg0> & arg0,const ExprP<typename Sig::Arg1> & arg1,const ExprP<typename Sig::Arg2> & arg2,const ExprP<typename Sig::Arg3> & arg3) const1916 ExprP<typename Sig::Ret> Func<Sig>::operator() (const ExprP<typename Sig::Arg0>& arg0,
1917 const ExprP<typename Sig::Arg1>& arg1,
1918 const ExprP<typename Sig::Arg2>& arg2,
1919 const ExprP<typename Sig::Arg3>& arg3) const
1920 {
1921 return createApply(*this, arg0, arg1, arg2, arg3);
1922 }
1923
1924 template <typename F>
app(const ExprP<typename F::Arg0> & arg0=voidP (),const ExprP<typename F::Arg1> & arg1=voidP (),const ExprP<typename F::Arg2> & arg2=voidP (),const ExprP<typename F::Arg3> & arg3=voidP ())1925 ExprP<typename F::Ret> app (const ExprP<typename F::Arg0>& arg0 = voidP(),
1926 const ExprP<typename F::Arg1>& arg1 = voidP(),
1927 const ExprP<typename F::Arg2>& arg2 = voidP(),
1928 const ExprP<typename F::Arg3>& arg3 = voidP())
1929 {
1930 return createApply(instance<F>(), arg0, arg1, arg2, arg3);
1931 }
1932
1933 template <typename F>
call(const EvalContext & ctx,const typename F::IArg0 & arg0=Void (),const typename F::IArg1 & arg1=Void (),const typename F::IArg2 & arg2=Void (),const typename F::IArg3 & arg3=Void ())1934 typename F::IRet call (const EvalContext& ctx,
1935 const typename F::IArg0& arg0 = Void(),
1936 const typename F::IArg1& arg1 = Void(),
1937 const typename F::IArg2& arg2 = Void(),
1938 const typename F::IArg3& arg3 = Void())
1939 {
1940 return instance<F>().apply(ctx, arg0, arg1, arg2, arg3);
1941 }
1942
1943 template <typename T>
alternatives(const ExprP<T> & arg0,const ExprP<T> & arg1)1944 ExprP<T> alternatives (const ExprP<T>& arg0,
1945 const ExprP<T>& arg1)
1946 {
1947 return createApply<typename Alternatives<T>::Sig>(instance<Alternatives<T> >(), arg0, arg1);
1948 }
1949
1950 template <typename Sig>
1951 class ApplyVar : public Apply<Sig>
1952 {
1953 public:
1954 typedef typename Sig::Ret Ret;
1955 typedef typename Sig::Arg0 Arg0;
1956 typedef typename Sig::Arg1 Arg1;
1957 typedef typename Sig::Arg2 Arg2;
1958 typedef typename Sig::Arg3 Arg3;
1959 typedef typename Expr<Ret>::Val Val;
1960 typedef typename Expr<Ret>::IVal IVal;
1961 typedef Func<Sig> ApplyFunc;
1962 typedef typename ApplyFunc::ArgExprs ArgExprs;
1963
ApplyVar(const ApplyFunc & func,const VariableP<Arg0> & arg0,const VariableP<Arg1> & arg1,const VariableP<Arg2> & arg2,const VariableP<Arg3> & arg3)1964 ApplyVar (const ApplyFunc& func,
1965 const VariableP<Arg0>& arg0,
1966 const VariableP<Arg1>& arg1,
1967 const VariableP<Arg2>& arg2,
1968 const VariableP<Arg3>& arg3)
1969 : Apply<Sig> (func, arg0, arg1, arg2, arg3) {}
1970 protected:
doEvaluate(const EvalContext & ctx) const1971 IVal doEvaluate (const EvalContext& ctx) const
1972 {
1973 const Variable<Arg0>& var0 = static_cast<const Variable<Arg0>&>(*this->m_args.a);
1974 const Variable<Arg1>& var1 = static_cast<const Variable<Arg1>&>(*this->m_args.b);
1975 const Variable<Arg2>& var2 = static_cast<const Variable<Arg2>&>(*this->m_args.c);
1976 const Variable<Arg3>& var3 = static_cast<const Variable<Arg3>&>(*this->m_args.d);
1977 return this->m_func.apply(ctx,
1978 ctx.env.lookup(var0), ctx.env.lookup(var1),
1979 ctx.env.lookup(var2), ctx.env.lookup(var3));
1980 }
1981
doFails(const EvalContext & ctx) const1982 IVal doFails (const EvalContext& ctx) const
1983 {
1984 const Variable<Arg0>& var0 = static_cast<const Variable<Arg0>&>(*this->m_args.a);
1985 const Variable<Arg1>& var1 = static_cast<const Variable<Arg1>&>(*this->m_args.b);
1986 const Variable<Arg2>& var2 = static_cast<const Variable<Arg2>&>(*this->m_args.c);
1987 const Variable<Arg3>& var3 = static_cast<const Variable<Arg3>&>(*this->m_args.d);
1988 return this->m_func.fail(ctx,
1989 ctx.env.lookup(var0), ctx.env.lookup(var1),
1990 ctx.env.lookup(var2), ctx.env.lookup(var3));
1991 }
1992 };
1993
1994 template <typename Sig>
applyVar(const Func<Sig> & func,const VariableP<typename Sig::Arg0> & arg0,const VariableP<typename Sig::Arg1> & arg1,const VariableP<typename Sig::Arg2> & arg2,const VariableP<typename Sig::Arg3> & arg3)1995 ExprP<typename Sig::Ret> applyVar (const Func<Sig>& func,
1996 const VariableP<typename Sig::Arg0>& arg0,
1997 const VariableP<typename Sig::Arg1>& arg1,
1998 const VariableP<typename Sig::Arg2>& arg2,
1999 const VariableP<typename Sig::Arg3>& arg3)
2000 {
2001 return exprP(new ApplyVar<Sig>(func, arg0, arg1, arg2, arg3));
2002 }
2003
2004 template <typename Sig_>
2005 class DerivedFunc : public Func<Sig_>
2006 {
2007 public:
2008 typedef typename DerivedFunc::ArgExprs ArgExprs;
2009 typedef typename DerivedFunc::IRet IRet;
2010 typedef typename DerivedFunc::IArgs IArgs;
2011 typedef typename DerivedFunc::Ret Ret;
2012 typedef typename DerivedFunc::Arg0 Arg0;
2013 typedef typename DerivedFunc::Arg1 Arg1;
2014 typedef typename DerivedFunc::Arg2 Arg2;
2015 typedef typename DerivedFunc::Arg3 Arg3;
2016 typedef typename DerivedFunc::IArg0 IArg0;
2017 typedef typename DerivedFunc::IArg1 IArg1;
2018 typedef typename DerivedFunc::IArg2 IArg2;
2019 typedef typename DerivedFunc::IArg3 IArg3;
2020
2021 protected:
doPrintDefinition(ostream & os) const2022 void doPrintDefinition (ostream& os) const
2023 {
2024 const ParamNames& paramNames = this->getParamNames();
2025
2026 initialize();
2027
2028 os << dataTypeNameOf<Ret>() << " " << this->getName()
2029 << "(";
2030 if (isTypeValid<Arg0>())
2031 os << dataTypeNameOf<Arg0>() << " " << paramNames.a;
2032 if (isTypeValid<Arg1>())
2033 os << ", " << dataTypeNameOf<Arg1>() << " " << paramNames.b;
2034 if (isTypeValid<Arg2>())
2035 os << ", " << dataTypeNameOf<Arg2>() << " " << paramNames.c;
2036 if (isTypeValid<Arg3>())
2037 os << ", " << dataTypeNameOf<Arg3>() << " " << paramNames.d;
2038 os << ")\n{\n";
2039
2040 for (size_t ndx = 0; ndx < m_body.size(); ++ndx)
2041 os << *m_body[ndx];
2042 os << "return " << *m_ret << ";\n";
2043 os << "}\n";
2044 }
2045
doApply(const EvalContext & ctx,const IArgs & args) const2046 IRet doApply (const EvalContext& ctx,
2047 const IArgs& args) const
2048 {
2049 Environment funEnv;
2050 IArgs& mutArgs = const_cast<IArgs&>(args);
2051 IRet ret;
2052
2053 initialize();
2054
2055 funEnv.bind(*m_var0, args.a);
2056 funEnv.bind(*m_var1, args.b);
2057 funEnv.bind(*m_var2, args.c);
2058 funEnv.bind(*m_var3, args.d);
2059
2060 {
2061 EvalContext funCtx(ctx.format, ctx.floatPrecision, funEnv, ctx.callDepth);
2062
2063 for (size_t ndx = 0; ndx < m_body.size(); ++ndx)
2064 m_body[ndx]->execute(funCtx);
2065
2066 ret = m_ret->evaluate(funCtx);
2067 }
2068
2069 // \todo [lauri] Store references instead of values in environment
2070 const_cast<IArg0&>(mutArgs.a) = funEnv.lookup(*m_var0);
2071 const_cast<IArg1&>(mutArgs.b) = funEnv.lookup(*m_var1);
2072 const_cast<IArg2&>(mutArgs.c) = funEnv.lookup(*m_var2);
2073 const_cast<IArg3&>(mutArgs.d) = funEnv.lookup(*m_var3);
2074
2075 return ret;
2076 }
2077
doGetUsedFuncs(FuncSet & dst) const2078 void doGetUsedFuncs (FuncSet& dst) const
2079 {
2080 initialize();
2081 if (dst.insert(this).second)
2082 {
2083 for (size_t ndx = 0; ndx < m_body.size(); ++ndx)
2084 m_body[ndx]->getUsedFuncs(dst);
2085 m_ret->getUsedFuncs(dst);
2086 }
2087 }
2088
2089 virtual ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args_) const = 0;
2090
2091 // These are transparently initialized when first needed. They cannot be
2092 // initialized in the constructor because they depend on the doExpand
2093 // method of the subclass.
2094
2095 mutable VariableP<Arg0> m_var0;
2096 mutable VariableP<Arg1> m_var1;
2097 mutable VariableP<Arg2> m_var2;
2098 mutable VariableP<Arg3> m_var3;
2099 mutable vector<StatementP> m_body;
2100 mutable ExprP<Ret> m_ret;
2101
2102 private:
2103
initialize(void) const2104 void initialize (void) const
2105 {
2106 if (!m_ret)
2107 {
2108 const ParamNames& paramNames = this->getParamNames();
2109 Counter symCounter;
2110 ExpandContext ctx (symCounter);
2111 ArgExprs args;
2112
2113 args.a = m_var0 = variable<Arg0>(paramNames.a);
2114 args.b = m_var1 = variable<Arg1>(paramNames.b);
2115 args.c = m_var2 = variable<Arg2>(paramNames.c);
2116 args.d = m_var3 = variable<Arg3>(paramNames.d);
2117
2118 m_ret = this->doExpand(ctx, args);
2119 m_body = ctx.getStatements();
2120 }
2121 }
2122 };
2123
2124 template <typename Sig>
2125 class PrimitiveFunc : public Func<Sig>
2126 {
2127 public:
2128 typedef typename PrimitiveFunc::Ret Ret;
2129 typedef typename PrimitiveFunc::ArgExprs ArgExprs;
2130
2131 protected:
doPrintDefinition(ostream &) const2132 void doPrintDefinition (ostream&) const {}
doGetUsedFuncs(FuncSet &) const2133 void doGetUsedFuncs (FuncSet&) const {}
2134 };
2135
2136 template <typename T>
2137 class Cond : public PrimitiveFunc<Signature<T, bool, T, T> >
2138 {
2139 public:
2140 typedef typename Cond::IArgs IArgs;
2141 typedef typename Cond::IRet IRet;
2142
getName(void) const2143 string getName (void) const
2144 {
2145 return "_cond";
2146 }
2147
2148 protected:
2149
doPrint(ostream & os,const BaseArgExprs & args) const2150 void doPrint (ostream& os, const BaseArgExprs& args) const
2151 {
2152 os << "(" << *args[0] << " ? " << *args[1] << " : " << *args[2] << ")";
2153 }
2154
doApply(const EvalContext &,const IArgs & iargs) const2155 IRet doApply (const EvalContext&, const IArgs& iargs)const
2156 {
2157 IRet ret;
2158
2159 if (iargs.a.contains(true))
2160 ret = unionIVal<T>(ret, iargs.b);
2161
2162 if (iargs.a.contains(false))
2163 ret = unionIVal<T>(ret, iargs.c);
2164
2165 return ret;
2166 }
2167 };
2168
2169 template <typename T>
2170 class CompareOperator : public PrimitiveFunc<Signature<bool, T, T> >
2171 {
2172 public:
2173 typedef typename CompareOperator::IArgs IArgs;
2174 typedef typename CompareOperator::IArg0 IArg0;
2175 typedef typename CompareOperator::IArg1 IArg1;
2176 typedef typename CompareOperator::IRet IRet;
2177
2178 protected:
doPrint(ostream & os,const BaseArgExprs & args) const2179 void doPrint (ostream& os, const BaseArgExprs& args) const
2180 {
2181 os << "(" << *args[0] << getSymbol() << *args[1] << ")";
2182 }
2183
doApply(const EvalContext &,const IArgs & iargs) const2184 Interval doApply (const EvalContext&, const IArgs& iargs) const
2185 {
2186 const IArg0& arg0 = iargs.a;
2187 const IArg1& arg1 = iargs.b;
2188 IRet ret;
2189
2190 if (canSucceed(arg0, arg1))
2191 ret |= true;
2192 if (canFail(arg0, arg1))
2193 ret |= false;
2194
2195 return ret;
2196 }
2197
2198 virtual string getSymbol (void) const = 0;
2199 virtual bool canSucceed (const IArg0&, const IArg1&) const = 0;
2200 virtual bool canFail (const IArg0&, const IArg1&) const = 0;
2201 };
2202
2203 template <typename T>
2204 class LessThan : public CompareOperator<T>
2205 {
2206 public:
getName(void) const2207 string getName (void) const { return "lessThan"; }
2208
2209 protected:
getSymbol(void) const2210 string getSymbol (void) const { return "<"; }
2211
canSucceed(const Interval & a,const Interval & b) const2212 bool canSucceed (const Interval& a, const Interval& b) const
2213 {
2214 return (a.lo() < b.hi());
2215 }
2216
canFail(const Interval & a,const Interval & b) const2217 bool canFail (const Interval& a, const Interval& b) const
2218 {
2219 return !(a.hi() < b.lo());
2220 }
2221 };
2222
2223 template <typename T>
operator <(const ExprP<T> & a,const ExprP<T> & b)2224 ExprP<bool> operator< (const ExprP<T>& a, const ExprP<T>& b)
2225 {
2226 return app<LessThan<T> >(a, b);
2227 }
2228
2229 template <typename T>
cond(const ExprP<bool> & test,const ExprP<T> & consequent,const ExprP<T> & alternative)2230 ExprP<T> cond (const ExprP<bool>& test,
2231 const ExprP<T>& consequent,
2232 const ExprP<T>& alternative)
2233 {
2234 return app<Cond<T> >(test, consequent, alternative);
2235 }
2236
2237 /*--------------------------------------------------------------------*//*!
2238 *
2239 * @}
2240 *
2241 *//*--------------------------------------------------------------------*/
2242 //Proper parameters for template T
2243 // Signature<float, float> 32bit tests
2244 // Signature<float, deFloat16> 16bit tests
2245 // Signature<double, double> 64bit tests
2246 template< class T>
2247 class FloatFunc1 : public PrimitiveFunc<T>
2248 {
2249 protected:
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0>::IArgs & iargs) const2250 Interval doApply (const EvalContext& ctx, const typename Signature<typename T::Ret, typename T::Arg0>::IArgs& iargs) const
2251 {
2252 return this->applyMonotone(ctx, iargs.a);
2253 }
2254
applyMonotone(const EvalContext & ctx,const Interval & iarg0) const2255 Interval applyMonotone (const EvalContext& ctx, const Interval& iarg0) const
2256 {
2257 Interval ret;
2258
2259 TCU_INTERVAL_APPLY_MONOTONE1(ret, arg0, iarg0, val,
2260 TCU_SET_INTERVAL(val, point,
2261 point = this->applyPoint(ctx, arg0)));
2262
2263 ret |= innerExtrema(ctx, iarg0);
2264 ret &= (this->getCodomain(ctx) | TCU_NAN);
2265
2266 return ctx.format.convert(ret);
2267 }
2268
innerExtrema(const EvalContext &,const Interval &) const2269 virtual Interval innerExtrema (const EvalContext&, const Interval&) const
2270 {
2271 return Interval(); // empty interval, i.e. no extrema
2272 }
2273
applyPoint(const EvalContext & ctx,double arg0) const2274 virtual Interval applyPoint (const EvalContext& ctx, double arg0) const
2275 {
2276 const double exact = this->applyExact(arg0);
2277 const double prec = this->precision(ctx, exact, arg0);
2278
2279 return exact + Interval(-prec, prec);
2280 }
2281
applyExact(double) const2282 virtual double applyExact (double) const
2283 {
2284 TCU_THROW(InternalError, "Cannot apply");
2285 }
2286
getCodomain(const EvalContext &) const2287 virtual Interval getCodomain (const EvalContext&) const
2288 {
2289 return Interval::unbounded(true);
2290 }
2291
2292 virtual double precision (const EvalContext& ctx, double, double) const = 0;
2293 };
2294
2295 /*Proper parameters for template T
2296 Signature<double, double> 64bit tests
2297 Signature<float, float> 32bit tests
2298 Signature<float, deFloat16> 16bit tests*/
2299 template <class T>
2300 class CFloatFunc1 : public FloatFunc1<T>
2301 {
2302 public:
CFloatFunc1(const string & name,tcu::DoubleFunc1 & func)2303 CFloatFunc1 (const string& name, tcu::DoubleFunc1& func)
2304 : m_name(name), m_func(func) {}
2305
getName(void) const2306 string getName (void) const { return m_name; }
2307
2308 protected:
applyExact(double x) const2309 double applyExact (double x) const { return m_func(x); }
2310
2311 const string m_name;
2312 tcu::DoubleFunc1& m_func;
2313 };
2314
2315 //<Signature<float, deFloat16, deFloat16> >
2316 //<Signature<float, float, float> >
2317 //<Signature<double, double, double> >
2318 template <class T>
2319 class FloatFunc2 : public PrimitiveFunc<T>
2320 {
2321 protected:
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2322 Interval doApply (const EvalContext& ctx, const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs& iargs) const
2323 {
2324 return this->applyMonotone(ctx, iargs.a, iargs.b);
2325 }
2326
applyMonotone(const EvalContext & ctx,const Interval & xi,const Interval & yi) const2327 Interval applyMonotone (const EvalContext& ctx,
2328 const Interval& xi,
2329 const Interval& yi) const
2330 {
2331 Interval reti;
2332
2333 TCU_INTERVAL_APPLY_MONOTONE2(reti, x, xi, y, yi, ret,
2334 TCU_SET_INTERVAL(ret, point,
2335 point = this->applyPoint(ctx, x, y)));
2336 reti |= innerExtrema(ctx, xi, yi);
2337 reti &= (this->getCodomain(ctx) | TCU_NAN);
2338
2339 return ctx.format.convert(reti);
2340 }
2341
innerExtrema(const EvalContext &,const Interval &,const Interval &) const2342 virtual Interval innerExtrema (const EvalContext&,
2343 const Interval&,
2344 const Interval&) const
2345 {
2346 return Interval(); // empty interval, i.e. no extrema
2347 }
2348
applyPoint(const EvalContext & ctx,double x,double y) const2349 virtual Interval applyPoint (const EvalContext& ctx,
2350 double x,
2351 double y) const
2352 {
2353 const double exact = this->applyExact(x, y);
2354 const double prec = this->precision(ctx, exact, x, y);
2355
2356 return exact + Interval(-prec, prec);
2357 }
2358
applyExact(double,double) const2359 virtual double applyExact (double, double) const
2360 {
2361 TCU_THROW(InternalError, "Cannot apply");
2362 }
2363
getCodomain(const EvalContext &) const2364 virtual Interval getCodomain (const EvalContext&) const
2365 {
2366 return Interval::unbounded(true);
2367 }
2368
2369 virtual double precision (const EvalContext& ctx,
2370 double ret,
2371 double x,
2372 double y) const = 0;
2373 };
2374
2375 template <class T>
2376 class CFloatFunc2 : public FloatFunc2<T>
2377 {
2378 public:
CFloatFunc2(const string & name,tcu::DoubleFunc2 & func)2379 CFloatFunc2 (const string& name,
2380 tcu::DoubleFunc2& func)
2381 : m_name(name)
2382 , m_func(func)
2383 {
2384 }
2385
getName(void) const2386 string getName (void) const { return m_name; }
2387
2388 protected:
applyExact(double x,double y) const2389 double applyExact (double x, double y) const { return m_func(x, y); }
2390
2391 const string m_name;
2392 tcu::DoubleFunc2& m_func;
2393 };
2394
2395 template <class T>
2396 class InfixOperator : public FloatFunc2<T>
2397 {
2398 protected:
2399 virtual string getSymbol (void) const = 0;
2400
doPrint(ostream & os,const BaseArgExprs & args) const2401 void doPrint (ostream& os, const BaseArgExprs& args) const
2402 {
2403 os << "(" << *args[0] << " " << getSymbol() << " " << *args[1] << ")";
2404 }
2405
applyPoint(const EvalContext & ctx,double x,double y) const2406 Interval applyPoint (const EvalContext& ctx,
2407 double x,
2408 double y) const
2409 {
2410 const double exact = this->applyExact(x, y);
2411
2412 // Allow either representable number on both sides of the exact value,
2413 // but require exactly representable values to be preserved.
2414 return ctx.format.roundOut(exact, !deIsInf(x) && !deIsInf(y));
2415 }
2416
precision(const EvalContext &,double,double,double) const2417 double precision (const EvalContext&, double, double, double) const
2418 {
2419 return 0.0;
2420 }
2421 };
2422
2423 class InfixOperator16Bit : public FloatFunc2 <Signature<float, deFloat16, deFloat16> >
2424 {
2425 protected:
2426 virtual string getSymbol (void) const = 0;
2427
doPrint(ostream & os,const BaseArgExprs & args) const2428 void doPrint (ostream& os, const BaseArgExprs& args) const
2429 {
2430 os << "(" << *args[0] << " " << getSymbol() << " " << *args[1] << ")";
2431 }
2432
applyPoint(const EvalContext & ctx,double x,double y) const2433 Interval applyPoint (const EvalContext& ctx,
2434 double x,
2435 double y) const
2436 {
2437 const double exact = this->applyExact(x, y);
2438
2439 // Allow either representable number on both sides of the exact value,
2440 // but require exactly representable values to be preserved.
2441 return ctx.format.roundOut(exact, !deIsInf(x) && !deIsInf(y));
2442 }
2443
precision(const EvalContext &,double,double,double) const2444 double precision (const EvalContext&, double, double, double) const
2445 {
2446 return 0.0;
2447 }
2448 };
2449
2450 template <class T>
2451 class FloatFunc3 : public PrimitiveFunc<T>
2452 {
2453 protected:
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1,typename T::Arg2>::IArgs & iargs) const2454 Interval doApply (const EvalContext& ctx, const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1, typename T::Arg2>::IArgs& iargs) const
2455 {
2456 return this->applyMonotone(ctx, iargs.a, iargs.b, iargs.c);
2457 }
2458
applyMonotone(const EvalContext & ctx,const Interval & xi,const Interval & yi,const Interval & zi) const2459 Interval applyMonotone (const EvalContext& ctx,
2460 const Interval& xi,
2461 const Interval& yi,
2462 const Interval& zi) const
2463 {
2464 Interval reti;
2465 TCU_INTERVAL_APPLY_MONOTONE3(reti, x, xi, y, yi, z, zi, ret,
2466 TCU_SET_INTERVAL(ret, point,
2467 point = this->applyPoint(ctx, x, y, z)));
2468 return ctx.format.convert(reti);
2469 }
2470
applyPoint(const EvalContext & ctx,double x,double y,double z) const2471 virtual Interval applyPoint (const EvalContext& ctx,
2472 double x,
2473 double y,
2474 double z) const
2475 {
2476 const double exact = this->applyExact(x, y, z);
2477 const double prec = this->precision(ctx, exact, x, y, z);
2478 return exact + Interval(-prec, prec);
2479 }
2480
applyExact(double,double,double) const2481 virtual double applyExact (double, double, double) const
2482 {
2483 TCU_THROW(InternalError, "Cannot apply");
2484 }
2485
2486 virtual double precision (const EvalContext& ctx,
2487 double result,
2488 double x,
2489 double y,
2490 double z) const = 0;
2491 };
2492
2493 // We define syntactic sugar functions for expression constructors. Since
2494 // these have the same names as ordinary mathematical operations (sin, log
2495 // etc.), it's better to give them a dedicated namespace.
2496 namespace Functions
2497 {
2498
2499 using namespace tcu;
2500
2501 template <class T>
2502 class Comparison : public InfixOperator < T >
2503 {
2504 public:
getName(void) const2505 string getName (void) const { return "comparison"; }
getSymbol(void) const2506 string getSymbol (void) const { return ""; }
2507
getSpirvCase() const2508 SpirVCaseT getSpirvCase () const { return SPIRV_CASETYPE_COMPARE; }
2509
doApply(const EvalContext & ctx,const typename Comparison<T>::IArgs & iargs) const2510 Interval doApply (const EvalContext& ctx,
2511 const typename Comparison<T>::IArgs& iargs) const
2512 {
2513 DE_UNREF(ctx);
2514 if (iargs.a.hasNaN() || iargs.b.hasNaN())
2515 {
2516 return TCU_NAN; // one of the floats is NaN: block analysis
2517 }
2518
2519 int operationFlag = 1;
2520 int result = 0;
2521 const double a = iargs.a.midpoint();
2522 const double b = iargs.b.midpoint();
2523
2524 for (int i = 0; i<2; ++i)
2525 {
2526 if (a == b)
2527 result += operationFlag;
2528 operationFlag = operationFlag << 1;
2529
2530 if (a > b)
2531 result += operationFlag;
2532 operationFlag = operationFlag << 1;
2533
2534 if (a < b)
2535 result += operationFlag;
2536 operationFlag = operationFlag << 1;
2537
2538 if (a >= b)
2539 result += operationFlag;
2540 operationFlag = operationFlag << 1;
2541
2542 if (a <= b)
2543 result += operationFlag;
2544 operationFlag = operationFlag << 1;
2545 }
2546 return result;
2547 }
2548 };
2549
2550 template <class T>
2551 class Add : public InfixOperator < T >
2552 {
2553 public:
getName(void) const2554 string getName (void) const { return "add"; }
getSymbol(void) const2555 string getSymbol (void) const { return "+"; }
2556
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2557 Interval doApply (const EvalContext& ctx,
2558 const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs& iargs) const
2559 {
2560 // Fast-path for common case
2561 if (iargs.a.isOrdinary(ctx.format.getMaxValue()) && iargs.b.isOrdinary(ctx.format.getMaxValue()))
2562 {
2563 Interval ret;
2564 TCU_SET_INTERVAL_BOUNDS(ret, sum,
2565 sum = iargs.a.lo() + iargs.b.lo(),
2566 sum = iargs.a.hi() + iargs.b.hi());
2567 return ctx.format.convert(ctx.format.roundOut(ret, true));
2568 }
2569 return this->applyMonotone(ctx, iargs.a, iargs.b);
2570 }
2571
2572 protected:
applyExact(double x,double y) const2573 double applyExact (double x, double y) const { return x + y; }
2574 };
2575
2576 template<class T>
2577 class Mul : public InfixOperator<T>
2578 {
2579 public:
getName(void) const2580 string getName (void) const { return "mul"; }
getSymbol(void) const2581 string getSymbol (void) const { return "*"; }
2582
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2583 Interval doApply (const EvalContext& ctx, const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs& iargs) const
2584 {
2585 Interval a = iargs.a;
2586 Interval b = iargs.b;
2587
2588 // Fast-path for common case
2589 if (a.isOrdinary(ctx.format.getMaxValue()) && b.isOrdinary(ctx.format.getMaxValue()))
2590 {
2591 Interval ret;
2592 if (a.hi() < 0)
2593 {
2594 a = -a;
2595 b = -b;
2596 }
2597 if (a.lo() >= 0 && b.lo() >= 0)
2598 {
2599 TCU_SET_INTERVAL_BOUNDS(ret, prod,
2600 prod = a.lo() * b.lo(),
2601 prod = a.hi() * b.hi());
2602 return ctx.format.convert(ctx.format.roundOut(ret, true));
2603 }
2604 if (a.lo() >= 0 && b.hi() <= 0)
2605 {
2606 TCU_SET_INTERVAL_BOUNDS(ret, prod,
2607 prod = a.hi() * b.lo(),
2608 prod = a.lo() * b.hi());
2609 return ctx.format.convert(ctx.format.roundOut(ret, true));
2610 }
2611 }
2612 return this->applyMonotone(ctx, iargs.a, iargs.b);
2613 }
2614
2615 protected:
applyExact(double x,double y) const2616 double applyExact (double x, double y) const { return x * y; }
2617
innerExtrema(const EvalContext &,const Interval & xi,const Interval & yi) const2618 Interval innerExtrema(const EvalContext&, const Interval& xi, const Interval& yi) const
2619 {
2620 if (((xi.contains(-TCU_INFINITY) || xi.contains(TCU_INFINITY)) && yi.contains(0.0)) ||
2621 ((yi.contains(-TCU_INFINITY) || yi.contains(TCU_INFINITY)) && xi.contains(0.0)))
2622 return Interval(TCU_NAN);
2623
2624 return Interval();
2625 }
2626 };
2627
2628 template<class T>
2629 class Sub : public InfixOperator <T>
2630 {
2631 public:
getName(void) const2632 string getName (void) const { return "sub"; }
getSymbol(void) const2633 string getSymbol (void) const { return "-"; }
2634
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2635 Interval doApply (const EvalContext& ctx, const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs& iargs) const
2636 {
2637 // Fast-path for common case
2638 if (iargs.a.isOrdinary(ctx.format.getMaxValue()) && iargs.b.isOrdinary(ctx.format.getMaxValue()))
2639 {
2640 Interval ret;
2641
2642 TCU_SET_INTERVAL_BOUNDS(ret, diff,
2643 diff = iargs.a.lo() - iargs.b.hi(),
2644 diff = iargs.a.hi() - iargs.b.lo());
2645 return ctx.format.convert(ctx.format.roundOut(ret, true));
2646
2647 }
2648 else
2649 {
2650 return this->applyMonotone(ctx, iargs.a, iargs.b);
2651 }
2652 }
2653
2654 protected:
applyExact(double x,double y) const2655 double applyExact (double x, double y) const { return x - y; }
2656 };
2657
2658 template <class T>
2659 class Negate : public FloatFunc1<T>
2660 {
2661 public:
getName(void) const2662 string getName (void) const { return "_negate"; }
doPrint(ostream & os,const BaseArgExprs & args) const2663 void doPrint (ostream& os, const BaseArgExprs& args) const { os << "-" << *args[0]; }
2664
2665 protected:
precision(const EvalContext &,double,double) const2666 double precision (const EvalContext&, double, double) const { return 0.0; }
applyExact(double x) const2667 double applyExact (double x) const { return -x; }
2668 };
2669
2670 template <class T>
2671 class Div : public InfixOperator<T>
2672 {
2673 public:
getName(void) const2674 string getName (void) const { return "div"; }
2675
2676 protected:
getSymbol(void) const2677 string getSymbol (void) const { return "/"; }
2678
innerExtrema(const EvalContext &,const Interval & nom,const Interval & den) const2679 Interval innerExtrema (const EvalContext&,
2680 const Interval& nom,
2681 const Interval& den) const
2682 {
2683 Interval ret;
2684
2685 if (den.contains(0.0))
2686 {
2687 if (nom.contains(0.0))
2688 ret |= TCU_NAN;
2689
2690 if (nom.lo() < 0.0 || nom.hi() > 0.0)
2691 ret |= Interval::unbounded();
2692 }
2693
2694 return ret;
2695 }
2696
applyExact(double x,double y) const2697 double applyExact (double x, double y) const { return x / y; }
2698
applyPoint(const EvalContext & ctx,double x,double y) const2699 Interval applyPoint (const EvalContext& ctx, double x, double y) const
2700 {
2701 Interval ret = FloatFunc2<T>::applyPoint(ctx, x, y);
2702
2703 if (!deIsInf(x) && !deIsInf(y) && y != 0.0)
2704 {
2705 const Interval dst = ctx.format.convert(ret);
2706 if (dst.contains(-TCU_INFINITY)) ret |= -ctx.format.getMaxValue();
2707 if (dst.contains(+TCU_INFINITY)) ret |= +ctx.format.getMaxValue();
2708 }
2709
2710 return ret;
2711 }
2712
precision(const EvalContext & ctx,double ret,double,double den) const2713 double precision (const EvalContext& ctx, double ret, double, double den) const
2714 {
2715 const FloatFormat& fmt = ctx.format;
2716
2717 // \todo [2014-03-05 lauri] Check that the limits in GLSL 3.10 are actually correct.
2718 // For now, we assume that division's precision is 2.5 ULP when the value is within
2719 // [2^MINEXP, 2^MAXEXP-1]
2720
2721 if (den == 0.0)
2722 return 0.0; // Result must be exactly inf
2723 else if (de::inBounds(deAbs(den),
2724 deLdExp(1.0, fmt.getMinExp()),
2725 deLdExp(1.0, fmt.getMaxExp() - 1)))
2726 return fmt.ulp(ret, 2.5);
2727 else
2728 return TCU_INFINITY; // Can be any number, but must be a number.
2729 }
2730 };
2731
2732 template <class T>
2733 class InverseSqrt : public FloatFunc1 <T>
2734 {
2735 public:
getName(void) const2736 string getName (void) const { return "inversesqrt"; }
2737
2738 protected:
applyExact(double x) const2739 double applyExact (double x) const { return 1.0 / deSqrt(x); }
2740
precision(const EvalContext & ctx,double ret,double x) const2741 double precision (const EvalContext& ctx, double ret, double x) const
2742 {
2743 return x <= 0 ? TCU_NAN : ctx.format.ulp(ret, 2.0);
2744 }
2745
getCodomain(const EvalContext &) const2746 Interval getCodomain (const EvalContext&) const
2747 {
2748 return Interval(0.0, TCU_INFINITY);
2749 }
2750 };
2751
2752 template <class T>
2753 class ExpFunc : public CFloatFunc1<T>
2754 {
2755 public:
ExpFunc(const string & name,DoubleFunc1 & func)2756 ExpFunc (const string& name, DoubleFunc1& func)
2757 : CFloatFunc1<T> (name, func)
2758 {}
2759 protected:
2760 double precision (const EvalContext& ctx, double ret, double x) const;
getCodomain(const EvalContext &) const2761 Interval getCodomain (const EvalContext&) const
2762 {
2763 return Interval(0.0, TCU_INFINITY);
2764 }
2765 };
2766
2767 template <>
precision(const EvalContext & ctx,double ret,double x) const2768 double ExpFunc <Signature<float, float> >::precision (const EvalContext& ctx, double ret, double x) const
2769 {
2770 switch (ctx.floatPrecision)
2771 {
2772 case glu::PRECISION_HIGHP:
2773 return ctx.format.ulp(ret, 3.0 + 2.0 * deAbs(x));
2774 case glu::PRECISION_MEDIUMP:
2775 case glu::PRECISION_LAST:
2776 return ctx.format.ulp(ret, 1.0 + 2.0 * deAbs(x));
2777 default:
2778 DE_FATAL("Impossible");
2779 }
2780
2781 return 0.0;
2782 }
2783
2784 template <>
precision(const EvalContext & ctx,double ret,double x) const2785 double ExpFunc <Signature<deFloat16, deFloat16> >::precision(const EvalContext& ctx, double ret, double x) const
2786 {
2787 return ctx.format.ulp(ret, 1.0 + 2.0 * deAbs(x));
2788 }
2789
2790 template <>
precision(const EvalContext & ctx,double ret,double x) const2791 double ExpFunc <Signature<double, double> >::precision(const EvalContext& ctx, double ret, double x) const
2792 {
2793 return ctx.format.ulp(ret, 1.0 + 2.0 * deAbs(x));
2794 }
2795
2796 template <class T>
Exp2(void)2797 class Exp2 : public ExpFunc<T> { public: Exp2 (void) : ExpFunc<T>("exp2", deExp2) {} };
2798 template <class T>
Exp(void)2799 class Exp : public ExpFunc<T> { public: Exp (void) : ExpFunc<T>("exp", deExp) {} };
2800
2801 template <typename T>
exp2(const ExprP<T> & x)2802 ExprP<T> exp2 (const ExprP<T>& x) { return app<Exp2< Signature<T, T> > >(x); }
2803 template <typename T>
exp(const ExprP<T> & x)2804 ExprP<T> exp (const ExprP<T>& x) { return app<Exp< Signature<T, T> > >(x); }
2805
2806 template <class T>
2807 class LogFunc : public CFloatFunc1<T>
2808 {
2809 public:
LogFunc(const string & name,DoubleFunc1 & func)2810 LogFunc (const string& name, DoubleFunc1& func)
2811 : CFloatFunc1<T>(name, func) {}
2812
2813 protected:
2814 double precision (const EvalContext& ctx, double ret, double x) const;
2815 };
2816
2817 template <>
precision(const EvalContext & ctx,double ret,double x) const2818 double LogFunc<Signature<float, float> >::precision(const EvalContext& ctx, double ret, double x) const
2819 {
2820 if (x <= 0)
2821 return TCU_NAN;
2822
2823 switch (ctx.floatPrecision)
2824 {
2825 case glu::PRECISION_HIGHP:
2826 return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -21) : ctx.format.ulp(ret, 3.0);
2827 case glu::PRECISION_MEDIUMP:
2828 case glu::PRECISION_LAST:
2829 return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -7) : ctx.format.ulp(ret, 3.0);
2830 default:
2831 DE_FATAL("Impossible");
2832 }
2833
2834 return 0;
2835 }
2836
2837 template <>
precision(const EvalContext & ctx,double ret,double x) const2838 double LogFunc<Signature<deFloat16, deFloat16> >::precision(const EvalContext& ctx, double ret, double x) const
2839 {
2840 if (x <= 0)
2841 return TCU_NAN;
2842 return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -7) : ctx.format.ulp(ret, 3.0);
2843 }
2844
2845 // Spec: "The precision of double-precision instructions is at least that of single precision."
2846 // Lets pick float high precision as a reference.
2847 template <>
precision(const EvalContext & ctx,double ret,double x) const2848 double LogFunc<Signature<double, double> >::precision(const EvalContext& ctx, double ret, double x) const
2849 {
2850 if (x <= 0)
2851 return TCU_NAN;
2852 return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -21) : ctx.format.ulp(ret, 3.0);
2853 }
2854
2855 template <class T>
Log2(void)2856 class Log2 : public LogFunc<T> { public: Log2 (void) : LogFunc<T>("log2", deLog2) {} };
2857 template <class T>
Log(void)2858 class Log : public LogFunc<T> { public: Log (void) : LogFunc<T>("log", deLog) {} };
2859
log2(const ExprP<float> & x)2860 ExprP<float> log2 (const ExprP<float>& x) { return app<Log2< Signature<float, float> > >(x); }
log(const ExprP<float> & x)2861 ExprP<float> log (const ExprP<float>& x) { return app<Log< Signature<float, float> > >(x); }
2862
log2(const ExprP<deFloat16> & x)2863 ExprP<deFloat16> log2 (const ExprP<deFloat16>& x) { return app<Log2< Signature<deFloat16, deFloat16> > >(x); }
log(const ExprP<deFloat16> & x)2864 ExprP<deFloat16> log (const ExprP<deFloat16>& x) { return app<Log< Signature<deFloat16, deFloat16> > >(x); }
2865
log2(const ExprP<double> & x)2866 ExprP<double> log2 (const ExprP<double>& x) { return app<Log2< Signature<double, double> > >(x); }
log(const ExprP<double> & x)2867 ExprP<double> log (const ExprP<double>& x) { return app<Log< Signature<double, double> > >(x); }
2868
2869 #define DEFINE_CONSTRUCTOR1(CLASS, TRET, NAME, T0) \
2870 ExprP<TRET> NAME (const ExprP<T0>& arg0) { return app<CLASS>(arg0); }
2871
2872 #define DEFINE_DERIVED1(CLASS, TRET, NAME, T0, ARG0, EXPANSION) \
2873 class CLASS : public DerivedFunc<Signature<TRET, T0> > /* NOLINT(CLASS) */ \
2874 { \
2875 public: \
2876 string getName (void) const { return #NAME; } \
2877 \
2878 protected: \
2879 ExprP<TRET> doExpand (ExpandContext&, \
2880 const CLASS::ArgExprs& args_) const \
2881 { \
2882 const ExprP<T0>& ARG0 = args_.a; \
2883 return EXPANSION; \
2884 } \
2885 }; \
2886 DEFINE_CONSTRUCTOR1(CLASS, TRET, NAME, T0)
2887
2888 #define DEFINE_DERIVED_DOUBLE1(CLASS, NAME, ARG0, EXPANSION) \
2889 DEFINE_DERIVED1(CLASS, double, NAME, double, ARG0, EXPANSION)
2890
2891 #define DEFINE_DERIVED_FLOAT1(CLASS, NAME, ARG0, EXPANSION) \
2892 DEFINE_DERIVED1(CLASS, float, NAME, float, ARG0, EXPANSION)
2893
2894
2895 #define DEFINE_DERIVED1_INPUTRANGE(CLASS, TRET, NAME, T0, ARG0, EXPANSION, INTERVAL) \
2896 class CLASS : public DerivedFunc<Signature<TRET, T0> > /* NOLINT(CLASS) */ \
2897 { \
2898 public: \
2899 string getName (void) const { return #NAME; } \
2900 \
2901 protected: \
2902 ExprP<TRET> doExpand (ExpandContext&, \
2903 const CLASS::ArgExprs& args_) const \
2904 { \
2905 const ExprP<T0>& ARG0 = args_.a; \
2906 return EXPANSION; \
2907 } \
2908 Interval getInputRange (const bool /*is16bit*/) const \
2909 { \
2910 return INTERVAL; \
2911 } \
2912 }; \
2913 DEFINE_CONSTRUCTOR1(CLASS, TRET, NAME, T0)
2914
2915 #define DEFINE_DERIVED_FLOAT1_INPUTRANGE(CLASS, NAME, ARG0, EXPANSION, INTERVAL) \
2916 DEFINE_DERIVED1_INPUTRANGE(CLASS, float, NAME, float, ARG0, EXPANSION, INTERVAL)
2917
2918 #define DEFINE_DERIVED_DOUBLE1_INPUTRANGE(CLASS, NAME, ARG0, EXPANSION, INTERVAL) \
2919 DEFINE_DERIVED1_INPUTRANGE(CLASS, double, NAME, double, ARG0, EXPANSION, INTERVAL)
2920
2921 #define DEFINE_DERIVED_FLOAT1_16BIT(CLASS, NAME, ARG0, EXPANSION) \
2922 DEFINE_DERIVED1(CLASS, deFloat16, NAME, deFloat16, ARG0, EXPANSION)
2923
2924 #define DEFINE_DERIVED_FLOAT1_INPUTRANGE_16BIT(CLASS, NAME, ARG0, EXPANSION, INTERVAL) \
2925 DEFINE_DERIVED1_INPUTRANGE(CLASS, deFloat16, NAME, deFloat16, ARG0, EXPANSION, INTERVAL)
2926
2927 #define DEFINE_CONSTRUCTOR2(CLASS, TRET, NAME, T0, T1) \
2928 ExprP<TRET> NAME (const ExprP<T0>& arg0, const ExprP<T1>& arg1) \
2929 { \
2930 return app<CLASS>(arg0, arg1); \
2931 }
2932
2933 #define DEFINE_CASED_DERIVED2(CLASS, TRET, NAME, T0, Arg0, T1, Arg1, EXPANSION, SPIRVCASE) \
2934 class CLASS : public DerivedFunc<Signature<TRET, T0, T1> > /* NOLINT(CLASS) */ \
2935 { \
2936 public: \
2937 string getName (void) const { return #NAME; } \
2938 \
2939 SpirVCaseT getSpirvCase(void) const { return SPIRVCASE; } \
2940 \
2941 protected: \
2942 ExprP<TRET> doExpand (ExpandContext&, const ArgExprs& args_) const \
2943 { \
2944 const ExprP<T0>& Arg0 = args_.a; \
2945 const ExprP<T1>& Arg1 = args_.b; \
2946 return EXPANSION; \
2947 } \
2948 }; \
2949 DEFINE_CONSTRUCTOR2(CLASS, TRET, NAME, T0, T1)
2950
2951 #define DEFINE_DERIVED2(CLASS, TRET, NAME, T0, Arg0, T1, Arg1, EXPANSION) \
2952 DEFINE_CASED_DERIVED2(CLASS, TRET, NAME, T0, Arg0, T1, Arg1, EXPANSION, SPIRV_CASETYPE_NONE)
2953
2954 #define DEFINE_DERIVED_DOUBLE2(CLASS, NAME, Arg0, Arg1, EXPANSION) \
2955 DEFINE_DERIVED2(CLASS, double, NAME, double, Arg0, double, Arg1, EXPANSION)
2956
2957 #define DEFINE_DERIVED_FLOAT2(CLASS, NAME, Arg0, Arg1, EXPANSION) \
2958 DEFINE_DERIVED2(CLASS, float, NAME, float, Arg0, float, Arg1, EXPANSION)
2959
2960 #define DEFINE_DERIVED_FLOAT2_16BIT(CLASS, NAME, Arg0, Arg1, EXPANSION) \
2961 DEFINE_DERIVED2(CLASS, deFloat16, NAME, deFloat16, Arg0, deFloat16, Arg1, EXPANSION)
2962
2963 #define DEFINE_CASED_DERIVED_FLOAT2(CLASS, NAME, Arg0, Arg1, EXPANSION, SPIRVCASE) \
2964 DEFINE_CASED_DERIVED2(CLASS, float, NAME, float, Arg0, float, Arg1, EXPANSION, SPIRVCASE)
2965
2966 #define DEFINE_CASED_DERIVED_FLOAT2_16BIT(CLASS, NAME, Arg0, Arg1, EXPANSION, SPIRVCASE) \
2967 DEFINE_CASED_DERIVED2(CLASS, deFloat16, NAME, deFloat16, Arg0, deFloat16, Arg1, EXPANSION, SPIRVCASE)
2968
2969 #define DEFINE_CASED_DERIVED_DOUBLE2(CLASS, NAME, Arg0, Arg1, EXPANSION, SPIRVCASE) \
2970 DEFINE_CASED_DERIVED2(CLASS, double, NAME, double, Arg0, double, Arg1, EXPANSION, SPIRVCASE)
2971
2972 #define DEFINE_CONSTRUCTOR3(CLASS, TRET, NAME, T0, T1, T2) \
2973 ExprP<TRET> NAME (const ExprP<T0>& arg0, const ExprP<T1>& arg1, const ExprP<T2>& arg2) \
2974 { \
2975 return app<CLASS>(arg0, arg1, arg2); \
2976 }
2977
2978 #define DEFINE_DERIVED3(CLASS, TRET, NAME, T0, ARG0, T1, ARG1, T2, ARG2, EXPANSION) \
2979 class CLASS : public DerivedFunc<Signature<TRET, T0, T1, T2> > /* NOLINT(CLASS) */ \
2980 { \
2981 public: \
2982 string getName (void) const { return #NAME; } \
2983 \
2984 protected: \
2985 ExprP<TRET> doExpand (ExpandContext&, const ArgExprs& args_) const \
2986 { \
2987 const ExprP<T0>& ARG0 = args_.a; \
2988 const ExprP<T1>& ARG1 = args_.b; \
2989 const ExprP<T2>& ARG2 = args_.c; \
2990 return EXPANSION; \
2991 } \
2992 }; \
2993 DEFINE_CONSTRUCTOR3(CLASS, TRET, NAME, T0, T1, T2)
2994
2995 #define DEFINE_DERIVED_DOUBLE3(CLASS, NAME, ARG0, ARG1, ARG2, EXPANSION) \
2996 DEFINE_DERIVED3(CLASS, double, NAME, double, ARG0, double, ARG1, double, ARG2, EXPANSION)
2997
2998 #define DEFINE_DERIVED_FLOAT3(CLASS, NAME, ARG0, ARG1, ARG2, EXPANSION) \
2999 DEFINE_DERIVED3(CLASS, float, NAME, float, ARG0, float, ARG1, float, ARG2, EXPANSION)
3000
3001 #define DEFINE_DERIVED_FLOAT3_16BIT(CLASS, NAME, ARG0, ARG1, ARG2, EXPANSION) \
3002 DEFINE_DERIVED3(CLASS, deFloat16, NAME, deFloat16, ARG0, deFloat16, ARG1, deFloat16, ARG2, EXPANSION)
3003
3004 #define DEFINE_CONSTRUCTOR4(CLASS, TRET, NAME, T0, T1, T2, T3) \
3005 ExprP<TRET> NAME (const ExprP<T0>& arg0, const ExprP<T1>& arg1, \
3006 const ExprP<T2>& arg2, const ExprP<T3>& arg3) \
3007 { \
3008 return app<CLASS>(arg0, arg1, arg2, arg3); \
3009 }
3010
3011 typedef InverseSqrt< Signature<deFloat16, deFloat16> > InverseSqrt16Bit;
3012 typedef InverseSqrt< Signature<float, float> > InverseSqrt32Bit;
3013 typedef InverseSqrt< Signature<double, double> > InverseSqrt64Bit;
3014
3015 DEFINE_DERIVED_FLOAT1(Sqrt32Bit, sqrt, x, constant(1.0f) / app<InverseSqrt32Bit>(x))
3016 DEFINE_DERIVED_FLOAT1_16BIT(Sqrt16Bit, sqrt, x, constant((deFloat16)FLOAT16_1_0) / app<InverseSqrt16Bit>(x))
3017 DEFINE_DERIVED_DOUBLE1(Sqrt64Bit, sqrt, x, constant(1.0) / app<InverseSqrt64Bit>(x))
3018 DEFINE_DERIVED_FLOAT2(Pow, pow, x, y, exp2<float>(y * log2(x)))
3019 DEFINE_DERIVED_FLOAT2_16BIT(Pow16, pow, x, y, exp2<deFloat16>(y * log2(x)))
3020 DEFINE_DERIVED_DOUBLE2(Pow64, pow, x, y, exp2<double>(y * log2(x)))
3021 DEFINE_DERIVED_FLOAT1(Radians, radians, d, (constant(DE_PI) / constant(180.0f)) * d)
3022 DEFINE_DERIVED_FLOAT1_16BIT(Radians16, radians, d, (constant((deFloat16)DE_PI_16BIT) / constant((deFloat16)FLOAT16_180_0)) * d)
3023 DEFINE_DERIVED_DOUBLE1(Radians64, radians, d, (constant((double)(DE_PI)) / constant(180.0)) * d)
3024 DEFINE_DERIVED_FLOAT1(Degrees, degrees, r, (constant(180.0f) / constant(DE_PI)) * r)
3025 DEFINE_DERIVED_FLOAT1_16BIT(Degrees16, degrees, r, (constant((deFloat16)FLOAT16_180_0) / constant((deFloat16)DE_PI_16BIT)) * r)
3026 DEFINE_DERIVED_DOUBLE1(Degrees64, degrees, r, (constant(180.0) / constant((double)(DE_PI))) * r)
3027
3028 /*Proper parameters for template T
3029 Signature<float, float> 32bit tests
3030 Signature<float, deFloat16> 16bit tests*/
3031 template<class T>
3032 class TrigFunc : public CFloatFunc1<T>
3033 {
3034 public:
TrigFunc(const string & name,DoubleFunc1 & func,const Interval & loEx,const Interval & hiEx)3035 TrigFunc (const string& name,
3036 DoubleFunc1& func,
3037 const Interval& loEx,
3038 const Interval& hiEx)
3039 : CFloatFunc1<T> (name, func)
3040 , m_loExtremum (loEx)
3041 , m_hiExtremum (hiEx) {}
3042
3043 protected:
innerExtrema(const EvalContext &,const Interval & angle) const3044 Interval innerExtrema (const EvalContext&, const Interval& angle) const
3045 {
3046 const double lo = angle.lo();
3047 const double hi = angle.hi();
3048 const int loSlope = doGetSlope(lo);
3049 const int hiSlope = doGetSlope(hi);
3050
3051 // Detect the high and low values the function can take between the
3052 // interval endpoints.
3053 if (angle.length() >= 2.0 * DE_PI_DOUBLE)
3054 {
3055 // The interval is longer than a full cycle, so it must get all possible values.
3056 return m_hiExtremum | m_loExtremum;
3057 }
3058 else if (loSlope == 1 && hiSlope == -1)
3059 {
3060 // The slope can change from positive to negative only at the maximum value.
3061 return m_hiExtremum;
3062 }
3063 else if (loSlope == -1 && hiSlope == 1)
3064 {
3065 // The slope can change from negative to positive only at the maximum value.
3066 return m_loExtremum;
3067 }
3068 else if (loSlope == hiSlope &&
3069 deIntSign(CFloatFunc1<T>::applyExact(hi) - CFloatFunc1<T>::applyExact(lo)) * loSlope == -1)
3070 {
3071 // The slope has changed twice between the endpoints, so both extrema are included.
3072 return m_hiExtremum | m_loExtremum;
3073 }
3074
3075 return Interval();
3076 }
3077
getCodomain(const EvalContext &) const3078 Interval getCodomain (const EvalContext&) const
3079 {
3080 // Ensure that result is always within [-1, 1], or NaN (for +-inf)
3081 return Interval(-1.0, 1.0) | TCU_NAN;
3082 }
3083
3084 double precision (const EvalContext& ctx, double ret, double arg) const;
3085
3086 Interval getInputRange (const bool is16bit) const;
3087 virtual int doGetSlope (double angle) const = 0;
3088
3089 Interval m_loExtremum;
3090 Interval m_hiExtremum;
3091 };
3092
3093 //Only -DE_PI_DOUBLE, DE_PI_DOUBLE input range
3094 template<>
getInputRange(const bool is16bit) const3095 Interval TrigFunc<Signature<float, float> >::getInputRange(const bool is16bit) const
3096 {
3097 DE_UNREF(is16bit);
3098 return Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE);
3099 }
3100
3101 //Only -DE_PI_DOUBLE, DE_PI_DOUBLE input range
3102 template<>
getInputRange(const bool is16bit) const3103 Interval TrigFunc<Signature<deFloat16, deFloat16> >::getInputRange(const bool is16bit) const
3104 {
3105 DE_UNREF(is16bit);
3106 return Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE);
3107 }
3108
3109 //Only -DE_PI_DOUBLE, DE_PI_DOUBLE input range
3110 template<>
getInputRange(const bool is16bit) const3111 Interval TrigFunc<Signature<double, double> >::getInputRange(const bool is16bit) const
3112 {
3113 DE_UNREF(is16bit);
3114 return Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE);
3115 }
3116
3117 template<>
precision(const EvalContext & ctx,double ret,double arg) const3118 double TrigFunc<Signature<float, float> >::precision(const EvalContext& ctx, double ret, double arg) const
3119 {
3120 DE_UNREF(ret);
3121 if (ctx.floatPrecision == glu::PRECISION_HIGHP)
3122 {
3123 if (-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE)
3124 return deLdExp(1.0, -11);
3125 else
3126 {
3127 // "larger otherwise", let's pick |x| * 2^-12 , which is slightly over
3128 // 2^-11 at x == pi.
3129 return deLdExp(deAbs(arg), -12);
3130 }
3131 }
3132 else
3133 {
3134 DE_ASSERT(ctx.floatPrecision == glu::PRECISION_MEDIUMP || ctx.floatPrecision == glu::PRECISION_LAST);
3135
3136 if (-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE)
3137 return deLdExp(1.0, -7);
3138 else
3139 {
3140 // |x| * 2^-8, slightly larger than 2^-7 at x == pi
3141 return deLdExp(deAbs(arg), -8);
3142 }
3143 }
3144 }
3145 //
3146 /*
3147 * Half tests
3148 * From Spec:
3149 * Absolute error 2^{-7} inside the range [-pi, pi].
3150 */
3151 template<>
precision(const EvalContext & ctx,double ret,double arg) const3152 double TrigFunc<Signature<deFloat16, deFloat16> >::precision(const EvalContext& ctx, double ret, double arg) const
3153 {
3154 DE_UNREF(ctx);
3155 DE_UNREF(ret);
3156 DE_UNREF(arg);
3157 DE_ASSERT(-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE && ctx.floatPrecision == glu::PRECISION_LAST);
3158 return deLdExp(1.0, -7);
3159 }
3160
3161 // Spec: "The precision of double-precision instructions is at least that of single precision."
3162 // Lets pick float high precision as a reference.
3163 template<>
precision(const EvalContext & ctx,double ret,double arg) const3164 double TrigFunc<Signature<double, double> >::precision(const EvalContext& ctx, double ret, double arg) const
3165 {
3166 DE_UNREF(ctx);
3167 DE_UNREF(ret);
3168 if (-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE)
3169 return deLdExp(1.0, -11);
3170 else
3171 {
3172 // "larger otherwise", let's pick |x| * 2^-12 , which is slightly over
3173 // 2^-11 at x == pi.
3174 return deLdExp(deAbs(arg), -12);
3175 }
3176 }
3177
3178 /*Proper parameters for template T
3179 Signature<float, float> 32bit tests
3180 Signature<float, deFloat16> 16bit tests*/
3181 template <class T>
3182 class Sin : public TrigFunc<T>
3183 {
3184 public:
Sin(void)3185 Sin (void) : TrigFunc<T>("sin", deSin, -1.0, 1.0) {}
3186
3187 protected:
doGetSlope(double angle) const3188 int doGetSlope (double angle) const { return deIntSign(deCos(angle)); }
3189 };
3190
sin(const ExprP<float> & x)3191 ExprP<float> sin (const ExprP<float>& x) { return app<Sin<Signature<float, float> > >(x); }
sin(const ExprP<deFloat16> & x)3192 ExprP<deFloat16> sin (const ExprP<deFloat16>& x) { return app<Sin<Signature<deFloat16, deFloat16> > >(x); }
sin(const ExprP<double> & x)3193 ExprP<double> sin (const ExprP<double>& x) { return app<Sin<Signature<double, double> > >(x); }
3194
3195 template <class T>
3196 class Cos : public TrigFunc<T>
3197 {
3198 public:
Cos(void)3199 Cos (void) : TrigFunc<T> ("cos", deCos, -1.0, 1.0) {}
3200
3201 protected:
doGetSlope(double angle) const3202 int doGetSlope (double angle) const { return -deIntSign(deSin(angle)); }
3203 };
3204
cos(const ExprP<float> & x)3205 ExprP<float> cos (const ExprP<float>& x) { return app<Cos<Signature<float, float> > >(x); }
cos(const ExprP<deFloat16> & x)3206 ExprP<deFloat16> cos (const ExprP<deFloat16>& x) { return app<Cos<Signature<deFloat16, deFloat16> > >(x); }
cos(const ExprP<double> & x)3207 ExprP<double> cos (const ExprP<double>& x) { return app<Cos<Signature<double, double> > >(x); }
3208
3209 DEFINE_DERIVED_FLOAT1_INPUTRANGE(Tan, tan, x, sin(x) * (constant(1.0f) / cos(x)), Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE))
3210 DEFINE_DERIVED_FLOAT1_INPUTRANGE_16BIT(Tan16Bit, tan, x, sin(x) * (constant((deFloat16)FLOAT16_1_0) / cos(x)), Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE))
3211 DEFINE_DERIVED_DOUBLE1_INPUTRANGE(Tan64Bit, tan, x, sin(x) * (constant(1.0) / cos(x)), Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE))
3212
3213 template <class T>
3214 class ATan : public CFloatFunc1<T>
3215 {
3216 public:
ATan(void)3217 ATan (void) : CFloatFunc1<T> ("atan", deAtanOver) {}
3218
3219 protected:
precision(const EvalContext & ctx,double ret,double) const3220 double precision (const EvalContext& ctx, double ret, double) const
3221 {
3222 if (ctx.floatPrecision == glu::PRECISION_HIGHP)
3223 return ctx.format.ulp(ret, 4096.0);
3224 else
3225 return ctx.format.ulp(ret, 5.0);
3226 }
3227
getCodomain(const EvalContext & ctx) const3228 Interval getCodomain(const EvalContext& ctx) const
3229 {
3230 return ctx.format.roundOut(Interval(-0.5 * DE_PI_DOUBLE, 0.5 * DE_PI_DOUBLE), true);
3231 }
3232 };
3233
3234 template <class T>
3235 class ATan2 : public CFloatFunc2<T>
3236 {
3237 public:
ATan2(void)3238 ATan2 (void) : CFloatFunc2<T> ("atan", deAtan2) {}
3239
3240 protected:
innerExtrema(const EvalContext & ctx,const Interval & yi,const Interval & xi) const3241 Interval innerExtrema (const EvalContext& ctx,
3242 const Interval& yi,
3243 const Interval& xi) const
3244 {
3245 Interval ret;
3246
3247 if (yi.contains(0.0))
3248 {
3249 if (xi.contains(0.0))
3250 ret |= TCU_NAN;
3251 if (xi.intersects(Interval(-TCU_INFINITY, 0.0)))
3252 ret |= ctx.format.roundOut(Interval(-DE_PI_DOUBLE, DE_PI_DOUBLE), true);
3253 }
3254
3255 if (!yi.isFinite(ctx.format.getMaxValue()) || !xi.isFinite(ctx.format.getMaxValue()))
3256 {
3257 // Infinities may not be supported, allow anything, including NaN
3258 ret |= TCU_NAN;
3259 }
3260
3261 return ret;
3262 }
3263
precision(const EvalContext & ctx,double ret,double,double) const3264 double precision (const EvalContext& ctx, double ret, double, double) const
3265 {
3266 if (ctx.floatPrecision == glu::PRECISION_HIGHP)
3267 return ctx.format.ulp(ret, 4096.0);
3268 else
3269 return ctx.format.ulp(ret, 5.0);
3270 }
3271
getCodomain(const EvalContext & ctx) const3272 Interval getCodomain(const EvalContext& ctx) const
3273 {
3274 return ctx.format.roundOut(Interval(-DE_PI_DOUBLE, DE_PI_DOUBLE), true);
3275 }
3276 };
3277
atan2(const ExprP<float> & x,const ExprP<float> & y)3278 ExprP<float> atan2 (const ExprP<float>& x, const ExprP<float>& y) { return app<ATan2<Signature<float, float, float> > >(x, y); }
3279
atan2(const ExprP<deFloat16> & x,const ExprP<deFloat16> & y)3280 ExprP<deFloat16> atan2 (const ExprP<deFloat16>& x, const ExprP<deFloat16>& y) { return app<ATan2<Signature<deFloat16, deFloat16, deFloat16> > >(x, y); }
3281
atan2(const ExprP<double> & x,const ExprP<double> & y)3282 ExprP<double> atan2 (const ExprP<double>& x, const ExprP<double>& y) { return app<ATan2<Signature<double, double, double> > >(x, y); }
3283
3284
3285 DEFINE_DERIVED_FLOAT1(Sinh, sinh, x, (exp<float>(x) - exp<float>(-x)) / constant(2.0f))
3286 DEFINE_DERIVED_FLOAT1(Cosh, cosh, x, (exp<float>(x) + exp<float>(-x)) / constant(2.0f))
3287 DEFINE_DERIVED_FLOAT1(Tanh, tanh, x, sinh(x) / cosh(x))
3288
3289 DEFINE_DERIVED_FLOAT1_16BIT(Sinh16Bit, sinh, x, (exp(x) - exp(-x)) / constant((deFloat16)FLOAT16_2_0))
3290 DEFINE_DERIVED_FLOAT1_16BIT(Cosh16Bit, cosh, x, (exp(x) + exp(-x)) / constant((deFloat16)FLOAT16_2_0))
3291 DEFINE_DERIVED_FLOAT1_16BIT(Tanh16Bit, tanh, x, sinh(x) / cosh(x))
3292
3293 DEFINE_DERIVED_DOUBLE1(Sinh64Bit, sinh, x, (exp<double>(x) - exp<double>(-x)) / constant(2.0))
3294 DEFINE_DERIVED_DOUBLE1(Cosh64Bit, cosh, x, (exp<double>(x) + exp<double>(-x)) / constant(2.0))
3295 DEFINE_DERIVED_DOUBLE1(Tanh64Bit, tanh, x, sinh(x) / cosh(x))
3296
3297 DEFINE_DERIVED_FLOAT1(ASin, asin, x, atan2(x, sqrt(constant(1.0f) - x * x)))
3298 DEFINE_DERIVED_FLOAT1(ACos, acos, x, atan2(sqrt(constant(1.0f) - x * x), x))
3299 DEFINE_DERIVED_FLOAT1(ASinh, asinh, x, log(x + sqrt(x * x + constant(1.0f))))
3300 DEFINE_DERIVED_FLOAT1(ACosh, acosh, x, log(x + sqrt(alternatives((x + constant(1.0f)) * (x - constant(1.0f)),
3301 (x * x - constant(1.0f))))))
3302 DEFINE_DERIVED_FLOAT1(ATanh, atanh, x, constant(0.5f) * log((constant(1.0f) + x) /
3303 (constant(1.0f) - x)))
3304
3305 DEFINE_DERIVED_FLOAT1_16BIT(ASin16Bit, asin, x, atan2(x, sqrt(constant((deFloat16)FLOAT16_1_0) - x * x)))
3306 DEFINE_DERIVED_FLOAT1_16BIT(ACos16Bit, acos, x, atan2(sqrt(constant((deFloat16)FLOAT16_1_0) - x * x), x))
3307 DEFINE_DERIVED_FLOAT1_16BIT(ASinh16Bit, asinh, x, log(x + sqrt(x * x + constant((deFloat16)FLOAT16_1_0))))
3308 DEFINE_DERIVED_FLOAT1_16BIT(ACosh16Bit, acosh, x, log(x + sqrt(alternatives((x + constant((deFloat16)FLOAT16_1_0)) * (x - constant((deFloat16)FLOAT16_1_0)),
3309 (x * x - constant((deFloat16)FLOAT16_1_0))))))
3310 DEFINE_DERIVED_FLOAT1_16BIT(ATanh16Bit, atanh, x, constant((deFloat16)FLOAT16_0_5) * log((constant((deFloat16)FLOAT16_1_0) + x) /
3311 (constant((deFloat16)FLOAT16_1_0) - x)))
3312
3313 DEFINE_DERIVED_DOUBLE1(ASin64Bit, asin, x, atan2(x, sqrt(constant(1.0) - pow(x, constant(2.0)))))
3314 DEFINE_DERIVED_DOUBLE1(ACos64Bit, acos, x, atan2(sqrt(constant(1.0) - pow(x, constant(2.0))), x))
3315 DEFINE_DERIVED_DOUBLE1(ASinh64Bit, asinh, x, log(x + sqrt(x * x + constant(1.0))))
3316 DEFINE_DERIVED_DOUBLE1(ACosh64Bit, acosh, x, log(x + sqrt(alternatives((x + constant(1.0)) * (x - constant(1.0)),
3317 (x * x - constant(1.0))))))
3318 DEFINE_DERIVED_DOUBLE1(ATanh64Bit, atanh, x, constant(0.5) * log((constant(1.0) + x) /
3319 (constant(1.0) - x)))
3320
3321 template <typename T>
3322 class GetComponent : public PrimitiveFunc<Signature<typename T::Element, T, int> >
3323 {
3324 public:
3325 typedef typename GetComponent::IRet IRet;
3326
getName(void) const3327 string getName (void) const { return "_getComponent"; }
3328
print(ostream & os,const BaseArgExprs & args) const3329 void print (ostream& os,
3330 const BaseArgExprs& args) const
3331 {
3332 os << *args[0] << "[" << *args[1] << "]";
3333 }
3334
3335 protected:
doApply(const EvalContext &,const typename GetComponent::IArgs & iargs) const3336 IRet doApply (const EvalContext&,
3337 const typename GetComponent::IArgs& iargs) const
3338 {
3339 IRet ret;
3340
3341 for (int compNdx = 0; compNdx < T::SIZE; ++compNdx)
3342 {
3343 if (iargs.b.contains(compNdx))
3344 ret = unionIVal<typename T::Element>(ret, iargs.a[compNdx]);
3345 }
3346
3347 return ret;
3348 }
3349
3350 };
3351
3352 template <typename T>
getComponent(const ExprP<T> & container,int ndx)3353 ExprP<typename T::Element> getComponent (const ExprP<T>& container, int ndx)
3354 {
3355 DE_ASSERT(0 <= ndx && ndx < T::SIZE);
3356 return app<GetComponent<T> >(container, constant(ndx));
3357 }
3358
3359 template <typename T> string vecNamePrefix (void);
vecNamePrefix(void)3360 template <> string vecNamePrefix<float> (void) { return ""; }
vecNamePrefix(void)3361 template <> string vecNamePrefix<deFloat16>(void) { return ""; }
vecNamePrefix(void)3362 template <> string vecNamePrefix<double> (void) { return "d"; }
vecNamePrefix(void)3363 template <> string vecNamePrefix<int> (void) { return "i"; }
vecNamePrefix(void)3364 template <> string vecNamePrefix<bool> (void) { return "b"; }
3365
3366 template <typename T, int Size>
vecName(void)3367 string vecName (void) { return vecNamePrefix<T>() + "vec" + de::toString(Size); }
3368
3369 template <typename T, int Size> class GenVec;
3370
3371 template <typename T>
3372 class GenVec<T, 1> : public DerivedFunc<Signature<T, T> >
3373 {
3374 public:
3375 typedef typename GenVec<T, 1>::ArgExprs ArgExprs;
3376
getName(void) const3377 string getName (void) const
3378 {
3379 return "_" + vecName<T, 1>();
3380 }
3381
3382 protected:
3383
doExpand(ExpandContext &,const ArgExprs & args) const3384 ExprP<T> doExpand (ExpandContext&, const ArgExprs& args) const { return args.a; }
3385 };
3386
3387 template <typename T>
3388 class GenVec<T, 2> : public PrimitiveFunc<Signature<Vector<T, 2>, T, T> >
3389 {
3390 public:
3391 typedef typename GenVec::IRet IRet;
3392 typedef typename GenVec::IArgs IArgs;
3393
getName(void) const3394 string getName (void) const
3395 {
3396 return vecName<T, 2>();
3397 }
3398
3399 protected:
doApply(const EvalContext &,const IArgs & iargs) const3400 IRet doApply (const EvalContext&, const IArgs& iargs) const
3401 {
3402 return IRet(iargs.a, iargs.b);
3403 }
3404 };
3405
3406 template <typename T>
3407 class GenVec<T, 3> : public PrimitiveFunc<Signature<Vector<T, 3>, T, T, T> >
3408 {
3409 public:
3410 typedef typename GenVec::IRet IRet;
3411 typedef typename GenVec::IArgs IArgs;
3412
getName(void) const3413 string getName (void) const
3414 {
3415 return vecName<T, 3>();
3416 }
3417
3418 protected:
doApply(const EvalContext &,const IArgs & iargs) const3419 IRet doApply (const EvalContext&, const IArgs& iargs) const
3420 {
3421 return IRet(iargs.a, iargs.b, iargs.c);
3422 }
3423 };
3424
3425 template <typename T>
3426 class GenVec<T, 4> : public PrimitiveFunc<Signature<Vector<T, 4>, T, T, T, T> >
3427 {
3428 public:
3429 typedef typename GenVec::IRet IRet;
3430 typedef typename GenVec::IArgs IArgs;
3431
getName(void) const3432 string getName (void) const { return vecName<T, 4>(); }
3433
3434 protected:
doApply(const EvalContext &,const IArgs & iargs) const3435 IRet doApply (const EvalContext&, const IArgs& iargs) const
3436 {
3437 return IRet(iargs.a, iargs.b, iargs.c, iargs.d);
3438 }
3439 };
3440
3441 template <typename T, int Rows, int Columns>
3442 class GenMat;
3443
3444 template <typename T, int Rows>
3445 class GenMat<T, Rows, 2> : public PrimitiveFunc<
3446 Signature<Matrix<T, Rows, 2>, Vector<T, Rows>, Vector<T, Rows> > >
3447 {
3448 public:
3449 typedef typename GenMat::Ret Ret;
3450 typedef typename GenMat::IRet IRet;
3451 typedef typename GenMat::IArgs IArgs;
3452
getName(void) const3453 string getName (void) const
3454 {
3455 return dataTypeNameOf<Ret>();
3456 }
3457
3458 protected:
3459
doApply(const EvalContext &,const IArgs & iargs) const3460 IRet doApply (const EvalContext&, const IArgs& iargs) const
3461 {
3462 IRet ret;
3463 ret[0] = iargs.a;
3464 ret[1] = iargs.b;
3465 return ret;
3466 }
3467 };
3468
3469 template <typename T, int Rows>
3470 class GenMat<T, Rows, 3> : public PrimitiveFunc<
3471 Signature<Matrix<T, Rows, 3>, Vector<T, Rows>, Vector<T, Rows>, Vector<T, Rows> > >
3472 {
3473 public:
3474 typedef typename GenMat::Ret Ret;
3475 typedef typename GenMat::IRet IRet;
3476 typedef typename GenMat::IArgs IArgs;
3477
getName(void) const3478 string getName (void) const
3479 {
3480 return dataTypeNameOf<Ret>();
3481 }
3482
3483 protected:
3484
doApply(const EvalContext &,const IArgs & iargs) const3485 IRet doApply (const EvalContext&, const IArgs& iargs) const
3486 {
3487 IRet ret;
3488 ret[0] = iargs.a;
3489 ret[1] = iargs.b;
3490 ret[2] = iargs.c;
3491 return ret;
3492 }
3493 };
3494
3495 template <typename T, int Rows>
3496 class GenMat<T, Rows, 4> : public PrimitiveFunc<
3497 Signature<Matrix<T, Rows, 4>,
3498 Vector<T, Rows>, Vector<T, Rows>, Vector<T, Rows>, Vector<T, Rows> > >
3499 {
3500 public:
3501 typedef typename GenMat::Ret Ret;
3502 typedef typename GenMat::IRet IRet;
3503 typedef typename GenMat::IArgs IArgs;
3504
getName(void) const3505 string getName (void) const
3506 {
3507 return dataTypeNameOf<Ret>();
3508 }
3509
3510 protected:
doApply(const EvalContext &,const IArgs & iargs) const3511 IRet doApply (const EvalContext&, const IArgs& iargs) const
3512 {
3513 IRet ret;
3514 ret[0] = iargs.a;
3515 ret[1] = iargs.b;
3516 ret[2] = iargs.c;
3517 ret[3] = iargs.d;
3518 return ret;
3519 }
3520 };
3521
3522 template <typename T, int Rows>
mat2(const ExprP<Vector<T,Rows>> & arg0,const ExprP<Vector<T,Rows>> & arg1)3523 ExprP<Matrix<T, Rows, 2> > mat2 (const ExprP<Vector<T, Rows> >& arg0,
3524 const ExprP<Vector<T, Rows> >& arg1)
3525 {
3526 return app<GenMat<T, Rows, 2> >(arg0, arg1);
3527 }
3528
3529 template <typename T, int Rows>
mat3(const ExprP<Vector<T,Rows>> & arg0,const ExprP<Vector<T,Rows>> & arg1,const ExprP<Vector<T,Rows>> & arg2)3530 ExprP<Matrix<T, Rows, 3> > mat3 (const ExprP<Vector<T, Rows> >& arg0,
3531 const ExprP<Vector<T, Rows> >& arg1,
3532 const ExprP<Vector<T, Rows> >& arg2)
3533 {
3534 return app<GenMat<T, Rows, 3> >(arg0, arg1, arg2);
3535 }
3536
3537 template <typename T, int Rows>
mat4(const ExprP<Vector<T,Rows>> & arg0,const ExprP<Vector<T,Rows>> & arg1,const ExprP<Vector<T,Rows>> & arg2,const ExprP<Vector<T,Rows>> & arg3)3538 ExprP<Matrix<T, Rows, 4> > mat4 (const ExprP<Vector<T, Rows> >& arg0,
3539 const ExprP<Vector<T, Rows> >& arg1,
3540 const ExprP<Vector<T, Rows> >& arg2,
3541 const ExprP<Vector<T, Rows> >& arg3)
3542 {
3543 return app<GenMat<T, Rows, 4> >(arg0, arg1, arg2, arg3);
3544 }
3545
3546 template <typename T, int Rows, int Cols>
3547 class MatNeg : public PrimitiveFunc<Signature<Matrix<T, Rows, Cols>,
3548 Matrix<T, Rows, Cols> > >
3549 {
3550 public:
3551 typedef typename MatNeg::IRet IRet;
3552 typedef typename MatNeg::IArgs IArgs;
3553
getName(void) const3554 string getName (void) const
3555 {
3556 return "_matNeg";
3557 }
3558
3559 protected:
doPrint(ostream & os,const BaseArgExprs & args) const3560 void doPrint (ostream& os, const BaseArgExprs& args) const
3561 {
3562 os << "-(" << *args[0] << ")";
3563 }
3564
doApply(const EvalContext &,const IArgs & iargs) const3565 IRet doApply (const EvalContext&, const IArgs& iargs) const
3566 {
3567 IRet ret;
3568
3569 for (int col = 0; col < Cols; ++col)
3570 {
3571 for (int row = 0; row < Rows; ++row)
3572 ret[col][row] = -iargs.a[col][row];
3573 }
3574
3575 return ret;
3576 }
3577 };
3578
3579 template <typename T, typename Sig>
3580 class CompWiseFunc : public PrimitiveFunc<Sig>
3581 {
3582 public:
3583 typedef Func<Signature<T, T, T> > ScalarFunc;
3584
getName(void) const3585 string getName (void) const
3586 {
3587 return doGetScalarFunc().getName();
3588 }
3589 protected:
doPrint(ostream & os,const BaseArgExprs & args) const3590 void doPrint (ostream& os,
3591 const BaseArgExprs& args) const
3592 {
3593 doGetScalarFunc().print(os, args);
3594 }
3595
3596 virtual
3597 const ScalarFunc& doGetScalarFunc (void) const = 0;
3598 };
3599
3600 template <typename T, int Rows, int Cols>
3601 class CompMatFuncBase : public CompWiseFunc<T, Signature<Matrix<T, Rows, Cols>,
3602 Matrix<T, Rows, Cols>,
3603 Matrix<T, Rows, Cols> > >
3604 {
3605 public:
3606 typedef typename CompMatFuncBase::IRet IRet;
3607 typedef typename CompMatFuncBase::IArgs IArgs;
3608
3609 protected:
3610
doApply(const EvalContext & ctx,const IArgs & iargs) const3611 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
3612 {
3613 IRet ret;
3614
3615 for (int col = 0; col < Cols; ++col)
3616 {
3617 for (int row = 0; row < Rows; ++row)
3618 ret[col][row] = this->doGetScalarFunc().apply(ctx,
3619 iargs.a[col][row],
3620 iargs.b[col][row]);
3621 }
3622
3623 return ret;
3624 }
3625 };
3626
3627 template <typename F, typename T, int Rows, int Cols>
3628 class CompMatFunc : public CompMatFuncBase<T, Rows, Cols>
3629 {
3630 protected:
doGetScalarFunc(void) const3631 const typename CompMatFunc::ScalarFunc& doGetScalarFunc (void) const
3632 {
3633 return instance<F>();
3634 }
3635 };
3636
3637 template <class T>
3638 class ScalarMatrixCompMult : public Mul< Signature<T, T, T> >
3639 {
3640 public:
3641
getName(void) const3642 string getName (void) const
3643 {
3644 return "matrixCompMult";
3645 }
3646
doPrint(ostream & os,const BaseArgExprs & args) const3647 void doPrint (ostream& os, const BaseArgExprs& args) const
3648 {
3649 Func<Signature<T, T, T> >::doPrint(os, args);
3650 }
3651 };
3652
3653 template <int Rows, int Cols, class T>
3654 class MatrixCompMult : public CompMatFunc<ScalarMatrixCompMult<T>, T, Rows, Cols>
3655 {
3656 };
3657
3658 template <int Rows, int Cols>
3659 class ScalarMatFuncBase : public CompWiseFunc<float, Signature<Matrix<float, Rows, Cols>,
3660 Matrix<float, Rows, Cols>,
3661 float> >
3662 {
3663 public:
3664 typedef typename ScalarMatFuncBase::IRet IRet;
3665 typedef typename ScalarMatFuncBase::IArgs IArgs;
3666
3667 protected:
3668
doApply(const EvalContext & ctx,const IArgs & iargs) const3669 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
3670 {
3671 IRet ret;
3672
3673 for (int col = 0; col < Cols; ++col)
3674 {
3675 for (int row = 0; row < Rows; ++row)
3676 ret[col][row] = this->doGetScalarFunc().apply(ctx, iargs.a[col][row], iargs.b);
3677 }
3678
3679 return ret;
3680 }
3681 };
3682
3683 template <typename F, int Rows, int Cols>
3684 class ScalarMatFunc : public ScalarMatFuncBase<Rows, Cols>
3685 {
3686 protected:
doGetScalarFunc(void) const3687 const typename ScalarMatFunc::ScalarFunc& doGetScalarFunc (void) const
3688 {
3689 return instance<F>();
3690 }
3691 };
3692
3693 template<typename T, int Size> struct GenXType;
3694
3695 template<typename T>
3696 struct GenXType<T, 1>
3697 {
genXTypevkt::shaderexecutor::Functions::GenXType3698 static ExprP<T> genXType (const ExprP<T>& x) { return x; }
3699 };
3700
3701 template<typename T>
3702 struct GenXType<T, 2>
3703 {
genXTypevkt::shaderexecutor::Functions::GenXType3704 static ExprP<Vector<T, 2> > genXType (const ExprP<T>& x)
3705 {
3706 return app<GenVec<T, 2> >(x, x);
3707 }
3708 };
3709
3710 template<typename T>
3711 struct GenXType<T, 3>
3712 {
genXTypevkt::shaderexecutor::Functions::GenXType3713 static ExprP<Vector<T, 3> > genXType (const ExprP<T>& x)
3714 {
3715 return app<GenVec<T, 3> >(x, x, x);
3716 }
3717 };
3718
3719 template<typename T>
3720 struct GenXType<T, 4>
3721 {
genXTypevkt::shaderexecutor::Functions::GenXType3722 static ExprP<Vector<T, 4> > genXType (const ExprP<T>& x)
3723 {
3724 return app<GenVec<T, 4> >(x, x, x, x);
3725 }
3726 };
3727
3728 //! Returns an expression of vector of size `Size` (or scalar if Size == 1),
3729 //! with each element initialized with the expression `x`.
3730 template<typename T, int Size>
genXType(const ExprP<T> & x)3731 ExprP<typename ContainerOf<T, Size>::Container> genXType (const ExprP<T>& x)
3732 {
3733 return GenXType<T, Size>::genXType(x);
3734 }
3735
3736 typedef GenVec<float, 2> FloatVec2;
3737 DEFINE_CONSTRUCTOR2(FloatVec2, Vec2, vec2, float, float)
3738
3739 typedef GenVec<deFloat16, 2> FloatVec2_16bit;
3740 DEFINE_CONSTRUCTOR2(FloatVec2_16bit, Vec2_16Bit, vec2, deFloat16, deFloat16)
3741
3742 typedef GenVec<double, 2> DoubleVec2;
3743 DEFINE_CONSTRUCTOR2(DoubleVec2, Vec2_64Bit, vec2, double, double)
3744
3745 typedef GenVec<float, 3> FloatVec3;
3746 DEFINE_CONSTRUCTOR3(FloatVec3, Vec3, vec3, float, float, float)
3747
3748 typedef GenVec<deFloat16, 3> FloatVec3_16bit;
3749 DEFINE_CONSTRUCTOR3(FloatVec3_16bit, Vec3_16Bit, vec3, deFloat16, deFloat16, deFloat16)
3750
3751 typedef GenVec<double, 3> DoubleVec3;
3752 DEFINE_CONSTRUCTOR3(DoubleVec3, Vec3_64Bit, vec3, double, double, double)
3753
3754 typedef GenVec<float, 4> FloatVec4;
3755 DEFINE_CONSTRUCTOR4(FloatVec4, Vec4, vec4, float, float, float, float)
3756
3757 typedef GenVec<deFloat16, 4> FloatVec4_16bit;
3758 DEFINE_CONSTRUCTOR4(FloatVec4_16bit, Vec4_16Bit, vec4, deFloat16, deFloat16, deFloat16, deFloat16)
3759
3760 typedef GenVec<double, 4> DoubleVec4;
3761 DEFINE_CONSTRUCTOR4(DoubleVec4, Vec4_64Bit, vec4, double, double, double, double)
3762
3763 template <class T>
3764 const ExprP<T> getConstZero(void);
3765 template <class T>
3766 const ExprP<T> getConstOne(void);
3767 template <class T>
3768 const ExprP<T> getConstTwo(void);
3769
3770 template <>
getConstZero(void)3771 const ExprP<float> getConstZero<float>(void)
3772 {
3773 return constant(0.0f);
3774 }
3775
3776 template <>
getConstZero(void)3777 const ExprP<deFloat16> getConstZero<deFloat16>(void)
3778 {
3779 return constant((deFloat16)FLOAT16_0_0);
3780 }
3781
3782 template <>
getConstZero(void)3783 const ExprP<double> getConstZero<double>(void)
3784 {
3785 return constant(0.0);
3786 }
3787
3788 template <>
getConstOne(void)3789 const ExprP<float> getConstOne<float>(void)
3790 {
3791 return constant(1.0f);
3792 }
3793
3794 template <>
getConstOne(void)3795 const ExprP<deFloat16> getConstOne<deFloat16>(void)
3796 {
3797 return constant((deFloat16)FLOAT16_1_0);
3798 }
3799
3800 template <>
getConstOne(void)3801 const ExprP<double> getConstOne<double>(void)
3802 {
3803 return constant(1.0);
3804 }
3805
3806 template <>
getConstTwo(void)3807 const ExprP<float> getConstTwo<float>(void)
3808 {
3809 return constant(2.0f);
3810 }
3811
3812 template <>
getConstTwo(void)3813 const ExprP<deFloat16> getConstTwo<deFloat16>(void)
3814 {
3815 return constant((deFloat16)FLOAT16_2_0);
3816 }
3817
3818 template <>
getConstTwo(void)3819 const ExprP<double> getConstTwo<double>(void)
3820 {
3821 return constant(2.0);
3822 }
3823
3824 template <int Size, class T>
3825 class Dot : public DerivedFunc<Signature<T, Vector<T, Size>, Vector<T, Size> > >
3826 {
3827 public:
3828 typedef typename Dot::ArgExprs ArgExprs;
3829
getName(void) const3830 string getName (void) const
3831 {
3832 return "dot";
3833 }
3834
3835 protected:
doExpand(ExpandContext &,const ArgExprs & args) const3836 ExprP<T> doExpand (ExpandContext&, const ArgExprs& args) const
3837 {
3838 ExprP<T> op[Size];
3839 // Precompute all products.
3840 for (int ndx = 0; ndx < Size; ++ndx)
3841 op[ndx] = args.a[ndx] * args.b[ndx];
3842
3843 int idx[Size];
3844 //Prepare an array of indices.
3845 for (int ndx = 0; ndx < Size; ++ndx)
3846 idx[ndx] = ndx;
3847
3848 ExprP<T> res = op[0];
3849 // Compute the first dot alternative: SUM(a[i]*b[i]), i = 0 .. Size-1
3850 for (int ndx = 1; ndx < Size; ++ndx)
3851 res = res + op[ndx];
3852
3853 // Generate all permutations of indices and
3854 // using a permutation compute a dot alternative.
3855 // Generates all possible variants fo summation of products in the dot product expansion expression.
3856 do {
3857 ExprP<T> alt = getConstZero<T>();
3858 for (int ndx = 0; ndx < Size; ++ndx)
3859 alt = alt + op[idx[ndx]];
3860 res = alternatives(res, alt);
3861 } while (std::next_permutation(idx, idx + Size));
3862
3863 return res;
3864 }
3865 };
3866
3867 template <class T>
3868 class Dot<1, T> : public DerivedFunc<Signature<T, T, T> >
3869 {
3870 public:
3871 typedef typename DerivedFunc<Signature<T, T, T> >::ArgExprs TArgExprs;
3872
getName(void) const3873 string getName (void) const
3874 {
3875 return "dot";
3876 }
3877
doExpand(ExpandContext &,const TArgExprs & args) const3878 ExprP<T> doExpand (ExpandContext&, const TArgExprs& args) const
3879 {
3880 return args.a * args.b;
3881 }
3882 };
3883
3884 template <int Size>
dot(const ExprP<Vector<deFloat16,Size>> & x,const ExprP<Vector<deFloat16,Size>> & y)3885 ExprP<deFloat16> dot (const ExprP<Vector<deFloat16, Size> >& x, const ExprP<Vector<deFloat16, Size> >& y)
3886 {
3887 return app<Dot<Size, deFloat16> >(x, y);
3888 }
3889
dot(const ExprP<deFloat16> & x,const ExprP<deFloat16> & y)3890 ExprP<deFloat16> dot (const ExprP<deFloat16>& x, const ExprP<deFloat16>& y)
3891 {
3892 return app<Dot<1, deFloat16> >(x, y);
3893 }
3894
3895 template <int Size>
dot(const ExprP<Vector<float,Size>> & x,const ExprP<Vector<float,Size>> & y)3896 ExprP<float> dot (const ExprP<Vector<float, Size> >& x, const ExprP<Vector<float, Size> >& y)
3897 {
3898 return app<Dot<Size, float> >(x, y);
3899 }
3900
dot(const ExprP<float> & x,const ExprP<float> & y)3901 ExprP<float> dot (const ExprP<float>& x, const ExprP<float>& y)
3902 {
3903 return app<Dot<1, float> >(x, y);
3904 }
3905
3906 template <int Size>
dot(const ExprP<Vector<double,Size>> & x,const ExprP<Vector<double,Size>> & y)3907 ExprP<double> dot (const ExprP<Vector<double, Size> >& x, const ExprP<Vector<double, Size> >& y)
3908 {
3909 return app<Dot<Size, double> >(x, y);
3910 }
3911
dot(const ExprP<double> & x,const ExprP<double> & y)3912 ExprP<double> dot (const ExprP<double>& x, const ExprP<double>& y)
3913 {
3914 return app<Dot<1, double> >(x, y);
3915 }
3916
3917 template <int Size, class T>
3918 class Length : public DerivedFunc<
3919 Signature<T, typename ContainerOf<T, Size>::Container> >
3920 {
3921 public:
3922 typedef typename Length::ArgExprs ArgExprs;
3923
getName(void) const3924 string getName (void) const
3925 {
3926 return "length";
3927 }
3928
3929 protected:
doExpand(ExpandContext &,const ArgExprs & args) const3930 ExprP<T> doExpand (ExpandContext&, const ArgExprs& args) const
3931 {
3932 return sqrt(dot(args.a, args.a));
3933 }
3934 };
3935
3936
3937 template <class T, class TRet>
length(const ExprP<T> & x)3938 ExprP<TRet> length (const ExprP<T>& x)
3939 {
3940 return app<Length<1, T> >(x);
3941 }
3942
3943 template <int Size, class T, class TRet>
length(const ExprP<typename ContainerOf<T,Size>::Container> & x)3944 ExprP<TRet> length (const ExprP<typename ContainerOf<T, Size>::Container>& x)
3945 {
3946 return app<Length<Size, T> >(x);
3947 }
3948
3949 template <int Size, class T>
3950 class Distance : public DerivedFunc<
3951 Signature<T,
3952 typename ContainerOf<T, Size>::Container,
3953 typename ContainerOf<T, Size>::Container> >
3954 {
3955 public:
3956 typedef typename Distance::Ret Ret;
3957 typedef typename Distance::ArgExprs ArgExprs;
3958
getName(void) const3959 string getName (void) const
3960 {
3961 return "distance";
3962 }
3963
3964 protected:
doExpand(ExpandContext &,const ArgExprs & args) const3965 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
3966 {
3967 return length<Size, T, Ret>(args.a - args.b);
3968 }
3969 };
3970
3971 // cross
3972
3973 class Cross : public DerivedFunc<Signature<Vec3, Vec3, Vec3> >
3974 {
3975 public:
getName(void) const3976 string getName (void) const
3977 {
3978 return "cross";
3979 }
3980
3981 protected:
doExpand(ExpandContext &,const ArgExprs & x) const3982 ExprP<Vec3> doExpand (ExpandContext&, const ArgExprs& x) const
3983 {
3984 return vec3(x.a[1] * x.b[2] - x.b[1] * x.a[2],
3985 x.a[2] * x.b[0] - x.b[2] * x.a[0],
3986 x.a[0] * x.b[1] - x.b[0] * x.a[1]);
3987 }
3988 };
3989
3990 class Cross16Bit : public DerivedFunc<Signature<Vec3_16Bit, Vec3_16Bit, Vec3_16Bit> >
3991 {
3992 public:
getName(void) const3993 string getName (void) const
3994 {
3995 return "cross";
3996 }
3997
3998 protected:
doExpand(ExpandContext &,const ArgExprs & x) const3999 ExprP<Vec3_16Bit> doExpand (ExpandContext&, const ArgExprs& x) const
4000 {
4001 return vec3(x.a[1] * x.b[2] - x.b[1] * x.a[2],
4002 x.a[2] * x.b[0] - x.b[2] * x.a[0],
4003 x.a[0] * x.b[1] - x.b[0] * x.a[1]);
4004 }
4005 };
4006
4007 class Cross64Bit : public DerivedFunc<Signature<Vec3_64Bit, Vec3_64Bit, Vec3_64Bit> >
4008 {
4009 public:
getName(void) const4010 string getName (void) const
4011 {
4012 return "cross";
4013 }
4014
4015 protected:
doExpand(ExpandContext &,const ArgExprs & x) const4016 ExprP<Vec3_64Bit> doExpand (ExpandContext&, const ArgExprs& x) const
4017 {
4018 return vec3(x.a[1] * x.b[2] - x.b[1] * x.a[2],
4019 x.a[2] * x.b[0] - x.b[2] * x.a[0],
4020 x.a[0] * x.b[1] - x.b[0] * x.a[1]);
4021 }
4022 };
4023
4024 DEFINE_CONSTRUCTOR2(Cross, Vec3, cross, Vec3, Vec3)
4025 DEFINE_CONSTRUCTOR2(Cross16Bit, Vec3_16Bit, cross, Vec3_16Bit, Vec3_16Bit)
4026 DEFINE_CONSTRUCTOR2(Cross64Bit, Vec3_64Bit, cross, Vec3_64Bit, Vec3_64Bit)
4027
4028 template<int Size, class T>
4029 class Normalize : public DerivedFunc<
4030 Signature<typename ContainerOf<T, Size>::Container,
4031 typename ContainerOf<T, Size>::Container> >
4032 {
4033 public:
4034 typedef typename Normalize::Ret Ret;
4035 typedef typename Normalize::ArgExprs ArgExprs;
4036
getName(void) const4037 string getName (void) const
4038 {
4039 return "normalize";
4040 }
4041
4042 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4043 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
4044 {
4045 return args.a / length<Size, T, T>(args.a);
4046 }
4047 };
4048
4049 template <int Size, class T>
4050 class FaceForward : public DerivedFunc<
4051 Signature<typename ContainerOf<T, Size>::Container,
4052 typename ContainerOf<T, Size>::Container,
4053 typename ContainerOf<T, Size>::Container,
4054 typename ContainerOf<T, Size>::Container> >
4055 {
4056 public:
4057 typedef typename FaceForward::Ret Ret;
4058 typedef typename FaceForward::ArgExprs ArgExprs;
4059
getName(void) const4060 string getName (void) const
4061 {
4062 return "faceforward";
4063 }
4064
4065 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4066 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
4067 {
4068 return cond(dot(args.c, args.b) < getConstZero<T>(), args.a, -args.a);
4069 }
4070 };
4071
4072 template <int Size, class T>
4073 class Reflect : public DerivedFunc<
4074 Signature<typename ContainerOf<T, Size>::Container,
4075 typename ContainerOf<T, Size>::Container,
4076 typename ContainerOf<T, Size>::Container> >
4077 {
4078 public:
4079 typedef typename Reflect::Ret Ret;
4080 typedef typename Reflect::Arg0 Arg0;
4081 typedef typename Reflect::Arg1 Arg1;
4082 typedef typename Reflect::ArgExprs ArgExprs;
4083
getName(void) const4084 string getName (void) const
4085 {
4086 return "reflect";
4087 }
4088
4089 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4090 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
4091 {
4092 const ExprP<Arg0>& i = args.a;
4093 const ExprP<Arg1>& n = args.b;
4094 const ExprP<T> dotNI = bindExpression("dotNI", ctx, dot(n, i));
4095
4096 return i - alternatives((n * dotNI) * getConstTwo<T>(),
4097 alternatives( n * (dotNI * getConstTwo<T>()),
4098 alternatives(n * dot(i * getConstTwo<T>(), n),
4099 n * dot(i, n * getConstTwo<T>())
4100 )
4101 )
4102 );
4103 }
4104 };
4105
4106 template <class T>
4107 class Reflect<1, T> : public DerivedFunc<
4108 Signature<T, T, T> >
4109 {
4110 public:
4111 typedef typename Reflect::Ret Ret;
4112 typedef typename Reflect::Arg0 Arg0;
4113 typedef typename Reflect::Arg1 Arg1;
4114 typedef typename Reflect::ArgExprs ArgExprs;
4115
getName(void) const4116 string getName (void) const
4117 {
4118 return "reflect";
4119 }
4120
4121 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4122 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
4123 {
4124 const ExprP<Arg0>& i = args.a;
4125 const ExprP<Arg1>& n = args.b;
4126 const ExprP<T> dotNI = bindExpression("dotNI", ctx, dot(n, i));
4127
4128 return i - alternatives((n * dotNI) * getConstTwo<T>(),
4129 alternatives( n * (dotNI * getConstTwo<T>()),
4130 alternatives(n * dot(i * getConstTwo<T>(), n),
4131 alternatives(n * dot(i, n * getConstTwo<T>()),
4132 dot(n * n, i * getConstTwo<T>()))
4133 )
4134 )
4135 );
4136 }
4137 };
4138
4139 template <int Size, class T>
4140 class Refract : public DerivedFunc<
4141 Signature<typename ContainerOf<T, Size>::Container,
4142 typename ContainerOf<T, Size>::Container,
4143 typename ContainerOf<T, Size>::Container,
4144 T> >
4145 {
4146 public:
4147 typedef typename Refract::Ret Ret;
4148 typedef typename Refract::Arg0 Arg0;
4149 typedef typename Refract::Arg1 Arg1;
4150 typedef typename Refract::Arg2 Arg2;
4151 typedef typename Refract::ArgExprs ArgExprs;
4152
getName(void) const4153 string getName (void) const
4154 {
4155 return "refract";
4156 }
4157
4158 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4159 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
4160 {
4161 const ExprP<Arg0>& i = args.a;
4162 const ExprP<Arg1>& n = args.b;
4163 const ExprP<Arg2>& eta = args.c;
4164 const ExprP<T> dotNI = bindExpression("dotNI", ctx, dot(n, i));
4165 const ExprP<T> k = bindExpression("k", ctx, getConstOne<T>() - eta * eta *
4166 (getConstOne<T>() - dotNI * dotNI));
4167 return cond(k < getConstZero<T>(),
4168 genXType<T, Size>(getConstZero<T>()),
4169 i * eta - n * (eta * dotNI + sqrt(k)));
4170 }
4171 };
4172
4173 template <class T>
4174 class PreciseFunc1 : public CFloatFunc1<T>
4175 {
4176 public:
PreciseFunc1(const string & name,DoubleFunc1 & func)4177 PreciseFunc1 (const string& name, DoubleFunc1& func) : CFloatFunc1<T> (name, func) {}
4178 protected:
precision(const EvalContext &,double,double) const4179 double precision (const EvalContext&, double, double) const { return 0.0; }
4180 };
4181
4182 template <class T>
4183 class Abs : public PreciseFunc1<T>
4184 {
4185 public:
Abs(void)4186 Abs (void) : PreciseFunc1<T> ("abs", deAbs) {}
4187 };
4188
4189 template <class T>
4190 class Sign : public PreciseFunc1<T>
4191 {
4192 public:
Sign(void)4193 Sign (void) : PreciseFunc1<T> ("sign", deSign) {}
4194 };
4195
4196 template <class T>
4197 class Floor : public PreciseFunc1<T>
4198 {
4199 public:
Floor(void)4200 Floor (void) : PreciseFunc1<T> ("floor", deFloor) {}
4201 };
4202
4203 template <class T>
4204 class Trunc : public PreciseFunc1<T>
4205 {
4206 public:
Trunc(void)4207 Trunc (void) : PreciseFunc1<T> ("trunc", deTrunc) {}
4208 };
4209
4210 template <class T>
4211 class Round : public FloatFunc1<T>
4212 {
4213 public:
getName(void) const4214 string getName (void) const { return "round"; }
4215
4216 protected:
applyPoint(const EvalContext &,double x) const4217 Interval applyPoint (const EvalContext&, double x) const
4218 {
4219 double truncated = 0.0;
4220 const double fract = deModf(x, &truncated);
4221 Interval ret;
4222
4223 if (fabs(fract) <= 0.5)
4224 ret |= truncated;
4225 if (fabs(fract) >= 0.5)
4226 ret |= truncated + deSign(fract);
4227
4228 return ret;
4229 }
4230
precision(const EvalContext &,double,double) const4231 double precision (const EvalContext&, double, double) const { return 0.0; }
4232 };
4233
4234 template <class T>
4235 class RoundEven : public PreciseFunc1<T>
4236 {
4237 public:
RoundEven(void)4238 RoundEven (void) : PreciseFunc1<T> ("roundEven", deRoundEven) {}
4239 };
4240
4241 template <class T>
4242 class Ceil : public PreciseFunc1<T>
4243 {
4244 public:
Ceil(void)4245 Ceil (void) : PreciseFunc1<T> ("ceil", deCeil) {}
4246 };
4247
4248 typedef Floor< Signature<float, float> > Floor32Bit;
4249 typedef Floor< Signature<deFloat16, deFloat16> > Floor16Bit;
4250 typedef Floor< Signature<double, double> > Floor64Bit;
4251
4252 typedef Trunc< Signature<float, float> > Trunc32Bit;
4253 typedef Trunc< Signature<deFloat16, deFloat16> > Trunc16Bit;
4254 typedef Trunc< Signature<double, double> > Trunc64Bit;
4255
4256 typedef Trunc< Signature<float, float> > Trunc32Bit;
4257 typedef Trunc< Signature<deFloat16, deFloat16> > Trunc16Bit;
4258
4259 DEFINE_DERIVED_FLOAT1(Fract, fract, x, x - app<Floor32Bit>(x))
4260 DEFINE_DERIVED_FLOAT1_16BIT(Fract16Bit, fract, x, x - app<Floor16Bit>(x))
4261 DEFINE_DERIVED_DOUBLE1(Fract64Bit, fract, x, x - app<Floor64Bit>(x))
4262
4263 template <class T>
4264 class PreciseFunc2 : public CFloatFunc2<T>
4265 {
4266 public:
PreciseFunc2(const string & name,DoubleFunc2 & func)4267 PreciseFunc2 (const string& name, DoubleFunc2& func) : CFloatFunc2<T> (name, func) {}
4268 protected:
precision(const EvalContext &,double,double,double) const4269 double precision (const EvalContext&, double, double, double) const { return 0.0; }
4270 };
4271
4272 DEFINE_DERIVED_FLOAT2(Mod32Bit, mod, x, y, x - y * app<Floor32Bit>(x / y))
4273 DEFINE_DERIVED_FLOAT2_16BIT(Mod16Bit, mod, x, y, x - y * app<Floor16Bit>(x / y))
4274 DEFINE_DERIVED_DOUBLE2(Mod64Bit, mod, x, y, x - y * app<Floor64Bit>(x / y))
4275
4276 DEFINE_CASED_DERIVED_FLOAT2(FRem32Bit, frem, x, y, x - y * app<Trunc32Bit>(x / y), SPIRV_CASETYPE_FREM)
4277 DEFINE_CASED_DERIVED_FLOAT2_16BIT(FRem16Bit, frem, x, y, x - y * app<Trunc16Bit>(x / y), SPIRV_CASETYPE_FREM)
4278 DEFINE_CASED_DERIVED_DOUBLE2(FRem64Bit, frem, x, y, x - y * app<Trunc64Bit>(x / y), SPIRV_CASETYPE_FREM)
4279
4280 template <class T>
4281 class Modf : public PrimitiveFunc<T>
4282 {
4283 public:
4284 typedef typename Modf<T>::IArgs TIArgs;
4285 typedef typename Modf<T>::IRet TIRet;
getName(void) const4286 string getName (void) const
4287 {
4288 return "modf";
4289 }
4290
4291 protected:
doApply(const EvalContext & ctx,const TIArgs & iargs) const4292 TIRet doApply (const EvalContext& ctx, const TIArgs& iargs) const
4293 {
4294 Interval fracIV;
4295 Interval& wholeIV = const_cast<Interval&>(iargs.b);
4296 double intPart = 0;
4297
4298 TCU_INTERVAL_APPLY_MONOTONE1(fracIV, x, iargs.a, frac, frac = deModf(x, &intPart));
4299 TCU_INTERVAL_APPLY_MONOTONE1(wholeIV, x, iargs.a, whole,
4300 deModf(x, &intPart); whole = intPart);
4301
4302 if (!iargs.a.isFinite(ctx.format.getMaxValue()))
4303 {
4304 // Behavior on modf(Inf) not well-defined, allow anything as a fractional part
4305 // See Khronos bug 13907
4306 fracIV |= TCU_NAN;
4307 }
4308
4309 return fracIV;
4310 }
4311
getOutParamIndex(void) const4312 int getOutParamIndex (void) const
4313 {
4314 return 1;
4315 }
4316 };
4317 typedef Modf< Signature<float, float, float> > Modf32Bit;
4318 typedef Modf< Signature<deFloat16, deFloat16, deFloat16> > Modf16Bit;
4319 typedef Modf< Signature<double, double, double> > Modf64Bit;
4320
4321 template <class T>
4322 class ModfStruct : public Modf<T>
4323 {
4324 public:
getName(void) const4325 virtual string getName (void) const { return "modfstruct"; }
getSpirvCase(void) const4326 virtual SpirVCaseT getSpirvCase (void) const { return SPIRV_CASETYPE_MODFSTRUCT; }
4327 };
4328 typedef ModfStruct< Signature<float, float, float> > ModfStruct32Bit;
4329 typedef ModfStruct< Signature<deFloat16, deFloat16, deFloat16> > ModfStruct16Bit;
4330 typedef ModfStruct< Signature<double, double, double> > ModfStruct64Bit;
4331
4332 template <class T>
Min(void)4333 class Min : public PreciseFunc2<T> { public: Min (void) : PreciseFunc2<T> ("min", deMin) {} };
4334 template <class T>
Max(void)4335 class Max : public PreciseFunc2<T> { public: Max (void) : PreciseFunc2<T> ("max", deMax) {} };
4336
4337 template <class T>
4338 class Clamp : public FloatFunc3<T>
4339 {
4340 public:
getName(void) const4341 string getName (void) const { return "clamp"; }
4342
applyExact(double x,double minVal,double maxVal) const4343 double applyExact (double x, double minVal, double maxVal) const
4344 {
4345 return de::min(de::max(x, minVal), maxVal);
4346 }
4347
precision(const EvalContext &,double,double,double minVal,double maxVal) const4348 double precision (const EvalContext&, double, double, double minVal, double maxVal) const
4349 {
4350 return minVal > maxVal ? TCU_NAN : 0.0;
4351 }
4352 };
4353
clamp(const ExprP<deFloat16> & x,const ExprP<deFloat16> & minVal,const ExprP<deFloat16> & maxVal)4354 ExprP<deFloat16> clamp(const ExprP<deFloat16>& x, const ExprP<deFloat16>& minVal, const ExprP<deFloat16>& maxVal)
4355 {
4356 return app<Clamp< Signature<deFloat16, deFloat16, deFloat16, deFloat16> > >(x, minVal, maxVal);
4357 }
4358
clamp(const ExprP<float> & x,const ExprP<float> & minVal,const ExprP<float> & maxVal)4359 ExprP<float> clamp(const ExprP<float>& x, const ExprP<float>& minVal, const ExprP<float>& maxVal)
4360 {
4361 return app<Clamp< Signature<float, float, float, float> > >(x, minVal, maxVal);
4362 }
4363
clamp(const ExprP<double> & x,const ExprP<double> & minVal,const ExprP<double> & maxVal)4364 ExprP<double> clamp(const ExprP<double>& x, const ExprP<double>& minVal, const ExprP<double>& maxVal)
4365 {
4366 return app<Clamp< Signature<double, double, double, double> > >(x, minVal, maxVal);
4367 }
4368
4369 template <class T>
4370 class NanIfGreaterOrEqual : public FloatFunc2<T>
4371 {
4372 public:
getName(void) const4373 string getName (void) const { return "nanIfGreaterOrEqual"; }
4374
applyExact(double edge0,double edge1) const4375 double applyExact (double edge0, double edge1) const
4376 {
4377 return (edge0 >= edge1) ? TCU_NAN : 0.0;
4378 }
4379
precision(const EvalContext &,double,double edge0,double edge1) const4380 double precision (const EvalContext&, double, double edge0, double edge1) const
4381 {
4382 return (edge0 >= edge1) ? TCU_NAN : 0.0;
4383 }
4384 };
4385
nanIfGreaterOrEqual(const ExprP<deFloat16> & edge0,const ExprP<deFloat16> & edge1)4386 ExprP<deFloat16> nanIfGreaterOrEqual(const ExprP<deFloat16>& edge0, const ExprP<deFloat16>& edge1)
4387 {
4388 return app<NanIfGreaterOrEqual< Signature<deFloat16, deFloat16, deFloat16> > >(edge0, edge1);
4389 }
4390
nanIfGreaterOrEqual(const ExprP<float> & edge0,const ExprP<float> & edge1)4391 ExprP<float> nanIfGreaterOrEqual(const ExprP<float>& edge0, const ExprP<float>& edge1)
4392 {
4393 return app<NanIfGreaterOrEqual< Signature<float, float, float> > >(edge0, edge1);
4394 }
4395
nanIfGreaterOrEqual(const ExprP<double> & edge0,const ExprP<double> & edge1)4396 ExprP<double> nanIfGreaterOrEqual(const ExprP<double>& edge0, const ExprP<double>& edge1)
4397 {
4398 return app<NanIfGreaterOrEqual< Signature<double, double, double> > >(edge0, edge1);
4399 }
4400
4401 DEFINE_DERIVED_FLOAT3(Mix, mix, x, y, a, alternatives((x * (constant(1.0f) - a)) + y * a,
4402 x + (y - x) * a))
4403
4404 DEFINE_DERIVED_FLOAT3_16BIT(Mix16Bit, mix, x, y, a, alternatives((x * (constant((deFloat16)FLOAT16_1_0) - a)) + y * a,
4405 x + (y - x) * a))
4406
4407 DEFINE_DERIVED_DOUBLE3(Mix64Bit, mix, x, y, a, alternatives((x * (constant(1.0) - a)) + y * a,
4408 x + (y - x) * a))
4409
step(double edge,double x)4410 static double step (double edge, double x)
4411 {
4412 return x < edge ? 0.0 : 1.0;
4413 }
4414
4415 template <class T>
Step(void)4416 class Step : public PreciseFunc2<T> { public: Step (void) : PreciseFunc2<T> ("step", step) {} };
4417
4418 template <class T>
4419 class SmoothStep : public DerivedFunc<T>
4420 {
4421 public:
4422 typedef typename SmoothStep<T>::ArgExprs TArgExprs;
4423 typedef typename SmoothStep<T>::Ret TRet;
getName(void) const4424 string getName (void) const
4425 {
4426 return "smoothstep";
4427 }
4428
4429 protected:
4430
4431 ExprP<TRet> doExpand (ExpandContext& ctx, const TArgExprs& args) const;
4432 };
4433
4434 template<>
doExpand(ExpandContext & ctx,const SmoothStep<Signature<float,float,float,float>>::ArgExprs & args) const4435 ExprP<SmoothStep< Signature<float, float, float, float> >::Ret> SmoothStep< Signature<float, float, float, float> >::doExpand (ExpandContext& ctx, const SmoothStep< Signature<float, float, float, float> >::ArgExprs& args) const
4436 {
4437 const ExprP<float>& edge0 = args.a;
4438 const ExprP<float>& edge1 = args.b;
4439 const ExprP<float>& x = args.c;
4440 const ExprP<float> tExpr = clamp((x - edge0) / (edge1 - edge0), constant(0.0f), constant(1.0f))
4441 + nanIfGreaterOrEqual(edge0, edge1); // force NaN (and non-analyzable result) for cases edge0 >= edge1
4442 const ExprP<float> t = bindExpression("t", ctx, tExpr);
4443
4444 return (t * t * (constant(3.0f) - constant(2.0f) * t));
4445 }
4446
4447 template<>
doExpand(ExpandContext & ctx,const TArgExprs & args) const4448 ExprP<SmoothStep< Signature<deFloat16, deFloat16, deFloat16, deFloat16> >::TRet> SmoothStep< Signature<deFloat16, deFloat16, deFloat16, deFloat16> >::doExpand (ExpandContext& ctx, const TArgExprs& args) const
4449 {
4450 const ExprP<deFloat16>& edge0 = args.a;
4451 const ExprP<deFloat16>& edge1 = args.b;
4452 const ExprP<deFloat16>& x = args.c;
4453 const ExprP<deFloat16> tExpr = clamp(( x - edge0 ) / ( edge1 - edge0 ),
4454 constant((deFloat16)FLOAT16_0_0), constant((deFloat16)FLOAT16_1_0))
4455 + nanIfGreaterOrEqual(edge0, edge1); // force NaN (and non-analyzable result) for cases edge0 >= edge1
4456 const ExprP<deFloat16> t = bindExpression("t", ctx, tExpr);
4457
4458 return (t * t * (constant((deFloat16)FLOAT16_3_0) - constant((deFloat16)FLOAT16_2_0) * t));
4459 }
4460
4461 template<>
doExpand(ExpandContext & ctx,const SmoothStep<Signature<double,double,double,double>>::ArgExprs & args) const4462 ExprP<SmoothStep< Signature<double, double, double, double> >::Ret> SmoothStep< Signature<double, double, double, double> >::doExpand (ExpandContext& ctx, const SmoothStep< Signature<double, double, double, double> >::ArgExprs& args) const
4463 {
4464 const ExprP<double>& edge0 = args.a;
4465 const ExprP<double>& edge1 = args.b;
4466 const ExprP<double>& x = args.c;
4467 const ExprP<double> tExpr = clamp((x - edge0) / (edge1 - edge0), constant(0.0), constant(1.0))
4468 + nanIfGreaterOrEqual(edge0, edge1); // force NaN (and non-analyzable result) for cases edge0 >= edge1
4469 const ExprP<double> t = bindExpression("t", ctx, tExpr);
4470
4471 return (t * t * (constant(3.0) - constant(2.0) * t));
4472 }
4473
4474 //Signature<float, float, int>
4475 //Signature<float, deFloat16, int>
4476 //Signature<double, double, int>
4477 template <class T>
4478 class FrExp : public PrimitiveFunc<T>
4479 {
4480 public:
getName(void) const4481 string getName (void) const
4482 {
4483 return "frexp";
4484 }
4485
4486 typedef typename FrExp::IRet IRet;
4487 typedef typename FrExp::IArgs IArgs;
4488 typedef typename FrExp::IArg0 IArg0;
4489 typedef typename FrExp::IArg1 IArg1;
4490
4491 protected:
doApply(const EvalContext &,const IArgs & iargs) const4492 IRet doApply (const EvalContext&, const IArgs& iargs) const
4493 {
4494 IRet ret;
4495 const IArg0& x = iargs.a;
4496 IArg1& exponent = const_cast<IArg1&>(iargs.b);
4497
4498 if (x.hasNaN() || x.contains(TCU_INFINITY) || x.contains(-TCU_INFINITY))
4499 {
4500 // GLSL (in contrast to IEEE) says that result of applying frexp
4501 // to infinity is undefined
4502 ret = Interval::unbounded() | TCU_NAN;
4503 exponent = Interval(-deLdExp(1.0, 31), deLdExp(1.0, 31)-1);
4504 }
4505 else if (!x.empty())
4506 {
4507 int loExp = 0;
4508 const double loFrac = deFrExp(x.lo(), &loExp);
4509 int hiExp = 0;
4510 const double hiFrac = deFrExp(x.hi(), &hiExp);
4511
4512 if (deSign(loFrac) != deSign(hiFrac))
4513 {
4514 exponent = Interval(-TCU_INFINITY, de::max(loExp, hiExp));
4515 ret = Interval();
4516 if (deSign(loFrac) < 0)
4517 ret |= Interval(-1.0 + DBL_EPSILON*0.5, 0.0);
4518 if (deSign(hiFrac) > 0)
4519 ret |= Interval(0.0, 1.0 - DBL_EPSILON*0.5);
4520 }
4521 else
4522 {
4523 exponent = Interval(loExp, hiExp);
4524 if (loExp == hiExp)
4525 ret = Interval(loFrac, hiFrac);
4526 else
4527 ret = deSign(loFrac) * Interval(0.5, 1.0 - DBL_EPSILON*0.5);
4528 }
4529 }
4530
4531 return ret;
4532 }
4533
getOutParamIndex(void) const4534 int getOutParamIndex (void) const
4535 {
4536 return 1;
4537 }
4538 };
4539 typedef FrExp< Signature<float, float, int> > Frexp32Bit;
4540 typedef FrExp< Signature<deFloat16, deFloat16, int> > Frexp16Bit;
4541 typedef FrExp< Signature<double, double, int> > Frexp64Bit;
4542
4543 template <class T>
4544 class FrexpStruct : public FrExp<T>
4545 {
4546 public:
getName(void) const4547 virtual string getName (void) const { return "frexpstruct"; }
getSpirvCase(void) const4548 virtual SpirVCaseT getSpirvCase (void) const { return SPIRV_CASETYPE_FREXPSTRUCT; }
4549 };
4550 typedef FrexpStruct< Signature<float, float, int> > FrexpStruct32Bit;
4551 typedef FrexpStruct< Signature<deFloat16, deFloat16, int> > FrexpStruct16Bit;
4552 typedef FrexpStruct< Signature<double, double, int> > FrexpStruct64Bit;
4553
4554 //Signature<float, float, int>
4555 //Signature<deFloat16, deFloat16, int>
4556 //Signature<double, double, int>
4557 template <class T>
4558 class LdExp : public PrimitiveFunc<T >
4559 {
4560 public:
4561 typedef typename LdExp::IRet IRet;
4562 typedef typename LdExp::IArgs IArgs;
4563
getName(void) const4564 string getName (void) const
4565 {
4566 return "ldexp";
4567 }
4568
4569 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const4570 Interval doApply (const EvalContext& ctx, const IArgs& iargs) const
4571 {
4572 const int minExp = ctx.format.getMinExp();
4573 const int maxExp = ctx.format.getMaxExp();
4574 // Restrictions from the GLSL.std.450 instruction set.
4575 // See Khronos bugzilla 11180 for rationale.
4576 bool any = iargs.a.hasNaN() || iargs.b.hi() > (maxExp + 1);
4577 Interval ret(any, ldexp(iargs.a.lo(), (int)iargs.b.lo()), ldexp(iargs.a.hi(), (int)iargs.b.hi()));
4578 if (iargs.b.lo() < minExp) ret |= 0.0;
4579 if (!ret.isFinite(ctx.format.getMaxValue())) ret |= TCU_NAN;
4580 return ctx.format.convert(ret);
4581 }
4582 };
4583
4584 template <>
doApply(const EvalContext & ctx,const IArgs & iargs) const4585 Interval LdExp <Signature<double, double, int>>::doApply(const EvalContext& ctx, const IArgs& iargs) const
4586 {
4587 const int minExp = ctx.format.getMinExp();
4588 const int maxExp = ctx.format.getMaxExp();
4589 // Restrictions from the GLSL.std.450 instruction set.
4590 // See Khronos bugzilla 11180 for rationale.
4591 bool any = iargs.a.hasNaN() || iargs.b.hi() > (maxExp + 1);
4592 Interval ret(any, ldexp(iargs.a.lo(), (int)iargs.b.lo()), ldexp(iargs.a.hi(), (int)iargs.b.hi()));
4593 // Add 1ULP precision tolerance to account for differing rounding modes between the GPU and deLdExp.
4594 ret += Interval(-ctx.format.ulp(ret.lo()), ctx.format.ulp(ret.hi()));
4595 if (iargs.b.lo() < minExp) ret |= 0.0;
4596 if (!ret.isFinite(ctx.format.getMaxValue())) ret |= TCU_NAN;
4597 return ctx.format.convert(ret);
4598 }
4599
4600 template<int Rows, int Columns, class T>
4601 class Transpose : public PrimitiveFunc<Signature<Matrix<T, Rows, Columns>,
4602 Matrix<T, Columns, Rows> > >
4603 {
4604 public:
4605 typedef typename Transpose::IRet IRet;
4606 typedef typename Transpose::IArgs IArgs;
4607
getName(void) const4608 string getName (void) const
4609 {
4610 return "transpose";
4611 }
4612
4613 protected:
doApply(const EvalContext &,const IArgs & iargs) const4614 IRet doApply (const EvalContext&, const IArgs& iargs) const
4615 {
4616 IRet ret;
4617
4618 for (int rowNdx = 0; rowNdx < Rows; ++rowNdx)
4619 {
4620 for (int colNdx = 0; colNdx < Columns; ++colNdx)
4621 ret(rowNdx, colNdx) = iargs.a(colNdx, rowNdx);
4622 }
4623
4624 return ret;
4625 }
4626 };
4627
4628 template<typename Ret, typename Arg0, typename Arg1>
4629 class MulFunc : public PrimitiveFunc<Signature<Ret, Arg0, Arg1> >
4630 {
4631 public:
getName(void) const4632 string getName (void) const { return "mul"; }
4633
4634 protected:
doPrint(ostream & os,const BaseArgExprs & args) const4635 void doPrint (ostream& os, const BaseArgExprs& args) const
4636 {
4637 os << "(" << *args[0] << " * " << *args[1] << ")";
4638 }
4639 };
4640
4641 template<typename T, int LeftRows, int Middle, int RightCols>
4642 class MatMul : public MulFunc<Matrix<T, LeftRows, RightCols>,
4643 Matrix<T, LeftRows, Middle>,
4644 Matrix<T, Middle, RightCols> >
4645 {
4646 protected:
4647 typedef typename MatMul::IRet IRet;
4648 typedef typename MatMul::IArgs IArgs;
4649 typedef typename MatMul::IArg0 IArg0;
4650 typedef typename MatMul::IArg1 IArg1;
4651
doApply(const EvalContext & ctx,const IArgs & iargs) const4652 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
4653 {
4654 const IArg0& left = iargs.a;
4655 const IArg1& right = iargs.b;
4656 IRet ret;
4657
4658 for (int row = 0; row < LeftRows; ++row)
4659 {
4660 for (int col = 0; col < RightCols; ++col)
4661 {
4662 Interval element (0.0);
4663
4664 for (int ndx = 0; ndx < Middle; ++ndx)
4665 element = call<Add< Signature<T, T, T> > >(ctx, element,
4666 call<Mul< Signature<T, T, T> > >(ctx, left[ndx][row], right[col][ndx]));
4667
4668 ret[col][row] = element;
4669 }
4670 }
4671
4672 return ret;
4673 }
4674 };
4675
4676 template<typename T, int Rows, int Cols>
4677 class VecMatMul : public MulFunc<Vector<T, Cols>,
4678 Vector<T, Rows>,
4679 Matrix<T, Rows, Cols> >
4680 {
4681 public:
4682 typedef typename VecMatMul::IRet IRet;
4683 typedef typename VecMatMul::IArgs IArgs;
4684 typedef typename VecMatMul::IArg0 IArg0;
4685 typedef typename VecMatMul::IArg1 IArg1;
4686
4687 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const4688 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
4689 {
4690 const IArg0& left = iargs.a;
4691 const IArg1& right = iargs.b;
4692 IRet ret;
4693
4694 for (int col = 0; col < Cols; ++col)
4695 {
4696 Interval element (0.0);
4697
4698 for (int row = 0; row < Rows; ++row)
4699 element = call<Add< Signature<T, T, T> > >(ctx, element, call<Mul< Signature<T, T, T> > >(ctx, left[row], right[col][row]));
4700
4701 ret[col] = element;
4702 }
4703
4704 return ret;
4705 }
4706 };
4707
4708 template<int Rows, int Cols, class T>
4709 class MatVecMul : public MulFunc<Vector<T, Rows>,
4710 Matrix<T, Rows, Cols>,
4711 Vector<T, Cols> >
4712 {
4713 public:
4714 typedef typename MatVecMul::IRet IRet;
4715 typedef typename MatVecMul::IArgs IArgs;
4716 typedef typename MatVecMul::IArg0 IArg0;
4717 typedef typename MatVecMul::IArg1 IArg1;
4718
4719 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const4720 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
4721 {
4722 const IArg0& left = iargs.a;
4723 const IArg1& right = iargs.b;
4724
4725 return call<VecMatMul<T, Cols, Rows> >(ctx, right,
4726 call<Transpose<Rows, Cols, T> >(ctx, left));
4727 }
4728 };
4729
4730 template<int Rows, int Cols, class T>
4731 class OuterProduct : public PrimitiveFunc<Signature<Matrix<T, Rows, Cols>,
4732 Vector<T, Rows>,
4733 Vector<T, Cols> > >
4734 {
4735 public:
4736 typedef typename OuterProduct::IRet IRet;
4737 typedef typename OuterProduct::IArgs IArgs;
4738
getName(void) const4739 string getName (void) const
4740 {
4741 return "outerProduct";
4742 }
4743
4744 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const4745 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
4746 {
4747 IRet ret;
4748
4749 for (int row = 0; row < Rows; ++row)
4750 {
4751 for (int col = 0; col < Cols; ++col)
4752 ret[col][row] = call<Mul< Signature<T, T, T> > >(ctx, iargs.a[row], iargs.b[col]);
4753 }
4754
4755 return ret;
4756 }
4757 };
4758
4759 template<int Rows, int Cols, class T>
outerProduct(const ExprP<Vector<T,Rows>> & left,const ExprP<Vector<T,Cols>> & right)4760 ExprP<Matrix<T, Rows, Cols> > outerProduct (const ExprP<Vector<T, Rows> >& left,
4761 const ExprP<Vector<T, Cols> >& right)
4762 {
4763 return app<OuterProduct<Rows, Cols, T> >(left, right);
4764 }
4765
4766 template<class T>
4767 class DeterminantBase : public DerivedFunc<T>
4768 {
4769 public:
getName(void) const4770 string getName (void) const { return "determinant"; }
4771 };
4772
4773 template<int Size> class Determinant;
4774 template<int Size> class Determinant16bit;
4775 template<int Size> class Determinant64bit;
4776
4777 template<int Size>
determinant(ExprP<Matrix<float,Size,Size>> mat)4778 ExprP<float> determinant (ExprP<Matrix<float, Size, Size> > mat)
4779 {
4780 return app<Determinant<Size> >(mat);
4781 }
4782
4783 template<int Size>
determinant(ExprP<Matrix<deFloat16,Size,Size>> mat)4784 ExprP<deFloat16> determinant (ExprP<Matrix<deFloat16, Size, Size> > mat)
4785 {
4786 return app<Determinant16bit<Size> >(mat);
4787 }
4788
4789 template<int Size>
determinant(ExprP<Matrix<double,Size,Size>> mat)4790 ExprP<double> determinant (ExprP<Matrix<double, Size, Size> > mat)
4791 {
4792 return app<Determinant64bit<Size> >(mat);
4793 }
4794
4795 template<>
4796 class Determinant<2> : public DeterminantBase<Signature<float, Matrix<float, 2, 2> > >
4797 {
4798 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4799 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
4800 {
4801 ExprP<Mat2> mat = args.a;
4802
4803 return mat[0][0] * mat[1][1] - mat[1][0] * mat[0][1];
4804 }
4805 };
4806
4807 template<>
4808 class Determinant<3> : public DeterminantBase<Signature<float, Matrix<float, 3, 3> > >
4809 {
4810 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4811 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
4812 {
4813 ExprP<Mat3> mat = args.a;
4814
4815 return (mat[0][0] * (mat[1][1] * mat[2][2] - mat[1][2] * mat[2][1]) +
4816 mat[0][1] * (mat[1][2] * mat[2][0] - mat[1][0] * mat[2][2]) +
4817 mat[0][2] * (mat[1][0] * mat[2][1] - mat[1][1] * mat[2][0]));
4818 }
4819 };
4820
4821 template<>
4822 class Determinant<4> : public DeterminantBase<Signature<float, Matrix<float, 4, 4> > >
4823 {
4824 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4825 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
4826 {
4827 ExprP<Mat4> mat = args.a;
4828 ExprP<Mat3> minors[4];
4829
4830 for (int ndx = 0; ndx < 4; ++ndx)
4831 {
4832 ExprP<Vec4> minorColumns[3];
4833 ExprP<Vec3> columns[3];
4834
4835 for (int col = 0; col < 3; ++col)
4836 minorColumns[col] = mat[col < ndx ? col : col + 1];
4837
4838 for (int col = 0; col < 3; ++col)
4839 columns[col] = vec3(minorColumns[0][col+1],
4840 minorColumns[1][col+1],
4841 minorColumns[2][col+1]);
4842
4843 minors[ndx] = bindExpression("minor", ctx,
4844 mat3(columns[0], columns[1], columns[2]));
4845 }
4846
4847 return (mat[0][0] * determinant(minors[0]) -
4848 mat[1][0] * determinant(minors[1]) +
4849 mat[2][0] * determinant(minors[2]) -
4850 mat[3][0] * determinant(minors[3]));
4851 }
4852 };
4853
4854 template<>
4855 class Determinant16bit<2> : public DeterminantBase<Signature<deFloat16, Matrix<deFloat16, 2, 2> > >
4856 {
4857 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4858 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
4859 {
4860 ExprP<Mat2_16b> mat = args.a;
4861
4862 return mat[0][0] * mat[1][1] - mat[1][0] * mat[0][1];
4863 }
4864 };
4865
4866 template<>
4867 class Determinant16bit<3> : public DeterminantBase<Signature<deFloat16, Matrix<deFloat16, 3, 3> > >
4868 {
4869 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4870 ExprP<Ret> doExpand(ExpandContext&, const ArgExprs& args) const
4871 {
4872 ExprP<Mat3_16b> mat = args.a;
4873
4874 return (mat[0][0] * (mat[1][1] * mat[2][2] - mat[1][2] * mat[2][1]) +
4875 mat[0][1] * (mat[1][2] * mat[2][0] - mat[1][0] * mat[2][2]) +
4876 mat[0][2] * (mat[1][0] * mat[2][1] - mat[1][1] * mat[2][0]));
4877 }
4878 };
4879
4880 template<>
4881 class Determinant16bit<4> : public DeterminantBase<Signature<deFloat16, Matrix<deFloat16, 4, 4> > >
4882 {
4883 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4884 ExprP<Ret> doExpand(ExpandContext& ctx, const ArgExprs& args) const
4885 {
4886 ExprP<Mat4_16b> mat = args.a;
4887 ExprP<Mat3_16b> minors[4];
4888
4889 for (int ndx = 0; ndx < 4; ++ndx)
4890 {
4891 ExprP<Vec4_16Bit> minorColumns[3];
4892 ExprP<Vec3_16Bit> columns[3];
4893
4894 for (int col = 0; col < 3; ++col)
4895 minorColumns[col] = mat[col < ndx ? col : col + 1];
4896
4897 for (int col = 0; col < 3; ++col)
4898 columns[col] = vec3(minorColumns[0][col + 1],
4899 minorColumns[1][col + 1],
4900 minorColumns[2][col + 1]);
4901
4902 minors[ndx] = bindExpression("minor", ctx,
4903 mat3(columns[0], columns[1], columns[2]));
4904 }
4905
4906 return (mat[0][0] * determinant(minors[0]) -
4907 mat[1][0] * determinant(minors[1]) +
4908 mat[2][0] * determinant(minors[2]) -
4909 mat[3][0] * determinant(minors[3]));
4910 }
4911 };
4912
4913 template<>
4914 class Determinant64bit<2> : public DeterminantBase<Signature<double, Matrix<double, 2, 2> > >
4915 {
4916 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4917 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
4918 {
4919 ExprP<Matrix2d> mat = args.a;
4920
4921 return mat[0][0] * mat[1][1] - mat[1][0] * mat[0][1];
4922 }
4923 };
4924
4925 template<>
4926 class Determinant64bit<3> : public DeterminantBase<Signature<double, Matrix<double, 3, 3> > >
4927 {
4928 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4929 ExprP<Ret> doExpand(ExpandContext&, const ArgExprs& args) const
4930 {
4931 ExprP<Matrix3d> mat = args.a;
4932
4933 return (mat[0][0] * (mat[1][1] * mat[2][2] - mat[1][2] * mat[2][1]) +
4934 mat[0][1] * (mat[1][2] * mat[2][0] - mat[1][0] * mat[2][2]) +
4935 mat[0][2] * (mat[1][0] * mat[2][1] - mat[1][1] * mat[2][0]));
4936 }
4937 };
4938
4939 template<>
4940 class Determinant64bit<4> : public DeterminantBase<Signature<double, Matrix<double, 4, 4> > >
4941 {
4942 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4943 ExprP<Ret> doExpand(ExpandContext& ctx, const ArgExprs& args) const
4944 {
4945 ExprP<Matrix4d> mat = args.a;
4946 ExprP<Matrix3d> minors[4];
4947
4948 for (int ndx = 0; ndx < 4; ++ndx)
4949 {
4950 ExprP<Vec4_64Bit> minorColumns[3];
4951 ExprP<Vec3_64Bit> columns[3];
4952
4953 for (int col = 0; col < 3; ++col)
4954 minorColumns[col] = mat[col < ndx ? col : col + 1];
4955
4956 for (int col = 0; col < 3; ++col)
4957 columns[col] = vec3(minorColumns[0][col + 1],
4958 minorColumns[1][col + 1],
4959 minorColumns[2][col + 1]);
4960
4961 minors[ndx] = bindExpression("minor", ctx,
4962 mat3(columns[0], columns[1], columns[2]));
4963 }
4964
4965 return (mat[0][0] * determinant(minors[0]) -
4966 mat[1][0] * determinant(minors[1]) +
4967 mat[2][0] * determinant(minors[2]) -
4968 mat[3][0] * determinant(minors[3]));
4969 }
4970 };
4971
4972 template<int Size> class Inverse;
4973
4974 template <int Size>
inverse(ExprP<Matrix<float,Size,Size>> mat)4975 ExprP<Matrix<float, Size, Size> > inverse (ExprP<Matrix<float, Size, Size> > mat)
4976 {
4977 return app<Inverse<Size> >(mat);
4978 }
4979
4980 template<int Size> class Inverse16bit;
4981
4982 template <int Size>
inverse(ExprP<Matrix<deFloat16,Size,Size>> mat)4983 ExprP<Matrix<deFloat16, Size, Size> > inverse (ExprP<Matrix<deFloat16, Size, Size> > mat)
4984 {
4985 return app<Inverse16bit<Size> >(mat);
4986 }
4987
4988 template<int Size> class Inverse64bit;
4989
4990 template <int Size>
inverse(ExprP<Matrix<double,Size,Size>> mat)4991 ExprP<Matrix<double, Size, Size> > inverse (ExprP<Matrix<double, Size, Size> > mat)
4992 {
4993 return app<Inverse64bit<Size> >(mat);
4994 }
4995
4996 template<>
4997 class Inverse<2> : public DerivedFunc<Signature<Mat2, Mat2> >
4998 {
4999 public:
getName(void) const5000 string getName (void) const
5001 {
5002 return "inverse";
5003 }
5004
5005 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5006 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
5007 {
5008 ExprP<Mat2> mat = args.a;
5009 ExprP<float> det = bindExpression("det", ctx, determinant(mat));
5010
5011 return mat2(vec2(mat[1][1] / det, -mat[0][1] / det),
5012 vec2(-mat[1][0] / det, mat[0][0] / det));
5013 }
5014 };
5015
5016 template<>
5017 class Inverse<3> : public DerivedFunc<Signature<Mat3, Mat3> >
5018 {
5019 public:
getName(void) const5020 string getName (void) const
5021 {
5022 return "inverse";
5023 }
5024
5025 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5026 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
5027 {
5028 ExprP<Mat3> mat = args.a;
5029 ExprP<Mat2> invA = bindExpression("invA", ctx,
5030 inverse(mat2(vec2(mat[0][0], mat[0][1]),
5031 vec2(mat[1][0], mat[1][1]))));
5032
5033 ExprP<Vec2> matB = bindExpression("matB", ctx, vec2(mat[2][0], mat[2][1]));
5034 ExprP<Vec2> matC = bindExpression("matC", ctx, vec2(mat[0][2], mat[1][2]));
5035 ExprP<float> matD = bindExpression("matD", ctx, mat[2][2]);
5036
5037 ExprP<float> schur = bindExpression("schur", ctx,
5038 constant(1.0f) /
5039 (matD - dot(matC * invA, matB)));
5040
5041 ExprP<Vec2> t1 = invA * matB;
5042 ExprP<Vec2> t2 = t1 * schur;
5043 ExprP<Mat2> t3 = outerProduct(t2, matC);
5044 ExprP<Mat2> t4 = t3 * invA;
5045 ExprP<Mat2> t5 = invA + t4;
5046 ExprP<Mat2> blockA = bindExpression("blockA", ctx, t5);
5047 ExprP<Vec2> blockB = bindExpression("blockB", ctx,
5048 (invA * matB) * -schur);
5049 ExprP<Vec2> blockC = bindExpression("blockC", ctx,
5050 (matC * invA) * -schur);
5051
5052 return mat3(vec3(blockA[0][0], blockA[0][1], blockC[0]),
5053 vec3(blockA[1][0], blockA[1][1], blockC[1]),
5054 vec3(blockB[0], blockB[1], schur));
5055 }
5056 };
5057
5058 template<>
5059 class Inverse<4> : public DerivedFunc<Signature<Mat4, Mat4> >
5060 {
5061 public:
getName(void) const5062 string getName (void) const { return "inverse"; }
5063
5064 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5065 ExprP<Ret> doExpand (ExpandContext& ctx,
5066 const ArgExprs& args) const
5067 {
5068 ExprP<Mat4> mat = args.a;
5069 ExprP<Mat2> invA = bindExpression("invA", ctx,
5070 inverse(mat2(vec2(mat[0][0], mat[0][1]),
5071 vec2(mat[1][0], mat[1][1]))));
5072 ExprP<Mat2> matB = bindExpression("matB", ctx,
5073 mat2(vec2(mat[2][0], mat[2][1]),
5074 vec2(mat[3][0], mat[3][1])));
5075 ExprP<Mat2> matC = bindExpression("matC", ctx,
5076 mat2(vec2(mat[0][2], mat[0][3]),
5077 vec2(mat[1][2], mat[1][3])));
5078 ExprP<Mat2> matD = bindExpression("matD", ctx,
5079 mat2(vec2(mat[2][2], mat[2][3]),
5080 vec2(mat[3][2], mat[3][3])));
5081 ExprP<Mat2> schur = bindExpression("schur", ctx,
5082 inverse(matD + -(matC * invA * matB)));
5083 ExprP<Mat2> blockA = bindExpression("blockA", ctx,
5084 invA + (invA * matB * schur * matC * invA));
5085 ExprP<Mat2> blockB = bindExpression("blockB", ctx,
5086 (-invA) * matB * schur);
5087 ExprP<Mat2> blockC = bindExpression("blockC", ctx,
5088 (-schur) * matC * invA);
5089
5090 return mat4(vec4(blockA[0][0], blockA[0][1], blockC[0][0], blockC[0][1]),
5091 vec4(blockA[1][0], blockA[1][1], blockC[1][0], blockC[1][1]),
5092 vec4(blockB[0][0], blockB[0][1], schur[0][0], schur[0][1]),
5093 vec4(blockB[1][0], blockB[1][1], schur[1][0], schur[1][1]));
5094 }
5095 };
5096
5097 template<>
5098 class Inverse16bit<2> : public DerivedFunc<Signature<Mat2_16b, Mat2_16b> >
5099 {
5100 public:
getName(void) const5101 string getName (void) const
5102 {
5103 return "inverse";
5104 }
5105
5106 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5107 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
5108 {
5109 ExprP<Mat2_16b> mat = args.a;
5110 ExprP<deFloat16> det = bindExpression("det", ctx, determinant(mat));
5111
5112 return mat2(vec2((mat[1][1] / det), (-mat[0][1] / det)),
5113 vec2((-mat[1][0] / det), (mat[0][0] / det)));
5114 }
5115 };
5116
5117 template<>
5118 class Inverse16bit<3> : public DerivedFunc<Signature<Mat3_16b, Mat3_16b> >
5119 {
5120 public:
getName(void) const5121 string getName(void) const
5122 {
5123 return "inverse";
5124 }
5125
5126 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5127 ExprP<Ret> doExpand(ExpandContext& ctx, const ArgExprs& args) const
5128 {
5129 ExprP<Mat3_16b> mat = args.a;
5130 ExprP<Mat2_16b> invA = bindExpression("invA", ctx,
5131 inverse(mat2(vec2(mat[0][0], mat[0][1]),
5132 vec2(mat[1][0], mat[1][1]))));
5133
5134 ExprP<Vec2_16Bit> matB = bindExpression("matB", ctx, vec2(mat[2][0], mat[2][1]));
5135 ExprP<Vec2_16Bit> matC = bindExpression("matC", ctx, vec2(mat[0][2], mat[1][2]));
5136 ExprP<Mat3_16b::Scalar> matD = bindExpression("matD", ctx, mat[2][2]);
5137
5138 ExprP<Mat3_16b::Scalar> schur = bindExpression("schur", ctx,
5139 constant((deFloat16)FLOAT16_1_0) /
5140 (matD - dot(matC * invA, matB)));
5141
5142 ExprP<Vec2_16Bit> t1 = invA * matB;
5143 ExprP<Vec2_16Bit> t2 = t1 * schur;
5144 ExprP<Mat2_16b> t3 = outerProduct(t2, matC);
5145 ExprP<Mat2_16b> t4 = t3 * invA;
5146 ExprP<Mat2_16b> t5 = invA + t4;
5147 ExprP<Mat2_16b> blockA = bindExpression("blockA", ctx, t5);
5148 ExprP<Vec2_16Bit> blockB = bindExpression("blockB", ctx,
5149 (invA * matB) * -schur);
5150 ExprP<Vec2_16Bit> blockC = bindExpression("blockC", ctx,
5151 (matC * invA) * -schur);
5152
5153 return mat3(vec3(blockA[0][0], blockA[0][1], blockC[0]),
5154 vec3(blockA[1][0], blockA[1][1], blockC[1]),
5155 vec3(blockB[0], blockB[1], schur));
5156 }
5157 };
5158
5159 template<>
5160 class Inverse16bit<4> : public DerivedFunc<Signature<Mat4_16b, Mat4_16b> >
5161 {
5162 public:
getName(void) const5163 string getName(void) const { return "inverse"; }
5164
5165 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5166 ExprP<Ret> doExpand(ExpandContext& ctx,
5167 const ArgExprs& args) const
5168 {
5169 ExprP<Mat4_16b> mat = args.a;
5170 ExprP<Mat2_16b> invA = bindExpression("invA", ctx,
5171 inverse(mat2(vec2(mat[0][0], mat[0][1]),
5172 vec2(mat[1][0], mat[1][1]))));
5173 ExprP<Mat2_16b> matB = bindExpression("matB", ctx,
5174 mat2(vec2(mat[2][0], mat[2][1]),
5175 vec2(mat[3][0], mat[3][1])));
5176 ExprP<Mat2_16b> matC = bindExpression("matC", ctx,
5177 mat2(vec2(mat[0][2], mat[0][3]),
5178 vec2(mat[1][2], mat[1][3])));
5179 ExprP<Mat2_16b> matD = bindExpression("matD", ctx,
5180 mat2(vec2(mat[2][2], mat[2][3]),
5181 vec2(mat[3][2], mat[3][3])));
5182 ExprP<Mat2_16b> schur = bindExpression("schur", ctx,
5183 inverse(matD + -(matC * invA * matB)));
5184 ExprP<Mat2_16b> blockA = bindExpression("blockA", ctx,
5185 invA + (invA * matB * schur * matC * invA));
5186 ExprP<Mat2_16b> blockB = bindExpression("blockB", ctx,
5187 (-invA) * matB * schur);
5188 ExprP<Mat2_16b> blockC = bindExpression("blockC", ctx,
5189 (-schur) * matC * invA);
5190
5191 return mat4(vec4(blockA[0][0], blockA[0][1], blockC[0][0], blockC[0][1]),
5192 vec4(blockA[1][0], blockA[1][1], blockC[1][0], blockC[1][1]),
5193 vec4(blockB[0][0], blockB[0][1], schur[0][0], schur[0][1]),
5194 vec4(blockB[1][0], blockB[1][1], schur[1][0], schur[1][1]));
5195 }
5196 };
5197
5198 template<>
5199 class Inverse64bit<2> : public DerivedFunc<Signature<Matrix2d, Matrix2d> >
5200 {
5201 public:
getName(void) const5202 string getName (void) const
5203 {
5204 return "inverse";
5205 }
5206
5207 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5208 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
5209 {
5210 ExprP<Matrix2d> mat = args.a;
5211 ExprP<double> det = bindExpression("det", ctx, determinant(mat));
5212
5213 return mat2(vec2((mat[1][1] / det), (-mat[0][1] / det)),
5214 vec2((-mat[1][0] / det), (mat[0][0] / det)));
5215 }
5216 };
5217
5218 template<>
5219 class Inverse64bit<3> : public DerivedFunc<Signature<Matrix3d, Matrix3d> >
5220 {
5221 public:
getName(void) const5222 string getName(void) const
5223 {
5224 return "inverse";
5225 }
5226
5227 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5228 ExprP<Ret> doExpand(ExpandContext& ctx, const ArgExprs& args) const
5229 {
5230 ExprP<Matrix3d> mat = args.a;
5231 ExprP<Matrix2d> invA = bindExpression("invA", ctx,
5232 inverse(mat2(vec2(mat[0][0], mat[0][1]),
5233 vec2(mat[1][0], mat[1][1]))));
5234
5235 ExprP<Vec2_64Bit> matB = bindExpression("matB", ctx, vec2(mat[2][0], mat[2][1]));
5236 ExprP<Vec2_64Bit> matC = bindExpression("matC", ctx, vec2(mat[0][2], mat[1][2]));
5237 ExprP<Matrix3d::Scalar> matD = bindExpression("matD", ctx, mat[2][2]);
5238
5239 ExprP<Matrix3d::Scalar> schur = bindExpression("schur", ctx,
5240 constant(1.0) /
5241 (matD - dot(matC * invA, matB)));
5242
5243 ExprP<Vec2_64Bit> t1 = invA * matB;
5244 ExprP<Vec2_64Bit> t2 = t1 * schur;
5245 ExprP<Matrix2d> t3 = outerProduct(t2, matC);
5246 ExprP<Matrix2d> t4 = t3 * invA;
5247 ExprP<Matrix2d> t5 = invA + t4;
5248 ExprP<Matrix2d> blockA = bindExpression("blockA", ctx, t5);
5249 ExprP<Vec2_64Bit> blockB = bindExpression("blockB", ctx,
5250 (invA * matB) * -schur);
5251 ExprP<Vec2_64Bit> blockC = bindExpression("blockC", ctx,
5252 (matC * invA) * -schur);
5253
5254 return mat3(vec3(blockA[0][0], blockA[0][1], blockC[0]),
5255 vec3(blockA[1][0], blockA[1][1], blockC[1]),
5256 vec3(blockB[0], blockB[1], schur));
5257 }
5258 };
5259
5260 template<>
5261 class Inverse64bit<4> : public DerivedFunc<Signature<Matrix4d, Matrix4d> >
5262 {
5263 public:
getName(void) const5264 string getName(void) const { return "inverse"; }
5265
5266 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5267 ExprP<Ret> doExpand(ExpandContext& ctx,
5268 const ArgExprs& args) const
5269 {
5270 ExprP<Matrix4d> mat = args.a;
5271 ExprP<Matrix2d> invA = bindExpression("invA", ctx,
5272 inverse(mat2(vec2(mat[0][0], mat[0][1]),
5273 vec2(mat[1][0], mat[1][1]))));
5274 ExprP<Matrix2d> matB = bindExpression("matB", ctx,
5275 mat2(vec2(mat[2][0], mat[2][1]),
5276 vec2(mat[3][0], mat[3][1])));
5277 ExprP<Matrix2d> matC = bindExpression("matC", ctx,
5278 mat2(vec2(mat[0][2], mat[0][3]),
5279 vec2(mat[1][2], mat[1][3])));
5280 ExprP<Matrix2d> matD = bindExpression("matD", ctx,
5281 mat2(vec2(mat[2][2], mat[2][3]),
5282 vec2(mat[3][2], mat[3][3])));
5283 ExprP<Matrix2d> schur = bindExpression("schur", ctx,
5284 inverse(matD + -(matC * invA * matB)));
5285 ExprP<Matrix2d> blockA = bindExpression("blockA", ctx,
5286 invA + (invA * matB * schur * matC * invA));
5287 ExprP<Matrix2d> blockB = bindExpression("blockB", ctx,
5288 (-invA) * matB * schur);
5289 ExprP<Matrix2d> blockC = bindExpression("blockC", ctx,
5290 (-schur) * matC * invA);
5291
5292 return mat4(vec4(blockA[0][0], blockA[0][1], blockC[0][0], blockC[0][1]),
5293 vec4(blockA[1][0], blockA[1][1], blockC[1][0], blockC[1][1]),
5294 vec4(blockB[0][0], blockB[0][1], schur[0][0], schur[0][1]),
5295 vec4(blockB[1][0], blockB[1][1], schur[1][0], schur[1][1]));
5296 }
5297 };
5298
5299 //Signature<float, float, float, float>
5300 //Signature<deFloat16, deFloat16, deFloat16, deFloat16>
5301 //Signature<double, double, double, double>
5302 template <class T>
5303 class Fma : public DerivedFunc<T>
5304 {
5305 public:
5306 typedef typename Fma::ArgExprs ArgExprs;
5307 typedef typename Fma::Ret Ret;
5308
getName(void) const5309 string getName (void) const
5310 {
5311 return "fma";
5312 }
5313
5314 protected:
doExpand(ExpandContext &,const ArgExprs & x) const5315 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& x) const
5316 {
5317 return x.a * x.b + x.c;
5318 }
5319 };
5320
5321 } // Functions
5322
5323 using namespace Functions;
5324
5325 template <typename T>
operator [](int i) const5326 ExprP<typename T::Element> ContainerExprPBase<T>::operator[] (int i) const
5327 {
5328 return Functions::getComponent(exprP<T>(*this), i);
5329 }
5330
operator +(const ExprP<float> & arg0,const ExprP<float> & arg1)5331 ExprP<float> operator+ (const ExprP<float>& arg0, const ExprP<float>& arg1)
5332 {
5333 return app<Add< Signature<float, float, float> > >(arg0, arg1);
5334 }
5335
operator +(const ExprP<deFloat16> & arg0,const ExprP<deFloat16> & arg1)5336 ExprP<deFloat16> operator+ (const ExprP<deFloat16>& arg0, const ExprP<deFloat16>& arg1)
5337 {
5338 return app<Add< Signature<deFloat16, deFloat16, deFloat16> > >(arg0, arg1);
5339 }
5340
operator +(const ExprP<double> & arg0,const ExprP<double> & arg1)5341 ExprP<double> operator+ (const ExprP<double>& arg0, const ExprP<double>& arg1)
5342 {
5343 return app<Add< Signature<double, double, double> > >(arg0, arg1);
5344 }
5345
5346 template <typename T>
operator -(const ExprP<T> & arg0,const ExprP<T> & arg1)5347 ExprP<T> operator- (const ExprP<T>& arg0, const ExprP<T>& arg1)
5348 {
5349 return app<Sub <Signature <T,T,T> > >(arg0, arg1);
5350 }
5351
5352 template <typename T>
operator -(const ExprP<T> & arg0)5353 ExprP<T> operator- (const ExprP<T>& arg0)
5354 {
5355 return app<Negate< Signature<T, T> > >(arg0);
5356 }
5357
operator *(const ExprP<float> & arg0,const ExprP<float> & arg1)5358 ExprP<float> operator* (const ExprP<float>& arg0, const ExprP<float>& arg1)
5359 {
5360 return app<Mul< Signature<float, float, float> > >(arg0, arg1);
5361 }
5362
operator *(const ExprP<deFloat16> & arg0,const ExprP<deFloat16> & arg1)5363 ExprP<deFloat16> operator* (const ExprP<deFloat16>& arg0, const ExprP<deFloat16>& arg1)
5364 {
5365 return app<Mul< Signature<deFloat16, deFloat16, deFloat16> > >(arg0, arg1);
5366 }
5367
operator *(const ExprP<double> & arg0,const ExprP<double> & arg1)5368 ExprP<double> operator* (const ExprP<double>& arg0, const ExprP<double>& arg1)
5369 {
5370 return app<Mul< Signature<double, double, double> > >(arg0, arg1);
5371 }
5372
5373 template <typename T>
operator /(const ExprP<T> & arg0,const ExprP<T> & arg1)5374 ExprP<T> operator/ (const ExprP<T>& arg0, const ExprP<T>& arg1)
5375 {
5376 return app<Div< Signature<T, T, T> > >(arg0, arg1);
5377 }
5378
5379
5380 template <typename Sig_, int Size>
5381 class GenFunc : public PrimitiveFunc<Signature<
5382 typename ContainerOf<typename Sig_::Ret, Size>::Container,
5383 typename ContainerOf<typename Sig_::Arg0, Size>::Container,
5384 typename ContainerOf<typename Sig_::Arg1, Size>::Container,
5385 typename ContainerOf<typename Sig_::Arg2, Size>::Container,
5386 typename ContainerOf<typename Sig_::Arg3, Size>::Container> >
5387 {
5388 public:
5389 typedef typename GenFunc::IArgs IArgs;
5390 typedef typename GenFunc::IRet IRet;
5391
GenFunc(const Func<Sig_> & scalarFunc)5392 GenFunc (const Func<Sig_>& scalarFunc) : m_func (scalarFunc) {}
5393
getSpirvCase(void) const5394 SpirVCaseT getSpirvCase (void) const
5395 {
5396 return m_func.getSpirvCase();
5397 }
5398
getName(void) const5399 string getName (void) const
5400 {
5401 return m_func.getName();
5402 }
5403
getOutParamIndex(void) const5404 int getOutParamIndex (void) const
5405 {
5406 return m_func.getOutParamIndex();
5407 }
5408
getRequiredExtension(void) const5409 string getRequiredExtension (void) const
5410 {
5411 return m_func.getRequiredExtension();
5412 }
5413
getInputRange(const bool is16bit) const5414 Interval getInputRange (const bool is16bit) const
5415 {
5416 return m_func.getInputRange(is16bit);
5417 }
5418
5419 protected:
doPrint(ostream & os,const BaseArgExprs & args) const5420 void doPrint (ostream& os, const BaseArgExprs& args) const
5421 {
5422 m_func.print(os, args);
5423 }
5424
doApply(const EvalContext & ctx,const IArgs & iargs) const5425 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
5426 {
5427 IRet ret;
5428
5429 for (int ndx = 0; ndx < Size; ++ndx)
5430 {
5431 ret[ndx] =
5432 m_func.apply(ctx, iargs.a[ndx], iargs.b[ndx], iargs.c[ndx], iargs.d[ndx]);
5433 }
5434
5435 return ret;
5436 }
5437
doFail(const EvalContext & ctx,const IArgs & iargs) const5438 IRet doFail (const EvalContext& ctx, const IArgs& iargs) const
5439 {
5440 IRet ret;
5441
5442 for (int ndx = 0; ndx < Size; ++ndx)
5443 {
5444 ret[ndx] =
5445 m_func.fail(ctx, iargs.a[ndx], iargs.b[ndx], iargs.c[ndx], iargs.d[ndx]);
5446 }
5447
5448 return ret;
5449 }
5450
doGetUsedFuncs(FuncSet & dst) const5451 void doGetUsedFuncs (FuncSet& dst) const
5452 {
5453 m_func.getUsedFuncs(dst);
5454 }
5455
5456 const Func<Sig_>& m_func;
5457 };
5458
5459 template <typename F, int Size>
5460 class VectorizedFunc : public GenFunc<typename F::Sig, Size>
5461 {
5462 public:
VectorizedFunc(void)5463 VectorizedFunc (void) : GenFunc<typename F::Sig, Size>(instance<F>()) {}
5464 };
5465
5466 template <typename Sig_, int Size>
5467 class FixedGenFunc : public PrimitiveFunc <Signature<
5468 typename ContainerOf<typename Sig_::Ret, Size>::Container,
5469 typename ContainerOf<typename Sig_::Arg0, Size>::Container,
5470 typename Sig_::Arg1,
5471 typename ContainerOf<typename Sig_::Arg2, Size>::Container,
5472 typename ContainerOf<typename Sig_::Arg3, Size>::Container> >
5473 {
5474 public:
5475 typedef typename FixedGenFunc::IArgs IArgs;
5476 typedef typename FixedGenFunc::IRet IRet;
5477
getName(void) const5478 string getName (void) const
5479 {
5480 return this->doGetScalarFunc().getName();
5481 }
5482
getSpirvCase(void) const5483 SpirVCaseT getSpirvCase (void) const
5484 {
5485 return this->doGetScalarFunc().getSpirvCase();
5486 }
5487
5488 protected:
doPrint(ostream & os,const BaseArgExprs & args) const5489 void doPrint (ostream& os, const BaseArgExprs& args) const
5490 {
5491 this->doGetScalarFunc().print(os, args);
5492 }
5493
doApply(const EvalContext & ctx,const IArgs & iargs) const5494 IRet doApply (const EvalContext& ctx,
5495 const IArgs& iargs) const
5496 {
5497 IRet ret;
5498 const Func<Sig_>& func = this->doGetScalarFunc();
5499
5500 for (int ndx = 0; ndx < Size; ++ndx)
5501 ret[ndx] = func.apply(ctx, iargs.a[ndx], iargs.b, iargs.c[ndx], iargs.d[ndx]);
5502
5503 return ret;
5504 }
5505
5506 virtual const Func<Sig_>& doGetScalarFunc (void) const = 0;
5507 };
5508
5509 template <typename F, int Size>
5510 class FixedVecFunc : public FixedGenFunc<typename F::Sig, Size>
5511 {
5512 protected:
doGetScalarFunc(void) const5513 const Func<typename F::Sig>& doGetScalarFunc (void) const { return instance<F>(); }
5514 };
5515
5516 template<typename Sig>
5517 struct GenFuncs
5518 {
GenFuncsvkt::shaderexecutor::GenFuncs5519 GenFuncs (const Func<Sig>& func_,
5520 const GenFunc<Sig, 2>& func2_,
5521 const GenFunc<Sig, 3>& func3_,
5522 const GenFunc<Sig, 4>& func4_)
5523 : func (func_)
5524 , func2 (func2_)
5525 , func3 (func3_)
5526 , func4 (func4_)
5527 {}
5528
5529 const Func<Sig>& func;
5530 const GenFunc<Sig, 2>& func2;
5531 const GenFunc<Sig, 3>& func3;
5532 const GenFunc<Sig, 4>& func4;
5533 };
5534
5535 template<typename F>
makeVectorizedFuncs(void)5536 GenFuncs<typename F::Sig> makeVectorizedFuncs (void)
5537 {
5538 return GenFuncs<typename F::Sig>(instance<F>(),
5539 instance<VectorizedFunc<F, 2> >(),
5540 instance<VectorizedFunc<F, 3> >(),
5541 instance<VectorizedFunc<F, 4> >());
5542 }
5543
5544 template<typename T, int Size>
operator /(const ExprP<Vector<T,Size>> & arg0,const ExprP<T> & arg1)5545 ExprP<Vector<T, Size> > operator/(const ExprP<Vector<T, Size> >& arg0,
5546 const ExprP<T>& arg1)
5547 {
5548 return app<FixedVecFunc<Div< Signature<T, T, T> >, Size> >(arg0, arg1);
5549 }
5550
5551 template<typename T, int Size>
operator -(const ExprP<Vector<T,Size>> & arg0)5552 ExprP<Vector<T, Size> > operator-(const ExprP<Vector<T, Size> >& arg0)
5553 {
5554 return app<VectorizedFunc<Negate< Signature<T, T> >, Size> >(arg0);
5555 }
5556
5557 template<typename T, int Size>
operator -(const ExprP<Vector<T,Size>> & arg0,const ExprP<Vector<T,Size>> & arg1)5558 ExprP<Vector<T, Size> > operator-(const ExprP<Vector<T, Size> >& arg0,
5559 const ExprP<Vector<T, Size> >& arg1)
5560 {
5561 return app<VectorizedFunc<Sub<Signature<T, T, T> >, Size> >(arg0, arg1);
5562 }
5563
5564 template<int Size, typename T>
operator *(const ExprP<Vector<T,Size>> & arg0,const ExprP<T> & arg1)5565 ExprP<Vector<T, Size> > operator*(const ExprP<Vector<T, Size> >& arg0,
5566 const ExprP<T>& arg1)
5567 {
5568 return app<FixedVecFunc<Mul< Signature<T, T, T> >, Size> >(arg0, arg1);
5569 }
5570
5571 template<typename T, int Size>
operator *(const ExprP<Vector<T,Size>> & arg0,const ExprP<Vector<T,Size>> & arg1)5572 ExprP<Vector<T, Size> > operator*(const ExprP<Vector<T, Size> >& arg0,
5573 const ExprP<Vector<T, Size> >& arg1)
5574 {
5575 return app<VectorizedFunc<Mul< Signature<T, T, T> >, Size> >(arg0, arg1);
5576 }
5577
5578 template<int LeftRows, int Middle, int RightCols, typename T>
5579 ExprP<Matrix<T, LeftRows, RightCols> >
operator *(const ExprP<Matrix<T,LeftRows,Middle>> & left,const ExprP<Matrix<T,Middle,RightCols>> & right)5580 operator* (const ExprP<Matrix<T, LeftRows, Middle> >& left,
5581 const ExprP<Matrix<T, Middle, RightCols> >& right)
5582 {
5583 return app<MatMul<T, LeftRows, Middle, RightCols> >(left, right);
5584 }
5585
5586 template<int Rows, int Cols, typename T>
operator *(const ExprP<Vector<T,Cols>> & left,const ExprP<Matrix<T,Rows,Cols>> & right)5587 ExprP<Vector<T, Rows> > operator* (const ExprP<Vector<T, Cols> >& left,
5588 const ExprP<Matrix<T, Rows, Cols> >& right)
5589 {
5590 return app<VecMatMul<T, Rows, Cols> >(left, right);
5591 }
5592
5593 template<int Rows, int Cols, class T>
operator *(const ExprP<Matrix<T,Rows,Cols>> & left,const ExprP<Vector<T,Rows>> & right)5594 ExprP<Vector<T, Cols> > operator* (const ExprP<Matrix<T, Rows, Cols> >& left,
5595 const ExprP<Vector<T, Rows> >& right)
5596 {
5597 return app<MatVecMul<Rows, Cols, T> >(left, right);
5598 }
5599
5600 template<int Rows, int Cols, typename T>
operator *(const ExprP<Matrix<T,Rows,Cols>> & left,const ExprP<T> & right)5601 ExprP<Matrix<T, Rows, Cols> > operator* (const ExprP<Matrix<T, Rows, Cols> >& left,
5602 const ExprP<T>& right)
5603 {
5604 return app<ScalarMatFunc<Mul< Signature<T, T, T> >, Rows, Cols> >(left, right);
5605 }
5606
5607 template<int Rows, int Cols>
operator +(const ExprP<Matrix<float,Rows,Cols>> & left,const ExprP<Matrix<float,Rows,Cols>> & right)5608 ExprP<Matrix<float, Rows, Cols> > operator+ (const ExprP<Matrix<float, Rows, Cols> >& left,
5609 const ExprP<Matrix<float, Rows, Cols> >& right)
5610 {
5611 return app<CompMatFunc<Add< Signature<float, float, float> >,float, Rows, Cols> >(left, right);
5612 }
5613
5614 template<int Rows, int Cols>
operator +(const ExprP<Matrix<deFloat16,Rows,Cols>> & left,const ExprP<Matrix<deFloat16,Rows,Cols>> & right)5615 ExprP<Matrix<deFloat16, Rows, Cols> > operator+ (const ExprP<Matrix<deFloat16, Rows, Cols> >& left,
5616 const ExprP<Matrix<deFloat16, Rows, Cols> >& right)
5617 {
5618 return app<CompMatFunc<Add< Signature<deFloat16, deFloat16, deFloat16> >, deFloat16, Rows, Cols> >(left, right);
5619 }
5620
5621 template<int Rows, int Cols>
operator +(const ExprP<Matrix<double,Rows,Cols>> & left,const ExprP<Matrix<double,Rows,Cols>> & right)5622 ExprP<Matrix<double, Rows, Cols> > operator+ (const ExprP<Matrix<double, Rows, Cols> >& left,
5623 const ExprP<Matrix<double, Rows, Cols> >& right)
5624 {
5625 return app<CompMatFunc<Add< Signature<double, double, double> >, double, Rows, Cols> >(left, right);
5626 }
5627
5628 template<typename T, int Rows, int Cols>
operator -(const ExprP<Matrix<T,Rows,Cols>> & mat)5629 ExprP<Matrix<T, Rows, Cols> > operator- (const ExprP<Matrix<T, Rows, Cols> >& mat)
5630 {
5631 return app<MatNeg<T, Rows, Cols> >(mat);
5632 }
5633
5634 template <typename T>
5635 class Sampling
5636 {
5637 public:
genFixeds(const FloatFormat &,const Precision,vector<T> &,const Interval &) const5638 virtual void genFixeds (const FloatFormat&, const Precision, vector<T>&, const Interval&) const {}
genRandom(const FloatFormat &,const Precision,Random &,const Interval &) const5639 virtual T genRandom (const FloatFormat&,const Precision, Random&, const Interval&) const { return T(); }
removeNotInRange(vector<T> &,const Interval &,const Precision) const5640 virtual void removeNotInRange (vector<T>&, const Interval&, const Precision) const {}
5641 };
5642
5643 template <>
5644 class DefaultSampling<Void> : public Sampling<Void>
5645 {
5646 public:
genFixeds(const FloatFormat &,const Precision,vector<Void> & dst,const Interval &) const5647 void genFixeds (const FloatFormat&, const Precision, vector<Void>& dst, const Interval&) const { dst.push_back(Void()); }
5648 };
5649
5650 template <>
5651 class DefaultSampling<bool> : public Sampling<bool>
5652 {
5653 public:
genFixeds(const FloatFormat &,const Precision,vector<bool> & dst,const Interval &) const5654 void genFixeds (const FloatFormat&, const Precision, vector<bool>& dst, const Interval&) const
5655 {
5656 dst.push_back(true);
5657 dst.push_back(false);
5658 }
5659 };
5660
5661 template <>
5662 class DefaultSampling<int> : public Sampling<int>
5663 {
5664 public:
genRandom(const FloatFormat &,const Precision prec,Random & rnd,const Interval &) const5665 int genRandom (const FloatFormat&, const Precision prec, Random& rnd, const Interval&) const
5666 {
5667 const int exp = rnd.getInt(0, getNumBits(prec)-2);
5668 const int sign = rnd.getBool() ? -1 : 1;
5669
5670 return sign * rnd.getInt(0, (deInt32)1 << exp);
5671 }
5672
genFixeds(const FloatFormat &,const Precision,vector<int> & dst,const Interval &) const5673 void genFixeds (const FloatFormat&, const Precision, vector<int>& dst, const Interval&) const
5674 {
5675 dst.push_back(0);
5676 dst.push_back(-1);
5677 dst.push_back(1);
5678 }
5679
5680 private:
getNumBits(Precision prec)5681 static inline int getNumBits (Precision prec)
5682 {
5683 switch (prec)
5684 {
5685 case glu::PRECISION_LAST:
5686 case glu::PRECISION_MEDIUMP: return 16;
5687 case glu::PRECISION_HIGHP: return 32;
5688 default:
5689 DE_ASSERT(false);
5690 return 0;
5691 }
5692 }
5693 };
5694
5695 template <>
5696 class DefaultSampling<float> : public Sampling<float>
5697 {
5698 public:
5699 float genRandom (const FloatFormat& format, const Precision prec, Random& rnd, const Interval& inputRange) const;
5700 void genFixeds (const FloatFormat& format, const Precision prec, vector<float>& dst, const Interval& inputRange) const;
5701 void removeNotInRange (vector<float>& dst, const Interval& inputRange, const Precision prec) const;
5702 };
5703
5704 template <>
5705 class DefaultSampling<double> : public Sampling<double>
5706 {
5707 public:
5708 double genRandom (const FloatFormat& format, const Precision prec, Random& rnd, const Interval& inputRange) const;
5709 void genFixeds (const FloatFormat& format, const Precision prec, vector<double>& dst, const Interval& inputRange) const;
5710 void removeNotInRange (vector<double>& dst, const Interval& inputRange, const Precision prec) const;
5711 };
5712
isDenorm16(deFloat16 v)5713 static bool isDenorm16(deFloat16 v)
5714 {
5715 const deUint16 mantissa = 0x03FF;
5716 const deUint16 exponent = 0x7C00;
5717 return ((exponent & v) == 0 && (mantissa & v) != 0);
5718 }
5719
5720 //! Generate a random double from a reasonable general-purpose distribution.
randomDouble(const FloatFormat & format,Random & rnd,const Interval & inputRange)5721 double randomDouble(const FloatFormat& format, Random& rnd, const Interval& inputRange)
5722 {
5723 // No testing of subnormals. TODO: Could integrate float controls for some operations.
5724 const int minExp = format.getMinExp();
5725 const int maxExp = format.getMaxExp();
5726 const bool haveSubnormal = false;
5727 const double midpoint = inputRange.midpoint();
5728
5729 // Choose exponent so that the cumulative distribution is cubic.
5730 // This makes the probability distribution quadratic, with the peak centered on zero.
5731 const double minRoot = deCbrt(minExp - 0.5 - (haveSubnormal ? 1.0 : 0.0));
5732 const double maxRoot = deCbrt(maxExp + 0.5);
5733 const int fractionBits = format.getFractionBits();
5734 const int exp = int(deRoundEven(dePow(rnd.getDouble(minRoot, maxRoot), 3.0)));
5735
5736 // Generate some occasional special numbers
5737 switch (rnd.getInt(0, 64))
5738 {
5739 case 0: return inputRange.contains(0) ? 0 : midpoint;
5740 case 1: return inputRange.contains(TCU_INFINITY) ? TCU_INFINITY : midpoint;
5741 case 2: return inputRange.contains(-TCU_INFINITY) ? -TCU_INFINITY : midpoint;
5742 case 3: return inputRange.contains(TCU_NAN) ? TCU_NAN : midpoint;
5743 default: break;
5744 }
5745
5746 DE_ASSERT(fractionBits < std::numeric_limits<double>::digits);
5747
5748 // Normal number
5749 double base = deLdExp(1.0, exp);
5750 double quantum = deLdExp(1.0, exp - fractionBits); // smallest representable difference in the binade
5751 double significand = 0.0;
5752 switch (rnd.getInt(0, 16))
5753 {
5754 case 0: // The highest number in this binade, significand is all bits one.
5755 significand = base - quantum;
5756 break;
5757 case 1: // Significand is one.
5758 significand = quantum;
5759 break;
5760 case 2: // Significand is zero.
5761 significand = 0.0;
5762 break;
5763 default: // Random (evenly distributed) significand.
5764 {
5765 deUint64 intFraction = rnd.getUint64() & ((1 << fractionBits) - 1);
5766 significand = double(intFraction) * quantum;
5767 }
5768 }
5769
5770 // Produce positive numbers more often than negative.
5771 double value = (rnd.getInt(0, 3) == 0 ? -1.0 : 1.0) * (base + significand);
5772 return inputRange.contains(value) ? value : midpoint;
5773 }
5774
5775 //! Generate a random float from a reasonable general-purpose distribution.
genRandom(const FloatFormat & format,Precision prec,Random & rnd,const Interval & inputRange) const5776 float DefaultSampling<float>::genRandom (const FloatFormat& format,
5777 Precision prec,
5778 Random& rnd,
5779 const Interval& inputRange) const
5780 {
5781 DE_UNREF(prec);
5782 return (float)randomDouble(format, rnd, inputRange);
5783 }
5784
5785 //! Generate a standard set of floats that should always be tested.
genFixeds(const FloatFormat & format,const Precision prec,vector<float> & dst,const Interval & inputRange) const5786 void DefaultSampling<float>::genFixeds (const FloatFormat& format, const Precision prec, vector<float>& dst, const Interval& inputRange) const
5787 {
5788 const int minExp = format.getMinExp();
5789 const int maxExp = format.getMaxExp();
5790 const int fractionBits = format.getFractionBits();
5791 const float minQuantum = deFloatLdExp(1.0f, minExp - fractionBits);
5792 const float minNormalized = deFloatLdExp(1.0f, minExp);
5793 const float maxQuantum = deFloatLdExp(1.0f, maxExp - fractionBits);
5794
5795 // NaN
5796 dst.push_back(TCU_NAN);
5797 // Zero
5798 dst.push_back(0.0f);
5799
5800 for (int sign = -1; sign <= 1; sign += 2)
5801 {
5802 // Smallest normalized
5803 dst.push_back((float)sign * minNormalized);
5804
5805 // Next smallest normalized
5806 dst.push_back((float)sign * (minNormalized + minQuantum));
5807
5808 dst.push_back((float)sign * 0.5f);
5809 dst.push_back((float)sign * 1.0f);
5810 dst.push_back((float)sign * 2.0f);
5811
5812 // Largest number
5813 dst.push_back((float)sign * (deFloatLdExp(1.0f, maxExp) +
5814 (deFloatLdExp(1.0f, maxExp) - maxQuantum)));
5815
5816 dst.push_back((float)sign * TCU_INFINITY);
5817 }
5818 removeNotInRange(dst, inputRange, prec);
5819 }
5820
removeNotInRange(vector<float> & dst,const Interval & inputRange,const Precision prec) const5821 void DefaultSampling<float>::removeNotInRange (vector<float>& dst, const Interval& inputRange, const Precision prec) const
5822 {
5823 for (vector<float>::iterator it = dst.begin(); it < dst.end();)
5824 {
5825 // Remove out of range values. PRECISION_LAST means this is an FP16 test so remove any values that
5826 // will be denorms when converted to FP16. (This is used in the precision_fp16_storage32b test group).
5827 if ( !inputRange.contains(static_cast<double>(*it)) || (prec == glu::PRECISION_LAST && isDenorm16(deFloat32To16Round(*it, DE_ROUNDINGMODE_TO_ZERO))))
5828 it = dst.erase(it);
5829 else
5830 ++it;
5831 }
5832 }
5833
5834 //! Generate a random double from a reasonable general-purpose distribution.
genRandom(const FloatFormat & format,Precision prec,Random & rnd,const Interval & inputRange) const5835 double DefaultSampling<double>::genRandom (const FloatFormat& format,
5836 Precision prec,
5837 Random& rnd,
5838 const Interval& inputRange) const
5839 {
5840 DE_UNREF(prec);
5841 return randomDouble(format, rnd, inputRange);
5842 }
5843
5844 //! Generate a standard set of floats that should always be tested.
genFixeds(const FloatFormat & format,const Precision prec,vector<double> & dst,const Interval & inputRange) const5845 void DefaultSampling<double>::genFixeds (const FloatFormat& format, const Precision prec, vector<double>& dst, const Interval& inputRange) const
5846 {
5847 const int minExp = format.getMinExp();
5848 const int maxExp = format.getMaxExp();
5849 const int fractionBits = format.getFractionBits();
5850 const double minQuantum = deLdExp(1.0, minExp - fractionBits);
5851 const double minNormalized = deLdExp(1.0, minExp);
5852 const double maxQuantum = deLdExp(1.0, maxExp - fractionBits);
5853
5854 // NaN
5855 dst.push_back(TCU_NAN);
5856 // Zero
5857 dst.push_back(0.0);
5858
5859 for (int sign = -1; sign <= 1; sign += 2)
5860 {
5861 // Smallest normalized
5862 dst.push_back((double)sign * minNormalized);
5863
5864 // Next smallest normalized
5865 dst.push_back((double)sign * (minNormalized + minQuantum));
5866
5867 dst.push_back((double)sign * 0.5);
5868 dst.push_back((double)sign * 1.0);
5869 dst.push_back((double)sign * 2.0);
5870
5871 // Largest number
5872 dst.push_back((double)sign * (deLdExp(1.0, maxExp) + (deLdExp(1.0, maxExp) - maxQuantum)));
5873
5874 dst.push_back((double)sign * TCU_INFINITY);
5875 }
5876 removeNotInRange(dst, inputRange, prec);
5877 }
5878
removeNotInRange(vector<double> & dst,const Interval & inputRange,const Precision) const5879 void DefaultSampling<double>::removeNotInRange (vector<double>& dst, const Interval& inputRange, const Precision) const
5880 {
5881 for (vector<double>::iterator it = dst.begin(); it < dst.end();)
5882 {
5883 if ( !inputRange.contains(*it) )
5884 it = dst.erase(it);
5885 else
5886 ++it;
5887 }
5888 }
5889
5890 template <>
5891 class DefaultSampling<deFloat16> : public Sampling<deFloat16>
5892 {
5893 public:
5894 deFloat16 genRandom (const FloatFormat& format, const Precision prec, Random& rnd, const Interval& inputRange) const;
5895 void genFixeds (const FloatFormat& format, const Precision prec, vector<deFloat16>& dst, const Interval& inputRange) const;
5896 private:
5897 void removeNotInRange(vector<deFloat16>& dst, const Interval& inputRange, const Precision prec) const;
5898 };
5899
5900 //! Generate a random float from a reasonable general-purpose distribution.
genRandom(const FloatFormat & format,const Precision prec,Random & rnd,const Interval & inputRange) const5901 deFloat16 DefaultSampling<deFloat16>::genRandom (const FloatFormat& format, const Precision prec,
5902 Random& rnd, const Interval& inputRange) const
5903 {
5904 DE_UNREF(prec);
5905 return deFloat64To16Round(randomDouble(format, rnd, inputRange), DE_ROUNDINGMODE_TO_NEAREST_EVEN);
5906 }
5907
5908 //! Generate a standard set of floats that should always be tested.
genFixeds(const FloatFormat & format,const Precision prec,vector<deFloat16> & dst,const Interval & inputRange) const5909 void DefaultSampling<deFloat16>::genFixeds (const FloatFormat& format, const Precision prec, vector<deFloat16>& dst, const Interval& inputRange) const
5910 {
5911 dst.push_back(deUint16(0x3E00)); //1.5
5912 dst.push_back(deUint16(0x3D00)); //1.25
5913 dst.push_back(deUint16(0x3F00)); //1.75
5914 // Zero
5915 dst.push_back(deUint16(0x0000));
5916 dst.push_back(deUint16(0x8000));
5917 // Infinity
5918 dst.push_back(deUint16(0x7c00));
5919 dst.push_back(deUint16(0xfc00));
5920 // SNaN
5921 dst.push_back(deUint16(0x7c0f));
5922 dst.push_back(deUint16(0xfc0f));
5923 // QNaN
5924 dst.push_back(deUint16(0x7cf0));
5925 dst.push_back(deUint16(0xfcf0));
5926 // Normalized
5927 dst.push_back(deUint16(0x0401));
5928 dst.push_back(deUint16(0x8401));
5929 // Some normal number
5930 dst.push_back(deUint16(0x14cb));
5931 dst.push_back(deUint16(0x94cb));
5932
5933 const int minExp = format.getMinExp();
5934 const int maxExp = format.getMaxExp();
5935 const int fractionBits = format.getFractionBits();
5936 const float minQuantum = deFloatLdExp(1.0f, minExp - fractionBits);
5937 const float minNormalized = deFloatLdExp(1.0f, minExp);
5938 const float maxQuantum = deFloatLdExp(1.0f, maxExp - fractionBits);
5939
5940 for (float sign = -1.0; sign <= 1.0f; sign += 2.0f)
5941 {
5942 // Smallest normalized
5943 dst.push_back(deFloat32To16Round(sign * minNormalized, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5944
5945 // Next smallest normalized
5946 dst.push_back(deFloat32To16Round(sign * (minNormalized + minQuantum), DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5947
5948 dst.push_back(deFloat32To16Round(sign * 0.5f, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5949 dst.push_back(deFloat32To16Round(sign * 1.0f, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5950 dst.push_back(deFloat32To16Round(sign * 2.0f, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5951
5952 // Largest number
5953 dst.push_back(deFloat32To16Round(sign * (deFloatLdExp(1.0f, maxExp) +
5954 (deFloatLdExp(1.0f, maxExp) - maxQuantum)), DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5955
5956 dst.push_back(deFloat32To16Round(sign * TCU_INFINITY, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5957 }
5958 removeNotInRange(dst, inputRange, prec);
5959 }
5960
removeNotInRange(vector<deFloat16> & dst,const Interval & inputRange,const Precision) const5961 void DefaultSampling<deFloat16>::removeNotInRange(vector<deFloat16>& dst, const Interval& inputRange, const Precision) const
5962 {
5963 for (vector<deFloat16>::iterator it = dst.begin(); it < dst.end();)
5964 {
5965 if (inputRange.contains(static_cast<double>(*it)))
5966 ++it;
5967 else
5968 it = dst.erase(it);
5969 }
5970 }
5971
5972 template <typename T, int Size>
5973 class DefaultSampling<Vector<T, Size> > : public Sampling<Vector<T, Size> >
5974 {
5975 public:
5976 typedef Vector<T, Size> Value;
5977
genRandom(const FloatFormat & fmt,const Precision prec,Random & rnd,const Interval & inputRange) const5978 Value genRandom (const FloatFormat& fmt, const Precision prec, Random& rnd, const Interval& inputRange) const
5979 {
5980 Value ret;
5981
5982 for (int ndx = 0; ndx < Size; ++ndx)
5983 ret[ndx] = instance<DefaultSampling<T> >().genRandom(fmt, prec, rnd, inputRange);
5984
5985 return ret;
5986 }
5987
genFixeds(const FloatFormat & fmt,const Precision prec,vector<Value> & dst,const Interval & inputRange) const5988 void genFixeds (const FloatFormat& fmt, const Precision prec, vector<Value>& dst, const Interval& inputRange) const
5989 {
5990 vector<T> scalars;
5991
5992 instance<DefaultSampling<T> >().genFixeds(fmt, prec, scalars, inputRange);
5993
5994 for (size_t scalarNdx = 0; scalarNdx < scalars.size(); ++scalarNdx)
5995 dst.push_back(Value(scalars[scalarNdx]));
5996 }
5997 };
5998
5999 template <typename T, int Rows, int Columns>
6000 class DefaultSampling<Matrix<T, Rows, Columns> > : public Sampling<Matrix<T, Rows, Columns> >
6001 {
6002 public:
6003 typedef Matrix<T, Rows, Columns> Value;
6004
genRandom(const FloatFormat & fmt,const Precision prec,Random & rnd,const Interval & inputRange) const6005 Value genRandom (const FloatFormat& fmt, const Precision prec, Random& rnd, const Interval& inputRange) const
6006 {
6007 Value ret;
6008
6009 for (int rowNdx = 0; rowNdx < Rows; ++rowNdx)
6010 for (int colNdx = 0; colNdx < Columns; ++colNdx)
6011 ret(rowNdx, colNdx) = instance<DefaultSampling<T> >().genRandom(fmt, prec, rnd, inputRange);
6012
6013 return ret;
6014 }
6015
genFixeds(const FloatFormat & fmt,const Precision prec,vector<Value> & dst,const Interval & inputRange) const6016 void genFixeds (const FloatFormat& fmt, const Precision prec, vector<Value>& dst, const Interval& inputRange) const
6017 {
6018 vector<T> scalars;
6019
6020 instance<DefaultSampling<T> >().genFixeds(fmt, prec, scalars, inputRange);
6021
6022 for (size_t scalarNdx = 0; scalarNdx < scalars.size(); ++scalarNdx)
6023 dst.push_back(Value(scalars[scalarNdx]));
6024
6025 if (Columns == Rows)
6026 {
6027 Value mat (T(0.0));
6028 T x = T(1.0f);
6029 mat[0][0] = x;
6030 for (int ndx = 0; ndx < Columns; ++ndx)
6031 {
6032 mat[Columns-1-ndx][ndx] = x;
6033 x = static_cast<T>(x * static_cast<T>(2.0f));
6034 }
6035 dst.push_back(mat);
6036 }
6037 }
6038 };
6039
6040 struct CaseContext
6041 {
CaseContextvkt::shaderexecutor::CaseContext6042 CaseContext (const string& name_,
6043 TestContext& testContext_,
6044 const FloatFormat& floatFormat_,
6045 const FloatFormat& highpFormat_,
6046 const Precision precision_,
6047 const ShaderType shaderType_,
6048 const size_t numRandoms_,
6049 const PrecisionTestFeatures precisionTestFeatures_ = PRECISION_TEST_FEATURES_NONE,
6050 const bool isPackFloat16b_ = false,
6051 const bool isFloat64b_ = false)
6052 : name (name_)
6053 , testContext (testContext_)
6054 , floatFormat (floatFormat_)
6055 , highpFormat (highpFormat_)
6056 , precision (precision_)
6057 , shaderType (shaderType_)
6058 , numRandoms (numRandoms_)
6059 , inputRange (-TCU_INFINITY, TCU_INFINITY)
6060 , precisionTestFeatures (precisionTestFeatures_)
6061 , isPackFloat16b (isPackFloat16b_)
6062 , isFloat64b (isFloat64b_)
6063 {}
6064
6065 string name;
6066 TestContext& testContext;
6067 FloatFormat floatFormat;
6068 FloatFormat highpFormat;
6069 Precision precision;
6070 ShaderType shaderType;
6071 size_t numRandoms;
6072 Interval inputRange;
6073 PrecisionTestFeatures precisionTestFeatures;
6074 bool isPackFloat16b;
6075 bool isFloat64b;
6076 };
6077
6078 template<typename In0_ = Void, typename In1_ = Void, typename In2_ = Void, typename In3_ = Void>
6079 struct InTypes
6080 {
6081 typedef In0_ In0;
6082 typedef In1_ In1;
6083 typedef In2_ In2;
6084 typedef In3_ In3;
6085 };
6086
6087 template <typename In>
numInputs(void)6088 int numInputs (void)
6089 {
6090 return (!isTypeValid<typename In::In0>() ? 0 :
6091 !isTypeValid<typename In::In1>() ? 1 :
6092 !isTypeValid<typename In::In2>() ? 2 :
6093 !isTypeValid<typename In::In3>() ? 3 :
6094 4);
6095 }
6096
6097 template<typename Out0_, typename Out1_ = Void>
6098 struct OutTypes
6099 {
6100 typedef Out0_ Out0;
6101 typedef Out1_ Out1;
6102 };
6103
6104 template <typename Out>
numOutputs(void)6105 int numOutputs (void)
6106 {
6107 return (!isTypeValid<typename Out::Out0>() ? 0 :
6108 !isTypeValid<typename Out::Out1>() ? 1 :
6109 2);
6110 }
6111
6112 template<typename In>
6113 struct Inputs
6114 {
6115 vector<typename In::In0> in0;
6116 vector<typename In::In1> in1;
6117 vector<typename In::In2> in2;
6118 vector<typename In::In3> in3;
6119 };
6120
6121 template<typename Out>
6122 struct Outputs
6123 {
Outputsvkt::shaderexecutor::Outputs6124 Outputs (size_t size) : out0(size), out1(size) {}
6125
6126 vector<typename Out::Out0> out0;
6127 vector<typename Out::Out1> out1;
6128 };
6129
6130 template<typename In, typename Out>
6131 struct Variables
6132 {
6133 VariableP<typename In::In0> in0;
6134 VariableP<typename In::In1> in1;
6135 VariableP<typename In::In2> in2;
6136 VariableP<typename In::In3> in3;
6137 VariableP<typename Out::Out0> out0;
6138 VariableP<typename Out::Out1> out1;
6139 };
6140
6141 template<typename In>
6142 struct Samplings
6143 {
Samplingsvkt::shaderexecutor::Samplings6144 Samplings (const Sampling<typename In::In0>& in0_,
6145 const Sampling<typename In::In1>& in1_,
6146 const Sampling<typename In::In2>& in2_,
6147 const Sampling<typename In::In3>& in3_)
6148 : in0 (in0_), in1 (in1_), in2 (in2_), in3 (in3_) {}
6149
6150 const Sampling<typename In::In0>& in0;
6151 const Sampling<typename In::In1>& in1;
6152 const Sampling<typename In::In2>& in2;
6153 const Sampling<typename In::In3>& in3;
6154 };
6155
6156 template<typename In>
6157 struct DefaultSamplings : Samplings<In>
6158 {
DefaultSamplingsvkt::shaderexecutor::DefaultSamplings6159 DefaultSamplings (void)
6160 : Samplings<In>(instance<DefaultSampling<typename In::In0> >(),
6161 instance<DefaultSampling<typename In::In1> >(),
6162 instance<DefaultSampling<typename In::In2> >(),
6163 instance<DefaultSampling<typename In::In3> >()) {}
6164 };
6165
6166 template <typename In, typename Out>
6167 class BuiltinPrecisionCaseTestInstance : public TestInstance
6168 {
6169 public:
BuiltinPrecisionCaseTestInstance(Context & context,const CaseContext caseCtx,const ShaderSpec & shaderSpec,const Variables<In,Out> variables,const Samplings<In> & samplings,const StatementP stmt,bool modularOp=false)6170 BuiltinPrecisionCaseTestInstance (Context& context,
6171 const CaseContext caseCtx,
6172 const ShaderSpec& shaderSpec,
6173 const Variables<In, Out> variables,
6174 const Samplings<In>& samplings,
6175 const StatementP stmt,
6176 bool modularOp = false)
6177 : TestInstance (context)
6178 , m_caseCtx (caseCtx)
6179 , m_variables (variables)
6180 , m_samplings (samplings)
6181 , m_stmt (stmt)
6182 , m_executor (createExecutor(context, caseCtx.shaderType, shaderSpec))
6183 , m_modularOp (modularOp)
6184 {
6185 }
6186 virtual tcu::TestStatus iterate (void);
6187
6188 protected:
6189 CaseContext m_caseCtx;
6190 Variables<In, Out> m_variables;
6191 const Samplings<In>& m_samplings;
6192 StatementP m_stmt;
6193 de::UniquePtr<ShaderExecutor> m_executor;
6194 bool m_modularOp;
6195 };
6196
6197 template<class In, class Out>
iterate(void)6198 tcu::TestStatus BuiltinPrecisionCaseTestInstance<In, Out>::iterate (void)
6199 {
6200 typedef typename In::In0 In0;
6201 typedef typename In::In1 In1;
6202 typedef typename In::In2 In2;
6203 typedef typename In::In3 In3;
6204 typedef typename Out::Out0 Out0;
6205 typedef typename Out::Out1 Out1;
6206
6207 areFeaturesSupported(m_context, m_caseCtx.precisionTestFeatures);
6208 Inputs<In> inputs = generateInputs(m_samplings, m_caseCtx.floatFormat, m_caseCtx.precision, m_caseCtx.numRandoms, 0xdeadbeefu + m_caseCtx.testContext.getCommandLine().getBaseSeed(), m_caseCtx.inputRange);
6209 const FloatFormat& fmt = m_caseCtx.floatFormat;
6210 const int inCount = numInputs<In>();
6211 const int outCount = numOutputs<Out>();
6212 const size_t numValues = (inCount > 0) ? inputs.in0.size() : 1;
6213 Outputs<Out> outputs (numValues);
6214 const FloatFormat highpFmt = m_caseCtx.highpFormat;
6215 const int maxMsgs = 100;
6216 int numErrors = 0;
6217 Environment env; // Hoisted out of the inner loop for optimization.
6218 ResultCollector status;
6219 TestLog& testLog = m_context.getTestContext().getLog();
6220
6221 // Module operations need exactly two inputs and have exactly one output.
6222 if (m_modularOp)
6223 {
6224 DE_ASSERT(inCount == 2);
6225 DE_ASSERT(outCount == 1);
6226 }
6227
6228 const void* inputArr[] =
6229 {
6230 inputs.in0.data(), inputs.in1.data(), inputs.in2.data(), inputs.in3.data(),
6231 };
6232 void* outputArr[] =
6233 {
6234 outputs.out0.data(), outputs.out1.data(),
6235 };
6236
6237 // Print out the statement and its definitions
6238 testLog << TestLog::Message << "Statement: " << m_stmt << TestLog::EndMessage;
6239 {
6240 ostringstream oss;
6241 FuncSet funcs;
6242
6243 m_stmt->getUsedFuncs(funcs);
6244 for (FuncSet::const_iterator it = funcs.begin(); it != funcs.end(); ++it)
6245 {
6246 (*it)->printDefinition(oss);
6247 }
6248 if (!funcs.empty())
6249 testLog << TestLog::Message << "Reference definitions:\n" << oss.str()
6250 << TestLog::EndMessage;
6251 }
6252 switch (inCount)
6253 {
6254 case 4:
6255 DE_ASSERT(inputs.in3.size() == numValues);
6256 // Fallthrough
6257 case 3:
6258 DE_ASSERT(inputs.in2.size() == numValues);
6259 // Fallthrough
6260 case 2:
6261 DE_ASSERT(inputs.in1.size() == numValues);
6262 // Fallthrough
6263 case 1:
6264 DE_ASSERT(inputs.in0.size() == numValues);
6265 // Fallthrough
6266 default:
6267 break;
6268 }
6269
6270 m_executor->execute(int(numValues), inputArr, outputArr);
6271
6272 // Initialize environment with unused values so we don't need to bind in inner loop.
6273 {
6274 const typename Traits<In0>::IVal in0;
6275 const typename Traits<In1>::IVal in1;
6276 const typename Traits<In2>::IVal in2;
6277 const typename Traits<In3>::IVal in3;
6278 const typename Traits<Out0>::IVal reference0;
6279 const typename Traits<Out1>::IVal reference1;
6280
6281 env.bind(*m_variables.in0, in0);
6282 env.bind(*m_variables.in1, in1);
6283 env.bind(*m_variables.in2, in2);
6284 env.bind(*m_variables.in3, in3);
6285 env.bind(*m_variables.out0, reference0);
6286 env.bind(*m_variables.out1, reference1);
6287 }
6288
6289 // For each input tuple, compute output reference interval and compare
6290 // shader output to the reference.
6291 for (size_t valueNdx = 0; valueNdx < numValues; valueNdx++)
6292 {
6293 bool result = true;
6294 const bool isInput16Bit = m_executor->areInputs16Bit();
6295 const bool isInput64Bit = m_executor->areInputs64Bit();
6296
6297 DE_ASSERT(!(isInput16Bit && isInput64Bit));
6298
6299 typename Traits<Out0>::IVal reference0;
6300 typename Traits<Out1>::IVal reference1;
6301
6302 if (valueNdx % (size_t)TOUCH_WATCHDOG_VALUE_FREQUENCY == 0)
6303 m_context.getTestContext().touchWatchdog();
6304
6305 env.lookup(*m_variables.in0) = convert<In0>(fmt, round(fmt, inputs.in0[valueNdx]));
6306 env.lookup(*m_variables.in1) = convert<In1>(fmt, round(fmt, inputs.in1[valueNdx]));
6307 env.lookup(*m_variables.in2) = convert<In2>(fmt, round(fmt, inputs.in2[valueNdx]));
6308 env.lookup(*m_variables.in3) = convert<In3>(fmt, round(fmt, inputs.in3[valueNdx]));
6309
6310 {
6311 EvalContext ctx (fmt, m_caseCtx.precision, env, 0);
6312 m_stmt->execute(ctx);
6313
6314 switch (outCount)
6315 {
6316 case 2:
6317 reference1 = convert<Out1>(highpFmt, env.lookup(*m_variables.out1));
6318 if (!status.check(contains(reference1, outputs.out1[valueNdx], m_caseCtx.isPackFloat16b), "Shader output 1 is outside acceptable range"))
6319 result = false;
6320 // Fallthrough
6321 case 1:
6322 {
6323 // Pass b from mod(a, b) if we are in the modulo operation.
6324 const tcu::Maybe<In1> modularDivisor = (m_modularOp ? tcu::just(inputs.in1[valueNdx]) : tcu::Nothing);
6325
6326 reference0 = convert<Out0>(highpFmt, env.lookup(*m_variables.out0));
6327 if (!status.check(contains(reference0, outputs.out0[valueNdx], m_caseCtx.isPackFloat16b, modularDivisor), "Shader output 0 is outside acceptable range"))
6328 {
6329 m_stmt->failed(ctx);
6330 reference0 = convert<Out0>(highpFmt, env.lookup(*m_variables.out0));
6331 if (!status.check(contains(reference0, outputs.out0[valueNdx], m_caseCtx.isPackFloat16b, modularDivisor), "Shader output 0 is outside acceptable range"))
6332 result = false;
6333 }
6334 }
6335 // Fallthrough
6336 default: break;
6337 }
6338
6339 }
6340 if (!result)
6341 ++numErrors;
6342
6343 if ((!result && numErrors <= maxMsgs) || GLS_LOG_ALL_RESULTS)
6344 {
6345 MessageBuilder builder = testLog.message();
6346
6347 builder << (result ? "Passed" : "Failed") << " sample:\n";
6348
6349 if (inCount > 0)
6350 {
6351 builder << "\t" << m_variables.in0->getName() << " = "
6352 << (isInput64Bit ? value64ToString(highpFmt, inputs.in0[valueNdx]) : (isInput16Bit ? value16ToString(highpFmt, inputs.in0[valueNdx]) : value32ToString(highpFmt, inputs.in0[valueNdx]))) << "\n";
6353 }
6354
6355 if (inCount > 1)
6356 {
6357 builder << "\t" << m_variables.in1->getName() << " = "
6358 << (isInput64Bit ? value64ToString(highpFmt, inputs.in1[valueNdx]) : (isInput16Bit ? value16ToString(highpFmt, inputs.in1[valueNdx]) : value32ToString(highpFmt, inputs.in1[valueNdx]))) << "\n";
6359 }
6360
6361 if (inCount > 2)
6362 {
6363 builder << "\t" << m_variables.in2->getName() << " = "
6364 << (isInput64Bit ? value64ToString(highpFmt, inputs.in2[valueNdx]) : (isInput16Bit ? value16ToString(highpFmt, inputs.in2[valueNdx]) : value32ToString(highpFmt, inputs.in2[valueNdx]))) << "\n";
6365 }
6366
6367 if (inCount > 3)
6368 {
6369 builder << "\t" << m_variables.in3->getName() << " = "
6370 << (isInput64Bit ? value64ToString(highpFmt, inputs.in3[valueNdx]) : (isInput16Bit ? value16ToString(highpFmt, inputs.in3[valueNdx]) : value32ToString(highpFmt, inputs.in3[valueNdx]))) << "\n";
6371 }
6372
6373 if (outCount > 0)
6374 {
6375 if (m_executor->spirvCase() == SPIRV_CASETYPE_COMPARE)
6376 {
6377 builder << "Output:\n"
6378 << comparisonMessage(outputs.out0[valueNdx])
6379 << "Expected result:\n"
6380 << comparisonMessageInterval<typename Out::Out0>(reference0) << "\n";
6381 }
6382 else
6383 {
6384 builder << "\t" << m_variables.out0->getName() << " = "
6385 << (m_executor->isOutput64Bit(0u) ? value64ToString(highpFmt, outputs.out0[valueNdx]) : (m_executor->isOutput16Bit(0u) || m_caseCtx.isPackFloat16b ? value16ToString(highpFmt, outputs.out0[valueNdx]) : value32ToString(highpFmt, outputs.out0[valueNdx]))) << "\n"
6386 << "\tExpected range: "
6387 << intervalToString<typename Out::Out0>(highpFmt, reference0) << "\n";
6388 }
6389 }
6390
6391 if (outCount > 1)
6392 {
6393 builder << "\t" << m_variables.out1->getName() << " = "
6394 << (m_executor->isOutput64Bit(1u) ? value64ToString(highpFmt, outputs.out1[valueNdx]) : (m_executor->isOutput16Bit(1u) || m_caseCtx.isPackFloat16b ? value16ToString(highpFmt, outputs.out1[valueNdx]) : value32ToString(highpFmt, outputs.out1[valueNdx]))) << "\n"
6395 << "\tExpected range: "
6396 << intervalToString<typename Out::Out1>(highpFmt, reference1) << "\n";
6397 }
6398
6399 builder << TestLog::EndMessage;
6400 }
6401 }
6402
6403 if (numErrors > maxMsgs)
6404 {
6405 testLog << TestLog::Message << "(Skipped " << (numErrors - maxMsgs) << " messages.)"
6406 << TestLog::EndMessage;
6407 }
6408
6409 if (numErrors == 0)
6410 {
6411 testLog << TestLog::Message << "All " << numValues << " inputs passed."
6412 << TestLog::EndMessage;
6413 }
6414 else
6415 {
6416 testLog << TestLog::Message << numErrors << "/" << numValues << " inputs failed."
6417 << TestLog::EndMessage;
6418 }
6419
6420 if (numErrors)
6421 return tcu::TestStatus::fail(de::toString(numErrors) + string(" test failed. Check log for the details"));
6422 else
6423 return tcu::TestStatus::pass("Pass");
6424
6425 }
6426
6427 class PrecisionCase : public TestCase
6428 {
6429 protected:
PrecisionCase(const CaseContext & context,const string & name,const Interval & inputRange,const string & extension="")6430 PrecisionCase (const CaseContext& context, const string& name, const Interval& inputRange, const string& extension = "")
6431 : TestCase (context.testContext, name.c_str(), name.c_str())
6432 , m_ctx (context)
6433 , m_extension (extension)
6434 {
6435 m_ctx.inputRange = inputRange;
6436 m_spec.packFloat16Bit = context.isPackFloat16b;
6437 }
6438
initPrograms(vk::SourceCollections & programCollection) const6439 virtual void initPrograms (vk::SourceCollections& programCollection) const
6440 {
6441 generateSources(m_ctx.shaderType, m_spec, programCollection);
6442 }
6443
getFormat(void) const6444 const FloatFormat& getFormat (void) const { return m_ctx.floatFormat; }
6445
6446 template <typename In, typename Out>
6447 void testStatement (const Variables<In, Out>& variables, const Statement& stmt, SpirVCaseT spirvCase);
6448
6449 template<typename T>
makeSymbol(const Variable<T> & variable)6450 Symbol makeSymbol (const Variable<T>& variable)
6451 {
6452 return Symbol(variable.getName(), getVarTypeOf<T>(m_ctx.precision));
6453 }
6454
6455 CaseContext m_ctx;
6456 const string m_extension;
6457 ShaderSpec m_spec;
6458 };
6459
6460 template <typename In, typename Out>
testStatement(const Variables<In,Out> & variables,const Statement & stmt,SpirVCaseT spirvCase)6461 void PrecisionCase::testStatement (const Variables<In, Out>& variables, const Statement& stmt, SpirVCaseT spirvCase)
6462 {
6463 const int inCount = numInputs<In>();
6464 const int outCount = numOutputs<Out>();
6465 Environment env; // Hoisted out of the inner loop for optimization.
6466
6467 // Initialize ShaderSpec from precision, variables and statement.
6468 if (m_ctx.precision != glu::PRECISION_LAST)
6469 {
6470 ostringstream os;
6471 os << "precision " << glu::getPrecisionName(m_ctx.precision) << " float;\n";
6472 m_spec.globalDeclarations = os.str();
6473 }
6474
6475 if (!m_extension.empty())
6476 m_spec.globalDeclarations = "#extension " + m_extension + " : require\n";
6477
6478 m_spec.inputs.resize(inCount);
6479
6480 switch (inCount)
6481 {
6482 case 4:
6483 m_spec.inputs[3] = makeSymbol(*variables.in3);
6484 // Fallthrough
6485 case 3:
6486 m_spec.inputs[2] = makeSymbol(*variables.in2);
6487 // Fallthrough
6488 case 2:
6489 m_spec.inputs[1] = makeSymbol(*variables.in1);
6490 // Fallthrough
6491 case 1:
6492 m_spec.inputs[0] = makeSymbol(*variables.in0);
6493 // Fallthrough
6494 default:
6495 break;
6496 }
6497
6498 bool inputs16Bit = false;
6499 for (vector<Symbol>::const_iterator symIter = m_spec.inputs.begin(); symIter != m_spec.inputs.end(); ++symIter)
6500 inputs16Bit = inputs16Bit || glu::isDataTypeFloat16OrVec(symIter->varType.getBasicType());
6501
6502 if (inputs16Bit || m_spec.packFloat16Bit)
6503 m_spec.globalDeclarations += "#extension GL_EXT_shader_explicit_arithmetic_types: require\n";
6504
6505 m_spec.outputs.resize(outCount);
6506
6507 switch (outCount)
6508 {
6509 case 2:
6510 m_spec.outputs[1] = makeSymbol(*variables.out1);
6511 // Fallthrough
6512 case 1:
6513 m_spec.outputs[0] = makeSymbol(*variables.out0);
6514 // Fallthrough
6515 default:
6516 break;
6517 }
6518
6519 m_spec.source = de::toString(stmt);
6520 m_spec.spirvCase = spirvCase;
6521 }
6522
6523 template <typename T>
6524 struct InputLess
6525 {
operator ()vkt::shaderexecutor::InputLess6526 bool operator() (const T& val1, const T& val2) const
6527 {
6528 return val1 < val2;
6529 }
6530 };
6531
6532 template <typename T>
inputLess(const T & val1,const T & val2)6533 bool inputLess (const T& val1, const T& val2)
6534 {
6535 return InputLess<T>()(val1, val2);
6536 }
6537
6538 template <>
6539 struct InputLess<float>
6540 {
operator ()vkt::shaderexecutor::InputLess6541 bool operator() (const float& val1, const float& val2) const
6542 {
6543 if (deIsNaN(val1))
6544 return false;
6545 if (deIsNaN(val2))
6546 return true;
6547 return val1 < val2;
6548 }
6549 };
6550
6551 template <typename T, int Size>
6552 struct InputLess<Vector<T, Size> >
6553 {
operator ()vkt::shaderexecutor::InputLess6554 bool operator() (const Vector<T, Size>& vec1, const Vector<T, Size>& vec2) const
6555 {
6556 for (int ndx = 0; ndx < Size; ++ndx)
6557 {
6558 if (inputLess(vec1[ndx], vec2[ndx]))
6559 return true;
6560 if (inputLess(vec2[ndx], vec1[ndx]))
6561 return false;
6562 }
6563
6564 return false;
6565 }
6566 };
6567
6568 template <typename T, int Rows, int Cols>
6569 struct InputLess<Matrix<T, Rows, Cols> >
6570 {
operator ()vkt::shaderexecutor::InputLess6571 bool operator() (const Matrix<T, Rows, Cols>& mat1,
6572 const Matrix<T, Rows, Cols>& mat2) const
6573 {
6574 for (int col = 0; col < Cols; ++col)
6575 {
6576 if (inputLess(mat1[col], mat2[col]))
6577 return true;
6578 if (inputLess(mat2[col], mat1[col]))
6579 return false;
6580 }
6581
6582 return false;
6583 }
6584 };
6585
6586 template <typename In>
6587 struct InTuple :
6588 public Tuple4<typename In::In0, typename In::In1, typename In::In2, typename In::In3>
6589 {
InTuplevkt::shaderexecutor::InTuple6590 InTuple (const typename In::In0& in0,
6591 const typename In::In1& in1,
6592 const typename In::In2& in2,
6593 const typename In::In3& in3)
6594 : Tuple4<typename In::In0, typename In::In1, typename In::In2, typename In::In3>
6595 (in0, in1, in2, in3) {}
6596 };
6597
6598 template <typename In>
6599 struct InputLess<InTuple<In> >
6600 {
operator ()vkt::shaderexecutor::InputLess6601 bool operator() (const InTuple<In>& in1, const InTuple<In>& in2) const
6602 {
6603 if (inputLess(in1.a, in2.a))
6604 return true;
6605 if (inputLess(in2.a, in1.a))
6606 return false;
6607 if (inputLess(in1.b, in2.b))
6608 return true;
6609 if (inputLess(in2.b, in1.b))
6610 return false;
6611 if (inputLess(in1.c, in2.c))
6612 return true;
6613 if (inputLess(in2.c, in1.c))
6614 return false;
6615 if (inputLess(in1.d, in2.d))
6616 return true;
6617 return false;
6618 }
6619 };
6620
6621 template<typename In>
generateInputs(const Samplings<In> & samplings,const FloatFormat & floatFormat,Precision intPrecision,size_t numSamples,deUint32 seed,const Interval & inputRange)6622 Inputs<In> generateInputs (const Samplings<In>& samplings,
6623 const FloatFormat& floatFormat,
6624 Precision intPrecision,
6625 size_t numSamples,
6626 deUint32 seed,
6627 const Interval& inputRange)
6628 {
6629 Random rnd(seed);
6630 Inputs<In> ret;
6631 Inputs<In> fixedInputs;
6632 set<InTuple<In>, InputLess<InTuple<In> > > seenInputs;
6633
6634 samplings.in0.genFixeds(floatFormat, intPrecision, fixedInputs.in0, inputRange);
6635 samplings.in1.genFixeds(floatFormat, intPrecision, fixedInputs.in1, inputRange);
6636 samplings.in2.genFixeds(floatFormat, intPrecision, fixedInputs.in2, inputRange);
6637 samplings.in3.genFixeds(floatFormat, intPrecision, fixedInputs.in3, inputRange);
6638
6639 for (size_t ndx0 = 0; ndx0 < fixedInputs.in0.size(); ++ndx0)
6640 {
6641 for (size_t ndx1 = 0; ndx1 < fixedInputs.in1.size(); ++ndx1)
6642 {
6643 for (size_t ndx2 = 0; ndx2 < fixedInputs.in2.size(); ++ndx2)
6644 {
6645 for (size_t ndx3 = 0; ndx3 < fixedInputs.in3.size(); ++ndx3)
6646 {
6647 const InTuple<In> tuple (fixedInputs.in0[ndx0],
6648 fixedInputs.in1[ndx1],
6649 fixedInputs.in2[ndx2],
6650 fixedInputs.in3[ndx3]);
6651
6652 seenInputs.insert(tuple);
6653 ret.in0.push_back(tuple.a);
6654 ret.in1.push_back(tuple.b);
6655 ret.in2.push_back(tuple.c);
6656 ret.in3.push_back(tuple.d);
6657 }
6658 }
6659 }
6660 }
6661
6662 for (size_t ndx = 0; ndx < numSamples; ++ndx)
6663 {
6664 const typename In::In0 in0 = samplings.in0.genRandom(floatFormat, intPrecision, rnd, inputRange);
6665 const typename In::In1 in1 = samplings.in1.genRandom(floatFormat, intPrecision, rnd, inputRange);
6666 const typename In::In2 in2 = samplings.in2.genRandom(floatFormat, intPrecision, rnd, inputRange);
6667 const typename In::In3 in3 = samplings.in3.genRandom(floatFormat, intPrecision, rnd, inputRange);
6668 const InTuple<In> tuple (in0, in1, in2, in3);
6669
6670 if (de::contains(seenInputs, tuple))
6671 continue;
6672
6673 seenInputs.insert(tuple);
6674 ret.in0.push_back(in0);
6675 ret.in1.push_back(in1);
6676 ret.in2.push_back(in2);
6677 ret.in3.push_back(in3);
6678 }
6679
6680 return ret;
6681 }
6682
6683 class FuncCaseBase : public PrecisionCase
6684 {
6685 protected:
FuncCaseBase(const CaseContext & context,const string & name,const FuncBase & func)6686 FuncCaseBase (const CaseContext& context, const string& name, const FuncBase& func)
6687 : PrecisionCase (context, name, func.getInputRange(!context.isFloat64b && (context.precision == glu::PRECISION_LAST || context.isPackFloat16b)), func.getRequiredExtension())
6688 {
6689 }
6690
6691 StatementP m_stmt;
6692 };
6693
6694 template <typename Sig>
6695 class FuncCase : public FuncCaseBase
6696 {
6697 public:
6698 typedef Func<Sig> CaseFunc;
6699 typedef typename Sig::Ret Ret;
6700 typedef typename Sig::Arg0 Arg0;
6701 typedef typename Sig::Arg1 Arg1;
6702 typedef typename Sig::Arg2 Arg2;
6703 typedef typename Sig::Arg3 Arg3;
6704 typedef InTypes<Arg0, Arg1, Arg2, Arg3> In;
6705 typedef OutTypes<Ret> Out;
6706
FuncCase(const CaseContext & context,const string & name,const CaseFunc & func,bool modularOp=false)6707 FuncCase (const CaseContext& context, const string& name, const CaseFunc& func, bool modularOp = false)
6708 : FuncCaseBase (context, name, func)
6709 , m_func (func)
6710 , m_modularOp (modularOp)
6711 {
6712 buildTest();
6713 }
6714
createInstance(Context & context) const6715 virtual TestInstance* createInstance (Context& context) const
6716 {
6717 return new BuiltinPrecisionCaseTestInstance<In, Out>(context, m_ctx, m_spec, m_variables, getSamplings(), m_stmt, m_modularOp);
6718 }
6719
6720 protected:
6721 void buildTest (void);
getSamplings(void) const6722 virtual const Samplings<In>& getSamplings (void) const
6723 {
6724 return instance<DefaultSamplings<In> >();
6725 }
6726
6727 private:
6728 const CaseFunc& m_func;
6729 Variables<In, Out> m_variables;
6730 bool m_modularOp;
6731 };
6732
6733 template <typename Sig>
buildTest(void)6734 void FuncCase<Sig>::buildTest (void)
6735 {
6736 m_variables.out0 = variable<Ret>("out0");
6737 m_variables.out1 = variable<Void>("out1");
6738 m_variables.in0 = variable<Arg0>("in0");
6739 m_variables.in1 = variable<Arg1>("in1");
6740 m_variables.in2 = variable<Arg2>("in2");
6741 m_variables.in3 = variable<Arg3>("in3");
6742
6743 {
6744 ExprP<Ret> expr = applyVar(m_func, m_variables.in0, m_variables.in1, m_variables.in2, m_variables.in3);
6745 m_stmt = variableAssignment(m_variables.out0, expr);
6746
6747 this->testStatement(m_variables, *m_stmt, m_func.getSpirvCase());
6748 }
6749 }
6750
6751 template <typename Sig>
6752 class InOutFuncCase : public FuncCaseBase
6753 {
6754 public:
6755 typedef Func<Sig> CaseFunc;
6756 typedef typename Sig::Ret Ret;
6757 typedef typename Sig::Arg0 Arg0;
6758 typedef typename Sig::Arg1 Arg1;
6759 typedef typename Sig::Arg2 Arg2;
6760 typedef typename Sig::Arg3 Arg3;
6761 typedef InTypes<Arg0, Arg2, Arg3> In;
6762 typedef OutTypes<Ret, Arg1> Out;
6763
InOutFuncCase(const CaseContext & context,const string & name,const CaseFunc & func,bool modularOp=false)6764 InOutFuncCase (const CaseContext& context, const string& name, const CaseFunc& func, bool modularOp = false)
6765 : FuncCaseBase (context, name, func)
6766 , m_func (func)
6767 , m_modularOp (modularOp)
6768 {
6769 buildTest();
6770 }
createInstance(Context & context) const6771 virtual TestInstance* createInstance (Context& context) const
6772 {
6773 return new BuiltinPrecisionCaseTestInstance<In, Out>(context, m_ctx, m_spec, m_variables, getSamplings(), m_stmt, m_modularOp);
6774 }
6775
6776 protected:
6777 void buildTest (void);
getSamplings(void) const6778 virtual const Samplings<In>& getSamplings (void) const
6779 {
6780 return instance<DefaultSamplings<In> >();
6781 }
6782
6783 private:
6784 const CaseFunc& m_func;
6785 Variables<In, Out> m_variables;
6786 bool m_modularOp;
6787 };
6788
6789 template <typename Sig>
buildTest(void)6790 void InOutFuncCase<Sig>::buildTest (void)
6791 {
6792 m_variables.out0 = variable<Ret>("out0");
6793 m_variables.out1 = variable<Arg1>("out1");
6794 m_variables.in0 = variable<Arg0>("in0");
6795 m_variables.in1 = variable<Arg2>("in1");
6796 m_variables.in2 = variable<Arg3>("in2");
6797 m_variables.in3 = variable<Void>("in3");
6798
6799 {
6800 ExprP<Ret> expr = applyVar(m_func, m_variables.in0, m_variables.out1, m_variables.in1, m_variables.in2);
6801 m_stmt = variableAssignment(m_variables.out0, expr);
6802
6803 this->testStatement(m_variables, *m_stmt, m_func.getSpirvCase());
6804 }
6805 }
6806
6807 template <typename Sig>
createFuncCase(const CaseContext & context,const string & name,const Func<Sig> & func,bool modularOp=false)6808 PrecisionCase* createFuncCase (const CaseContext& context, const string& name, const Func<Sig>& func, bool modularOp = false)
6809 {
6810 switch (func.getOutParamIndex())
6811 {
6812 case -1:
6813 return new FuncCase<Sig>(context, name, func, modularOp);
6814 case 1:
6815 return new InOutFuncCase<Sig>(context, name, func, modularOp);
6816 default:
6817 DE_FATAL("Impossible");
6818 }
6819 return DE_NULL;
6820 }
6821
6822 class CaseFactory
6823 {
6824 public:
~CaseFactory(void)6825 virtual ~CaseFactory (void) {}
6826 virtual MovePtr<TestNode> createCase (const CaseContext& ctx) const = 0;
6827 virtual string getName (void) const = 0;
6828 virtual string getDesc (void) const = 0;
6829 };
6830
6831 class FuncCaseFactory : public CaseFactory
6832 {
6833 public:
6834 virtual const FuncBase& getFunc (void) const = 0;
getName(void) const6835 string getName (void) const { return de::toLower(getFunc().getName()); }
getDesc(void) const6836 string getDesc (void) const { return "Function '" + getFunc().getName() + "'"; }
6837 };
6838
6839 template <typename Sig>
6840 class GenFuncCaseFactory : public CaseFactory
6841 {
6842 public:
GenFuncCaseFactory(const GenFuncs<Sig> & funcs,const string & name,bool modularOp=false)6843 GenFuncCaseFactory (const GenFuncs<Sig>& funcs, const string& name, bool modularOp = false)
6844 : m_funcs (funcs)
6845 , m_name (de::toLower(name))
6846 , m_modularOp (modularOp)
6847 {
6848 }
6849
createCase(const CaseContext & ctx) const6850 MovePtr<TestNode> createCase (const CaseContext& ctx) const
6851 {
6852 TestCaseGroup* group = new TestCaseGroup(ctx.testContext, ctx.name.c_str(), ctx.name.c_str());
6853
6854 group->addChild(createFuncCase(ctx, "scalar", m_funcs.func, m_modularOp));
6855 group->addChild(createFuncCase(ctx, "vec2", m_funcs.func2, m_modularOp));
6856 group->addChild(createFuncCase(ctx, "vec3", m_funcs.func3, m_modularOp));
6857 group->addChild(createFuncCase(ctx, "vec4", m_funcs.func4, m_modularOp));
6858 return MovePtr<TestNode>(group);
6859 }
6860
getName(void) const6861 string getName (void) const { return m_name; }
getDesc(void) const6862 string getDesc (void) const { return "Function '" + m_funcs.func.getName() + "'"; }
6863
6864 private:
6865 const GenFuncs<Sig> m_funcs;
6866 string m_name;
6867 bool m_modularOp;
6868 };
6869
6870 template <template <int, class> class GenF, typename T>
6871 class TemplateFuncCaseFactory : public FuncCaseFactory
6872 {
6873 public:
createCase(const CaseContext & ctx) const6874 MovePtr<TestNode> createCase (const CaseContext& ctx) const
6875 {
6876 TestCaseGroup* group = new TestCaseGroup(ctx.testContext, ctx.name.c_str(), ctx.name.c_str());
6877
6878 group->addChild(createFuncCase(ctx, "scalar", instance<GenF<1, T> >()));
6879 group->addChild(createFuncCase(ctx, "vec2", instance<GenF<2, T> >()));
6880 group->addChild(createFuncCase(ctx, "vec3", instance<GenF<3, T> >()));
6881 group->addChild(createFuncCase(ctx, "vec4", instance<GenF<4, T> >()));
6882
6883 return MovePtr<TestNode>(group);
6884 }
6885
getFunc(void) const6886 const FuncBase& getFunc (void) const { return instance<GenF<1, T> >(); }
6887 };
6888
6889 #ifndef CTS_USES_VULKANSC
6890 template <template <int> class GenF>
6891 class SquareMatrixFuncCaseFactory : public FuncCaseFactory
6892 {
6893 public:
createCase(const CaseContext & ctx) const6894 MovePtr<TestNode> createCase (const CaseContext& ctx) const
6895 {
6896 TestCaseGroup* group = new TestCaseGroup(ctx.testContext, ctx.name.c_str(), ctx.name.c_str());
6897
6898 group->addChild(createFuncCase(ctx, "mat2", instance<GenF<2> >()));
6899
6900 // There is no defined precision for mediump/RelaxedPrecision in Vulkan
6901 if (ctx.name != "mediump")
6902 {
6903 static const char dataDir[] = "builtin/precision/square_matrix";
6904 std::string fileName = getFunc().getName() + "_" + ctx.name;
6905 std::vector<std::string> requirements;
6906
6907 if (ctx.name == "compute")
6908 {
6909 if (ctx.isFloat64b)
6910 {
6911 requirements.push_back("Features.shaderFloat64");
6912 fileName += "_fp64";
6913 }
6914 else
6915 {
6916 requirements.push_back("Float16Int8Features.shaderFloat16");
6917 requirements.push_back("VK_KHR_16bit_storage");
6918 requirements.push_back("VK_KHR_storage_buffer_storage_class");
6919 fileName += "_fp16";
6920
6921 if (ctx.isPackFloat16b == true)
6922 {
6923 fileName += "_32bit";
6924 }
6925 else
6926 {
6927 requirements.push_back("Storage16BitFeatures.storageBuffer16BitAccess");
6928 }
6929 }
6930 }
6931
6932 group->addChild(cts_amber::createAmberTestCase(ctx.testContext, "mat3", "Square matrix 3x3 precision tests", dataDir, fileName + "_mat_3x3.amber", requirements));
6933 group->addChild(cts_amber::createAmberTestCase(ctx.testContext, "mat4", "Square matrix 4x4 precision tests", dataDir, fileName + "_mat_4x4.amber", requirements));
6934 }
6935
6936 return MovePtr<TestNode>(group);
6937 }
6938
getFunc(void) const6939 const FuncBase& getFunc (void) const { return instance<GenF<2> >(); }
6940 };
6941 #endif // CTS_USES_VULKANSC
6942
6943 template <template <int, int, class> class GenF, typename T>
6944 class MatrixFuncCaseFactory : public FuncCaseFactory
6945 {
6946 public:
createCase(const CaseContext & ctx) const6947 MovePtr<TestNode> createCase (const CaseContext& ctx) const
6948 {
6949 TestCaseGroup* const group = new TestCaseGroup(ctx.testContext, ctx.name.c_str(), ctx.name.c_str());
6950
6951 this->addCase<2, 2>(ctx, group);
6952 this->addCase<3, 2>(ctx, group);
6953 this->addCase<4, 2>(ctx, group);
6954 this->addCase<2, 3>(ctx, group);
6955 this->addCase<3, 3>(ctx, group);
6956 this->addCase<4, 3>(ctx, group);
6957 this->addCase<2, 4>(ctx, group);
6958 this->addCase<3, 4>(ctx, group);
6959 this->addCase<4, 4>(ctx, group);
6960
6961 return MovePtr<TestNode>(group);
6962 }
6963
getFunc(void) const6964 const FuncBase& getFunc (void) const { return instance<GenF<2,2, T> >(); }
6965
6966 private:
6967 template <int Rows, int Cols>
addCase(const CaseContext & ctx,TestCaseGroup * group) const6968 void addCase (const CaseContext& ctx, TestCaseGroup* group) const
6969 {
6970 const char* const name = dataTypeNameOf<Matrix<float, Rows, Cols> >();
6971 group->addChild(createFuncCase(ctx, name, instance<GenF<Rows, Cols, T> >()));
6972 }
6973 };
6974
6975 template <typename Sig>
6976 class SimpleFuncCaseFactory : public CaseFactory
6977 {
6978 public:
SimpleFuncCaseFactory(const Func<Sig> & func)6979 SimpleFuncCaseFactory (const Func<Sig>& func) : m_func(func) {}
6980
createCase(const CaseContext & ctx) const6981 MovePtr<TestNode> createCase (const CaseContext& ctx) const { return MovePtr<TestNode>(createFuncCase(ctx, ctx.name.c_str(), m_func)); }
getName(void) const6982 string getName (void) const { return de::toLower(m_func.getName()); }
getDesc(void) const6983 string getDesc (void) const { return "Function '" + getName() + "'"; }
6984
6985 private:
6986 const Func<Sig>& m_func;
6987 };
6988
6989 template <typename F>
createSimpleFuncCaseFactory(void)6990 SharedPtr<SimpleFuncCaseFactory<typename F::Sig> > createSimpleFuncCaseFactory (void)
6991 {
6992 return SharedPtr<SimpleFuncCaseFactory<typename F::Sig> >(new SimpleFuncCaseFactory<typename F::Sig>(instance<F>()));
6993 }
6994
6995 class CaseFactories
6996 {
6997 public:
~CaseFactories(void)6998 virtual ~CaseFactories (void) {}
6999 virtual const std::vector<const CaseFactory*> getFactories (void) const = 0;
7000 };
7001
7002 class BuiltinFuncs : public CaseFactories
7003 {
7004 public:
getFactories(void) const7005 const vector<const CaseFactory*> getFactories (void) const
7006 {
7007 vector<const CaseFactory*> ret;
7008
7009 for (size_t ndx = 0; ndx < m_factories.size(); ++ndx)
7010 ret.push_back(m_factories[ndx].get());
7011
7012 return ret;
7013 }
7014
addFactory(SharedPtr<const CaseFactory> fact)7015 void addFactory (SharedPtr<const CaseFactory> fact) { m_factories.push_back(fact); }
7016
7017 private:
7018 vector<SharedPtr<const CaseFactory> > m_factories;
7019 };
7020
7021 template <typename F>
addScalarFactory(BuiltinFuncs & funcs,string name="",bool modularOp=false)7022 void addScalarFactory (BuiltinFuncs& funcs, string name = "", bool modularOp = false)
7023 {
7024 if (name.empty())
7025 name = instance<F>().getName();
7026
7027 funcs.addFactory(SharedPtr<const CaseFactory>(new GenFuncCaseFactory<typename F::Sig>(makeVectorizedFuncs<F>(), name, modularOp)));
7028 }
7029
createBuiltinCases()7030 MovePtr<const CaseFactories> createBuiltinCases ()
7031 {
7032 MovePtr<BuiltinFuncs> funcs (new BuiltinFuncs());
7033
7034 // Tests for ES3 builtins
7035 addScalarFactory<Comparison< Signature<int, float, float> > >(*funcs);
7036 addScalarFactory<Add< Signature<float, float, float> > >(*funcs);
7037 addScalarFactory<Sub< Signature<float, float, float> > >(*funcs);
7038 addScalarFactory<Mul< Signature<float, float, float> > >(*funcs);
7039 addScalarFactory<Div< Signature<float, float, float> > >(*funcs);
7040
7041 addScalarFactory<Radians>(*funcs);
7042 addScalarFactory<Degrees>(*funcs);
7043 addScalarFactory<Sin<Signature<float, float> > >(*funcs);
7044 addScalarFactory<Cos<Signature<float, float> > >(*funcs);
7045 addScalarFactory<Tan>(*funcs);
7046
7047 addScalarFactory<ASin>(*funcs);
7048 addScalarFactory<ACos>(*funcs);
7049 addScalarFactory<ATan2< Signature<float, float, float> > >(*funcs, "atan2");
7050 addScalarFactory<ATan<Signature<float, float> > >(*funcs);
7051 addScalarFactory<Sinh>(*funcs);
7052 addScalarFactory<Cosh>(*funcs);
7053 addScalarFactory<Tanh>(*funcs);
7054 addScalarFactory<ASinh>(*funcs);
7055 addScalarFactory<ACosh>(*funcs);
7056 addScalarFactory<ATanh>(*funcs);
7057
7058 addScalarFactory<Pow>(*funcs);
7059 addScalarFactory<Exp<Signature<float, float> > >(*funcs);
7060 addScalarFactory<Log< Signature<float, float> > >(*funcs);
7061 addScalarFactory<Exp2<Signature<float, float> > >(*funcs);
7062 addScalarFactory<Log2< Signature<float, float> > >(*funcs);
7063 addScalarFactory<Sqrt32Bit>(*funcs);
7064 addScalarFactory<InverseSqrt< Signature<float, float> > >(*funcs);
7065
7066 addScalarFactory<Abs< Signature<float, float> > >(*funcs);
7067 addScalarFactory<Sign< Signature<float, float> > >(*funcs);
7068 addScalarFactory<Floor32Bit>(*funcs);
7069 addScalarFactory<Trunc32Bit>(*funcs);
7070 addScalarFactory<Round< Signature<float, float> > >(*funcs);
7071 addScalarFactory<RoundEven< Signature<float, float> > >(*funcs);
7072 addScalarFactory<Ceil< Signature<float, float> > >(*funcs);
7073 addScalarFactory<Fract>(*funcs);
7074
7075 addScalarFactory<Mod32Bit>(*funcs, "mod", true);
7076 addScalarFactory<FRem32Bit>(*funcs);
7077
7078 addScalarFactory<Modf32Bit>(*funcs);
7079 addScalarFactory<ModfStruct32Bit>(*funcs);
7080 addScalarFactory<Min< Signature<float, float, float> > >(*funcs);
7081 addScalarFactory<Max< Signature<float, float, float> > >(*funcs);
7082 addScalarFactory<Clamp< Signature<float, float, float, float> > >(*funcs);
7083 addScalarFactory<Mix>(*funcs);
7084 addScalarFactory<Step< Signature<float, float, float> > >(*funcs);
7085 addScalarFactory<SmoothStep< Signature<float, float, float, float> > >(*funcs);
7086
7087 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Length, float>()));
7088 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Distance, float>()));
7089 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Dot, float>()));
7090 funcs->addFactory(createSimpleFuncCaseFactory<Cross>());
7091 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Normalize, float>()));
7092 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<FaceForward, float>()));
7093 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Reflect, float>()));
7094 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Refract, float>()));
7095
7096 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<MatrixCompMult, float>()));
7097 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<OuterProduct, float>()));
7098 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<Transpose, float>()));
7099 #ifndef CTS_USES_VULKANSC
7100 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Determinant>()));
7101 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Inverse>()));
7102 #endif // CTS_USES_VULKANSC
7103
7104 addScalarFactory<Frexp32Bit>(*funcs);
7105 addScalarFactory<FrexpStruct32Bit>(*funcs);
7106 addScalarFactory<LdExp <Signature<float, float, int> > >(*funcs);
7107 addScalarFactory<Fma <Signature<float, float, float, float> > >(*funcs);
7108
7109 return MovePtr<const CaseFactories>(funcs.release());
7110 }
7111
createBuiltinDoubleCases()7112 MovePtr<const CaseFactories> createBuiltinDoubleCases ()
7113 {
7114 MovePtr<BuiltinFuncs> funcs (new BuiltinFuncs());
7115
7116 // Tests for ES3 builtins
7117 addScalarFactory<Comparison<Signature<int, double, double>>>(*funcs);
7118 addScalarFactory<Add<Signature<double, double, double>>>(*funcs);
7119 addScalarFactory<Sub<Signature<double, double, double>>>(*funcs);
7120 addScalarFactory<Mul<Signature<double, double, double>>>(*funcs);
7121 addScalarFactory<Div<Signature<double, double, double>>>(*funcs);
7122
7123 // Radians, degrees, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh, atan2, pow, exp, log, exp2 and log2
7124 // only work with 16-bit and 32-bit floating point types according to the spec.
7125 #if 0
7126 addScalarFactory<Radians64>(*funcs);
7127 addScalarFactory<Degrees64>(*funcs);
7128 addScalarFactory<Sin<Signature<double, double>>>(*funcs);
7129 addScalarFactory<Cos<Signature<double, double>>>(*funcs);
7130 addScalarFactory<Tan64Bit>(*funcs);
7131 addScalarFactory<ASin64Bit>(*funcs);
7132 addScalarFactory<ACos64Bit>(*funcs);
7133 addScalarFactory<ATan2<Signature<double, double, double>>>(*funcs, "atan2");
7134 addScalarFactory<ATan<Signature<double, double>>>(*funcs);
7135 addScalarFactory<Sinh64Bit>(*funcs);
7136 addScalarFactory<Cosh64Bit>(*funcs);
7137 addScalarFactory<Tanh64Bit>(*funcs);
7138 addScalarFactory<ASinh64Bit>(*funcs);
7139 addScalarFactory<ACosh64Bit>(*funcs);
7140 addScalarFactory<ATanh64Bit>(*funcs);
7141
7142 addScalarFactory<Pow64>(*funcs);
7143 addScalarFactory<Exp<Signature<double, double>>>(*funcs);
7144 addScalarFactory<Log<Signature<double, double>>>(*funcs);
7145 addScalarFactory<Exp2<Signature<double, double>>>(*funcs);
7146 addScalarFactory<Log2<Signature<double, double>>>(*funcs);
7147 #endif
7148 addScalarFactory<Sqrt64Bit>(*funcs);
7149 addScalarFactory<InverseSqrt<Signature<double, double>>>(*funcs);
7150
7151 addScalarFactory<Abs<Signature<double, double>>>(*funcs);
7152 addScalarFactory<Sign<Signature<double, double>>>(*funcs);
7153 addScalarFactory<Floor64Bit>(*funcs);
7154 addScalarFactory<Trunc64Bit>(*funcs);
7155 addScalarFactory<Round<Signature<double, double>>>(*funcs);
7156 addScalarFactory<RoundEven<Signature<double, double>>>(*funcs);
7157 addScalarFactory<Ceil<Signature<double, double>>>(*funcs);
7158 addScalarFactory<Fract64Bit>(*funcs);
7159
7160 addScalarFactory<Mod64Bit>(*funcs, "mod", true);
7161 addScalarFactory<FRem64Bit>(*funcs);
7162
7163 addScalarFactory<Modf64Bit>(*funcs);
7164 addScalarFactory<ModfStruct64Bit>(*funcs);
7165 addScalarFactory<Min<Signature<double, double, double>>>(*funcs);
7166 addScalarFactory<Max<Signature<double, double, double>>>(*funcs);
7167 addScalarFactory<Clamp<Signature<double, double, double, double>>>(*funcs);
7168 addScalarFactory<Mix64Bit>(*funcs);
7169 addScalarFactory<Step<Signature<double, double, double>>>(*funcs);
7170 addScalarFactory<SmoothStep<Signature<double, double, double, double>>>(*funcs);
7171
7172 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Length, double>()));
7173 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Distance, double>()));
7174 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Dot, double>()));
7175 funcs->addFactory(createSimpleFuncCaseFactory<Cross64Bit>());
7176 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Normalize, double>()));
7177 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<FaceForward, double>()));
7178 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Reflect, double>()));
7179 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Refract, double>()));
7180
7181 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<MatrixCompMult, double>()));
7182 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<OuterProduct, double>()));
7183 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<Transpose, double>()));
7184 #ifndef CTS_USES_VULKANSC
7185 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Determinant64bit>()));
7186 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Inverse64bit>()));
7187 #endif // CTS_USES_VULKANSC
7188
7189 addScalarFactory<Frexp64Bit>(*funcs);
7190 addScalarFactory<FrexpStruct64Bit>(*funcs);
7191 addScalarFactory<LdExp<Signature<double, double, int>>>(*funcs);
7192 addScalarFactory<Fma<Signature<double, double, double, double>>>(*funcs);
7193
7194 return MovePtr<const CaseFactories>(funcs.release());
7195 }
7196
createBuiltinCases16Bit(void)7197 MovePtr<const CaseFactories> createBuiltinCases16Bit(void)
7198 {
7199 MovePtr<BuiltinFuncs> funcs(new BuiltinFuncs());
7200
7201 addScalarFactory<Comparison< Signature<int, deFloat16, deFloat16> > >(*funcs);
7202 addScalarFactory<Add< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7203 addScalarFactory<Sub< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7204 addScalarFactory<Mul< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7205 addScalarFactory<Div< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7206
7207 addScalarFactory<Radians16>(*funcs);
7208 addScalarFactory<Degrees16>(*funcs);
7209
7210 addScalarFactory<Sin<Signature<deFloat16, deFloat16> > >(*funcs);
7211 addScalarFactory<Cos<Signature<deFloat16, deFloat16> > >(*funcs);
7212 addScalarFactory<Tan16Bit>(*funcs);
7213 addScalarFactory<ASin16Bit>(*funcs);
7214 addScalarFactory<ACos16Bit>(*funcs);
7215 addScalarFactory<ATan2< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs, "atan2");
7216 addScalarFactory<ATan<Signature<deFloat16, deFloat16> > >(*funcs);
7217
7218 addScalarFactory<Sinh16Bit>(*funcs);
7219 addScalarFactory<Cosh16Bit>(*funcs);
7220 addScalarFactory<Tanh16Bit>(*funcs);
7221 addScalarFactory<ASinh16Bit>(*funcs);
7222 addScalarFactory<ACosh16Bit>(*funcs);
7223 addScalarFactory<ATanh16Bit>(*funcs);
7224
7225 addScalarFactory<Pow16>(*funcs);
7226 addScalarFactory<Exp< Signature<deFloat16, deFloat16> > >(*funcs);
7227 addScalarFactory<Log< Signature<deFloat16, deFloat16> > >(*funcs);
7228 addScalarFactory<Exp2< Signature<deFloat16, deFloat16> > >(*funcs);
7229 addScalarFactory<Log2< Signature<deFloat16, deFloat16> > >(*funcs);
7230 addScalarFactory<Sqrt16Bit>(*funcs);
7231 addScalarFactory<InverseSqrt16Bit>(*funcs);
7232
7233 addScalarFactory<Abs< Signature<deFloat16, deFloat16> > >(*funcs);
7234 addScalarFactory<Sign< Signature<deFloat16, deFloat16> > >(*funcs);
7235 addScalarFactory<Floor16Bit>(*funcs);
7236 addScalarFactory<Trunc16Bit>(*funcs);
7237 addScalarFactory<Round< Signature<deFloat16, deFloat16> > >(*funcs);
7238 addScalarFactory<RoundEven< Signature<deFloat16, deFloat16> > >(*funcs);
7239 addScalarFactory<Ceil< Signature<deFloat16, deFloat16> > >(*funcs);
7240 addScalarFactory<Fract16Bit>(*funcs);
7241
7242 addScalarFactory<Mod16Bit>(*funcs, "mod", true);
7243 addScalarFactory<FRem16Bit>(*funcs);
7244
7245 addScalarFactory<Modf16Bit>(*funcs);
7246 addScalarFactory<ModfStruct16Bit>(*funcs);
7247 addScalarFactory<Min< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7248 addScalarFactory<Max< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7249 addScalarFactory<Clamp< Signature<deFloat16, deFloat16, deFloat16, deFloat16> > >(*funcs);
7250 addScalarFactory<Mix16Bit>(*funcs);
7251 addScalarFactory<Step< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7252 addScalarFactory<SmoothStep< Signature<deFloat16, deFloat16, deFloat16, deFloat16> > >(*funcs);
7253
7254 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Length, deFloat16>()));
7255 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Distance, deFloat16>()));
7256 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Dot, deFloat16>()));
7257 funcs->addFactory(createSimpleFuncCaseFactory<Cross16Bit>());
7258 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Normalize, deFloat16>()));
7259 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<FaceForward, deFloat16>()));
7260 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Reflect, deFloat16>()));
7261 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Refract, deFloat16>()));
7262
7263 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<OuterProduct, deFloat16>()));
7264 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<Transpose, deFloat16>()));
7265 #ifndef CTS_USES_VULKANSC
7266 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Determinant16bit>()));
7267 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Inverse16bit>()));
7268 #endif // CTS_USES_VULKANSC
7269
7270 addScalarFactory<Frexp16Bit>(*funcs);
7271 addScalarFactory<FrexpStruct16Bit>(*funcs);
7272 addScalarFactory<LdExp <Signature<deFloat16, deFloat16, int> > >(*funcs);
7273 addScalarFactory<Fma <Signature<deFloat16, deFloat16, deFloat16, deFloat16> > >(*funcs);
7274
7275 return MovePtr<const CaseFactories>(funcs.release());
7276 }
7277
createFuncGroup(TestContext & ctx,const CaseFactory & factory,int numRandoms)7278 TestCaseGroup* createFuncGroup (TestContext& ctx, const CaseFactory& factory, int numRandoms)
7279 {
7280 TestCaseGroup* const group = new TestCaseGroup(ctx, factory.getName().c_str(), factory.getDesc().c_str());
7281 const FloatFormat highp (-126, 127, 23, true,
7282 tcu::MAYBE, // subnormals
7283 tcu::YES, // infinities
7284 tcu::MAYBE); // NaN
7285 const FloatFormat mediump (-14, 13, 10, false, tcu::MAYBE);
7286
7287 for (int precNdx = glu::PRECISION_MEDIUMP; precNdx < glu::PRECISION_LAST; ++precNdx)
7288 {
7289 const Precision precision = Precision(precNdx);
7290 const string precName (glu::getPrecisionName(precision));
7291 const FloatFormat& fmt = precNdx == glu::PRECISION_MEDIUMP ? mediump : highp;
7292
7293 const CaseContext caseCtx (precName, ctx, fmt, highp, precision, glu::SHADERTYPE_COMPUTE, numRandoms);
7294
7295 group->addChild(factory.createCase(caseCtx).release());
7296 }
7297
7298 return group;
7299 }
7300
createFuncGroupDouble(TestContext & ctx,const CaseFactory & factory,int numRandoms)7301 TestCaseGroup* createFuncGroupDouble (TestContext& ctx, const CaseFactory& factory, int numRandoms)
7302 {
7303 TestCaseGroup* const group = new TestCaseGroup(ctx, factory.getName().c_str(), factory.getDesc().c_str());
7304 const Precision precision = Precision(glu::PRECISION_LAST);
7305 const FloatFormat highp (-1022, 1023, 52, true,
7306 tcu::MAYBE, // subnormals
7307 tcu::YES, // infinities
7308 tcu::MAYBE); // NaN
7309
7310 PrecisionTestFeatures precisionTestFeatures = PRECISION_TEST_FEATURES_64BIT_SHADER_FLOAT;
7311
7312 const CaseContext caseCtx("compute", ctx, highp, highp, precision, glu::SHADERTYPE_COMPUTE, numRandoms, precisionTestFeatures, false, true);
7313 group->addChild(factory.createCase(caseCtx).release());
7314
7315 return group;
7316 }
7317
createFuncGroup16Bit(TestContext & ctx,const CaseFactory & factory,int numRandoms,bool storage32)7318 TestCaseGroup* createFuncGroup16Bit(TestContext& ctx, const CaseFactory& factory, int numRandoms, bool storage32)
7319 {
7320 TestCaseGroup* const group = new TestCaseGroup(ctx, factory.getName().c_str(), factory.getDesc().c_str());
7321 const Precision precision = Precision(glu::PRECISION_LAST);
7322 const FloatFormat float16 (-14, 15, 10, true, tcu::MAYBE);
7323
7324 PrecisionTestFeatures precisionTestFeatures = PRECISION_TEST_FEATURES_16BIT_SHADER_FLOAT;
7325 if (!storage32)
7326 precisionTestFeatures |= PRECISION_TEST_FEATURES_16BIT_UNIFORM_AND_STORAGE_BUFFER_ACCESS;
7327
7328 const CaseContext caseCtx("compute", ctx, float16, float16, precision, glu::SHADERTYPE_COMPUTE, numRandoms, precisionTestFeatures, storage32);
7329 group->addChild(factory.createCase(caseCtx).release());
7330
7331 return group;
7332 }
7333
7334 const int defRandoms = 16384;
7335
addBuiltinPrecisionTests(TestContext & ctx,TestCaseGroup & dstGroup,const bool test16Bit=false,const bool storage32Bit=false)7336 void addBuiltinPrecisionTests (TestContext& ctx,
7337 TestCaseGroup& dstGroup,
7338 const bool test16Bit = false,
7339 const bool storage32Bit = false)
7340 {
7341 const int userRandoms = ctx.getCommandLine().getTestIterationCount();
7342 const int numRandoms = userRandoms > 0 ? userRandoms : defRandoms;
7343
7344 MovePtr<const CaseFactories> cases = (test16Bit && !storage32Bit) ? createBuiltinCases16Bit()
7345 : createBuiltinCases();
7346 for (size_t ndx = 0; ndx < cases->getFactories().size(); ++ndx)
7347 {
7348 if (!test16Bit)
7349 dstGroup.addChild(createFuncGroup(ctx, *cases->getFactories()[ndx], numRandoms));
7350 else
7351 dstGroup.addChild(createFuncGroup16Bit(ctx, *cases->getFactories()[ndx], numRandoms, storage32Bit));
7352 }
7353 }
7354
addBuiltinPrecisionDoubleTests(TestContext & ctx,TestCaseGroup & dstGroup)7355 void addBuiltinPrecisionDoubleTests (TestContext& ctx,
7356 TestCaseGroup& dstGroup)
7357 {
7358 const int userRandoms = ctx.getCommandLine().getTestIterationCount();
7359 const int numRandoms = userRandoms > 0 ? userRandoms : defRandoms;
7360
7361 MovePtr<const CaseFactories> cases = createBuiltinDoubleCases();
7362 for (size_t ndx = 0; ndx < cases->getFactories().size(); ++ndx)
7363 {
7364 dstGroup.addChild(createFuncGroupDouble(ctx, *cases->getFactories()[ndx], numRandoms));
7365 }
7366 }
7367
BuiltinPrecisionTests(tcu::TestContext & testCtx)7368 BuiltinPrecisionTests::BuiltinPrecisionTests (tcu::TestContext& testCtx)
7369 : tcu::TestCaseGroup(testCtx, "precision", "Builtin precision tests")
7370 {
7371 }
7372
~BuiltinPrecisionTests(void)7373 BuiltinPrecisionTests::~BuiltinPrecisionTests (void)
7374 {
7375 }
7376
init(void)7377 void BuiltinPrecisionTests::init (void)
7378 {
7379 addBuiltinPrecisionTests(m_testCtx, *this);
7380 }
7381
BuiltinPrecisionDoubleTests(tcu::TestContext & testCtx)7382 BuiltinPrecisionDoubleTests::BuiltinPrecisionDoubleTests (tcu::TestContext& testCtx)
7383 : tcu::TestCaseGroup(testCtx, "precision_double", "Builtin precision tests")
7384 {
7385 }
7386
~BuiltinPrecisionDoubleTests(void)7387 BuiltinPrecisionDoubleTests::~BuiltinPrecisionDoubleTests (void)
7388 {
7389 }
7390
init(void)7391 void BuiltinPrecisionDoubleTests::init (void)
7392 {
7393 addBuiltinPrecisionDoubleTests(m_testCtx, *this);
7394 }
7395
BuiltinPrecision16BitTests(tcu::TestContext & testCtx)7396 BuiltinPrecision16BitTests::BuiltinPrecision16BitTests (tcu::TestContext& testCtx)
7397 : tcu::TestCaseGroup(testCtx, "precision_fp16_storage16b", "Builtin precision tests")
7398 {
7399 }
7400
~BuiltinPrecision16BitTests(void)7401 BuiltinPrecision16BitTests::~BuiltinPrecision16BitTests (void)
7402 {
7403 }
7404
init(void)7405 void BuiltinPrecision16BitTests::init (void)
7406 {
7407 addBuiltinPrecisionTests(m_testCtx, *this, true);
7408 }
7409
BuiltinPrecision16Storage32BitTests(tcu::TestContext & testCtx)7410 BuiltinPrecision16Storage32BitTests::BuiltinPrecision16Storage32BitTests(tcu::TestContext& testCtx)
7411 : tcu::TestCaseGroup(testCtx, "precision_fp16_storage32b", "Builtin precision tests")
7412 {
7413 }
7414
~BuiltinPrecision16Storage32BitTests(void)7415 BuiltinPrecision16Storage32BitTests::~BuiltinPrecision16Storage32BitTests(void)
7416 {
7417 }
7418
init(void)7419 void BuiltinPrecision16Storage32BitTests::init(void)
7420 {
7421 addBuiltinPrecisionTests(m_testCtx, *this, true, true);
7422 }
7423
7424 } // shaderexecutor
7425 } // vkt
7426