1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2018 The Khronos Group Inc.
6 * Copyright (c) 2015 Samsung Electronics Co., Ltd.
7 * Copyright (c) 2016 The Android Open Source Project
8 *
9 * Licensed under the Apache License, Version 2.0 (the "License");
10 * you may not use this file except in compliance with the License.
11 * You may obtain a copy of the License at
12 *
13 * http://www.apache.org/licenses/LICENSE-2.0
14 *
15 * Unless required by applicable law or agreed to in writing, software
16 * distributed under the License is distributed on an "AS IS" BASIS,
17 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 * See the License for the specific language governing permissions and
19 * limitations under the License.
20 *
21 *//*!
22 * \file
23 * \brief Precision and range tests for builtins and types.
24 *
25 *//*--------------------------------------------------------------------*/
26
27 #include "vktShaderBuiltinPrecisionTests.hpp"
28 #include "vktShaderExecutor.hpp"
29
30 #include "deMath.h"
31 #include "deMemory.h"
32 #include "deFloat16.h"
33 #include "deDefs.hpp"
34 #include "deRandom.hpp"
35 #include "deSTLUtil.hpp"
36 #include "deStringUtil.hpp"
37 #include "deUniquePtr.hpp"
38 #include "deSharedPtr.hpp"
39 #include "deArrayUtil.hpp"
40
41 #include "tcuCommandLine.hpp"
42 #include "tcuFloatFormat.hpp"
43 #include "tcuInterval.hpp"
44 #include "tcuTestLog.hpp"
45 #include "tcuVector.hpp"
46 #include "tcuMatrix.hpp"
47 #include "tcuResultCollector.hpp"
48 #include "tcuMaybe.hpp"
49
50 #include "gluContextInfo.hpp"
51 #include "gluVarType.hpp"
52 #include "gluRenderContext.hpp"
53 #include "glwDefs.hpp"
54
55 #include <cmath>
56 #include <string>
57 #include <sstream>
58 #include <iostream>
59 #include <map>
60 #include <utility>
61 #include <limits>
62
63 // Uncomment this to get evaluation trace dumps to std::cerr
64 // #define GLS_ENABLE_TRACE
65
66 // set this to true to dump even passing results
67 #define GLS_LOG_ALL_RESULTS false
68
69 #define FLOAT16_1_0 0x3C00 //1.0 float16bit
70 #define FLOAT16_180_0 0x59A0 //180.0 float16bit
71 #define FLOAT16_2_0 0x4000 //2.0 float16bit
72 #define FLOAT16_3_0 0x4200 //3.0 float16bit
73 #define FLOAT16_0_5 0x3800 //0.5 float16bit
74 #define FLOAT16_0_0 0x0000 //0.0 float16bit
75
76
77 using tcu::Vector;
78 typedef Vector<deFloat16, 1> Vec1_16Bit;
79 typedef Vector<deFloat16, 2> Vec2_16Bit;
80 typedef Vector<deFloat16, 3> Vec3_16Bit;
81 typedef Vector<deFloat16, 4> Vec4_16Bit;
82
83 typedef Vector<double, 1> Vec1_64Bit;
84 typedef Vector<double, 2> Vec2_64Bit;
85 typedef Vector<double, 3> Vec3_64Bit;
86 typedef Vector<double, 4> Vec4_64Bit;
87
88 enum
89 {
90 // Computing reference intervals can take a non-trivial amount of time, especially on
91 // platforms where toggling floating-point rounding mode is slow (emulated arm on x86).
92 // As a workaround watchdog is kept happy by touching it periodically during reference
93 // interval computation.
94 TOUCH_WATCHDOG_VALUE_FREQUENCY = 512
95 };
96
97 namespace vkt
98 {
99 namespace shaderexecutor
100 {
101
102 using std::string;
103 using std::map;
104 using std::ostream;
105 using std::ostringstream;
106 using std::pair;
107 using std::vector;
108 using std::set;
109
110 using de::MovePtr;
111 using de::Random;
112 using de::SharedPtr;
113 using de::UniquePtr;
114 using tcu::Interval;
115 using tcu::FloatFormat;
116 using tcu::MessageBuilder;
117 using tcu::TestLog;
118 using tcu::Vector;
119 using tcu::Matrix;
120 using glu::Precision;
121 using glu::VarType;
122 using glu::DataType;
123 using glu::ShaderType;
124
125 enum PrecisionTestFeatureBits
126 {
127 PRECISION_TEST_FEATURES_NONE = 0u,
128 PRECISION_TEST_FEATURES_16BIT_BUFFER_ACCESS = (1u << 1),
129 PRECISION_TEST_FEATURES_16BIT_UNIFORM_AND_STORAGE_BUFFER_ACCESS = (1u << 2),
130 PRECISION_TEST_FEATURES_16BIT_PUSH_CONSTANT = (1u << 3),
131 PRECISION_TEST_FEATURES_16BIT_INPUT_OUTPUT = (1u << 4),
132 PRECISION_TEST_FEATURES_16BIT_SHADER_FLOAT = (1u << 5),
133 PRECISION_TEST_FEATURES_64BIT_SHADER_FLOAT = (1u << 6),
134 };
135 typedef deUint32 PrecisionTestFeatures;
136
137
areFeaturesSupported(const Context & context,deUint32 toCheck)138 void areFeaturesSupported (const Context& context, deUint32 toCheck)
139 {
140 if (toCheck == PRECISION_TEST_FEATURES_NONE) return;
141
142 const vk::VkPhysicalDevice16BitStorageFeatures& extensionFeatures = context.get16BitStorageFeatures();
143
144 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_BUFFER_ACCESS) != 0 && extensionFeatures.storageBuffer16BitAccess == VK_FALSE)
145 TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
146
147 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_UNIFORM_AND_STORAGE_BUFFER_ACCESS) != 0 && extensionFeatures.uniformAndStorageBuffer16BitAccess == VK_FALSE)
148 TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
149
150 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_PUSH_CONSTANT) != 0 && extensionFeatures.storagePushConstant16 == VK_FALSE)
151 TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
152
153 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_INPUT_OUTPUT) != 0 && extensionFeatures.storageInputOutput16 == VK_FALSE)
154 TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
155
156 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_SHADER_FLOAT) != 0 && context.getShaderFloat16Int8Features().shaderFloat16 == VK_FALSE)
157 TCU_THROW(NotSupportedError, "Requested 16-bit floats (halfs) are not supported in shader code");
158
159 if ((toCheck & PRECISION_TEST_FEATURES_64BIT_SHADER_FLOAT) != 0 && context.getDeviceFeatures().shaderFloat64 == VK_FALSE)
160 TCU_THROW(NotSupportedError, "Requested 64-bit floats are not supported in shader code");
161 }
162
163 /*--------------------------------------------------------------------*//*!
164 * \brief Generic singleton creator.
165 *
166 * instance<T>() returns a reference to a unique default-constructed instance
167 * of T. This is mainly used for our GLSL function implementations: each
168 * function is implemented by an object, and each of the objects has a
169 * distinct class. It would be extremely toilsome to maintain a separate
170 * context object that contained individual instances of the function classes,
171 * so we have to resort to global singleton instances.
172 *
173 *//*--------------------------------------------------------------------*/
174 template <typename T>
instance(void)175 const T& instance (void)
176 {
177 static const T s_instance = T();
178 return s_instance;
179 }
180
181 /*--------------------------------------------------------------------*//*!
182 * \brief Dummy placeholder type for unused template parameters.
183 *
184 * In the precision tests we are dealing with functions of different arities.
185 * To minimize code duplication, we only define templates with the maximum
186 * number of arguments, currently four. If a function's arity is less than the
187 * maximum, Void us used as the type for unused arguments.
188 *
189 * Although Voids are not used at run-time, they still must be compilable, so
190 * they must support all operations that other types do.
191 *
192 *//*--------------------------------------------------------------------*/
193 struct Void
194 {
195 typedef Void Element;
196 enum
197 {
198 SIZE = 0,
199 };
200
201 template <typename T>
Voidvkt::shaderexecutor::Void202 explicit Void (const T&) {}
Voidvkt::shaderexecutor::Void203 Void (void) {}
operator doublevkt::shaderexecutor::Void204 operator double (void) const { return TCU_NAN; }
205
206 // These are used to make Voids usable as containers in container-generic code.
operator []vkt::shaderexecutor::Void207 Void& operator[] (int) { return *this; }
operator []vkt::shaderexecutor::Void208 const Void& operator[] (int) const { return *this; }
209 };
210
operator <<(ostream & os,Void)211 ostream& operator<< (ostream& os, Void) { return os << "()"; }
212
213 //! Returns true for all other types except Void
isTypeValid(void)214 template <typename T> bool isTypeValid (void) { return true; }
isTypeValid(void)215 template <> bool isTypeValid<Void> (void) { return false; }
216
isInteger(void)217 template <typename T> bool isInteger (void) { return false; }
isInteger(void)218 template <> bool isInteger<int> (void) { return true; }
isInteger(void)219 template <> bool isInteger<tcu::IVec2> (void) { return true; }
isInteger(void)220 template <> bool isInteger<tcu::IVec3> (void) { return true; }
isInteger(void)221 template <> bool isInteger<tcu::IVec4> (void) { return true; }
222
223 //! Utility function for getting the name of a data type.
224 //! This is used in vector and matrix constructors.
225 template <typename T>
dataTypeNameOf(void)226 const char* dataTypeNameOf (void)
227 {
228 return glu::getDataTypeName(glu::dataTypeOf<T>());
229 }
230
231 template <>
dataTypeNameOf(void)232 const char* dataTypeNameOf<Void> (void)
233 {
234 DE_FATAL("Impossible");
235 return DE_NULL;
236 }
237
238 template <typename T>
getVarTypeOf(Precision prec=glu::PRECISION_LAST)239 VarType getVarTypeOf (Precision prec = glu::PRECISION_LAST)
240 {
241 return glu::varTypeOf<T>(prec);
242 }
243
244 //! A hack to get Void support for VarType.
245 template <>
getVarTypeOf(Precision)246 VarType getVarTypeOf<Void> (Precision)
247 {
248 DE_FATAL("Impossible");
249 return VarType();
250 }
251
252 /*--------------------------------------------------------------------*//*!
253 * \brief Type traits for generalized interval types.
254 *
255 * We are trying to compute sets of acceptable values not only for
256 * float-valued expressions but also for compound values: vectors and
257 * matrices. We approximate a set of vectors as a vector of intervals and
258 * likewise for matrices.
259 *
260 * We now need generalized operations for each type and its interval
261 * approximation. These are given in the type Traits<T>.
262 *
263 * The type Traits<T>::IVal is the approximation of T: it is `Interval` for
264 * scalar types, and a vector or matrix of intervals for container types.
265 *
266 * To allow template inference to take place, there are function wrappers for
267 * the actual operations in Traits<T>. Hence we can just use:
268 *
269 * makeIVal(someFloat)
270 *
271 * instead of:
272 *
273 * Traits<float>::doMakeIVal(value)
274 *
275 *//*--------------------------------------------------------------------*/
276
277 template <typename T> struct Traits;
278
279 //! Create container from elementwise singleton values.
280 template <typename T>
makeIVal(const T & value)281 typename Traits<T>::IVal makeIVal (const T& value)
282 {
283 return Traits<T>::doMakeIVal(value);
284 }
285
286 //! Elementwise union of intervals.
287 template <typename T>
unionIVal(const typename Traits<T>::IVal & a,const typename Traits<T>::IVal & b)288 typename Traits<T>::IVal unionIVal (const typename Traits<T>::IVal& a,
289 const typename Traits<T>::IVal& b)
290 {
291 return Traits<T>::doUnion(a, b);
292 }
293
294 //! Returns true iff every element of `ival` contains the corresponding element of `value`.
295 template <typename T, typename U = Void>
contains(const typename Traits<T>::IVal & ival,const T & value,bool is16Bit=false,const tcu::Maybe<U> & modularDivisor=tcu::nothing<U> ())296 bool contains (const typename Traits<T>::IVal& ival, const T& value, bool is16Bit = false, const tcu::Maybe<U>& modularDivisor = tcu::nothing<U>())
297 {
298 return Traits<T>::doContains(ival, value, is16Bit, modularDivisor);
299 }
300
301 //! Print out an interval with the precision of `fmt`.
302 template <typename T>
printIVal(const FloatFormat & fmt,const typename Traits<T>::IVal & ival,ostream & os)303 void printIVal (const FloatFormat& fmt, const typename Traits<T>::IVal& ival, ostream& os)
304 {
305 Traits<T>::doPrintIVal(fmt, ival, os);
306 }
307
308 template <typename T>
intervalToString(const FloatFormat & fmt,const typename Traits<T>::IVal & ival)309 string intervalToString (const FloatFormat& fmt, const typename Traits<T>::IVal& ival)
310 {
311 ostringstream oss;
312 printIVal<T>(fmt, ival, oss);
313 return oss.str();
314 }
315
316 //! Print out a value with the precision of `fmt`.
317 template <typename T>
printValue16(const FloatFormat & fmt,const T & value,ostream & os)318 void printValue16 (const FloatFormat& fmt, const T& value, ostream& os)
319 {
320 Traits<T>::doPrintValue16(fmt, value, os);
321 }
322
323 template <typename T>
value16ToString(const FloatFormat & fmt,const T & val)324 string value16ToString(const FloatFormat& fmt, const T& val)
325 {
326 ostringstream oss;
327 printValue16(fmt, val, oss);
328 return oss.str();
329 }
330
getComparisonOperation(const int ndx)331 const std::string getComparisonOperation(const int ndx)
332 {
333 const int operationCount = 10;
334 DE_ASSERT(de::inBounds(ndx, 0, operationCount));
335 const std::string operations[operationCount] =
336 {
337 "OpFOrdEqual\t\t\t",
338 "OpFOrdGreaterThan\t",
339 "OpFOrdLessThan\t\t",
340 "OpFOrdGreaterThanEqual",
341 "OpFOrdLessThanEqual\t",
342 "OpFUnordEqual\t\t",
343 "OpFUnordGreaterThan\t",
344 "OpFUnordLessThan\t",
345 "OpFUnordGreaterThanEqual",
346 "OpFUnordLessThanEqual"
347 };
348 return operations[ndx];
349 }
350
351 template <typename T>
comparisonMessage(const T & val)352 string comparisonMessage(const T& val)
353 {
354 DE_UNREF(val);
355 return "";
356 }
357
358 template <>
comparisonMessage(const int & val)359 string comparisonMessage(const int& val)
360 {
361 ostringstream oss;
362
363 int flags = val;
364 for(int ndx = 0; ndx < 10; ++ndx)
365 {
366 oss << getComparisonOperation(ndx) << "\t:\t" << ((flags & 1) == 1 ? "TRUE" : "FALSE") << "\n";
367 flags = flags >> 1;
368 }
369 return oss.str();
370 }
371
372 template <>
comparisonMessage(const tcu::IVec2 & val)373 string comparisonMessage(const tcu::IVec2& val)
374 {
375 ostringstream oss;
376 tcu::IVec2 flags = val;
377 for (int ndx = 0; ndx < 10; ++ndx)
378 {
379 oss << getComparisonOperation(ndx) << "\t:\t" << ((flags.x() & 1) == 1 ? "TRUE" : "FALSE") << "\t" << ((flags.y() & 1) == 1 ? "TRUE" : "FALSE") << "\n";
380 flags.x() = flags.x() >> 1;
381 flags.y() = flags.y() >> 1;
382 }
383 return oss.str();
384 }
385
386 template <>
comparisonMessage(const tcu::IVec3 & val)387 string comparisonMessage(const tcu::IVec3& val)
388 {
389 ostringstream oss;
390 tcu::IVec3 flags = val;
391 for (int ndx = 0; ndx < 10; ++ndx)
392 {
393 oss << getComparisonOperation(ndx) << "\t:\t" << ((flags.x() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
394 << ((flags.y() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
395 << ((flags.z() & 1) == 1 ? "TRUE" : "FALSE") << "\n";
396 flags.x() = flags.x() >> 1;
397 flags.y() = flags.y() >> 1;
398 flags.z() = flags.z() >> 1;
399 }
400 return oss.str();
401 }
402
403 template <>
comparisonMessage(const tcu::IVec4 & val)404 string comparisonMessage(const tcu::IVec4& val)
405 {
406 ostringstream oss;
407 tcu::IVec4 flags = val;
408 for (int ndx = 0; ndx < 10; ++ndx)
409 {
410 oss << getComparisonOperation(ndx) << "\t:\t" << ((flags.x() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
411 << ((flags.y() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
412 << ((flags.z() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
413 << ((flags.w() & 1) == 1 ? "TRUE" : "FALSE") << "\n";
414 flags.x() = flags.x() >> 1;
415 flags.y() = flags.y() >> 1;
416 flags.z() = flags.z() >> 1;
417 flags.w() = flags.z() >> 1;
418 }
419 return oss.str();
420 }
421 //! Print out a value with the precision of `fmt`.
422 template <typename T>
printValue32(const FloatFormat & fmt,const T & value,ostream & os)423 void printValue32 (const FloatFormat& fmt, const T& value, ostream& os)
424 {
425 Traits<T>::doPrintValue32(fmt, value, os);
426 }
427
428 template <typename T>
value32ToString(const FloatFormat & fmt,const T & val)429 string value32ToString (const FloatFormat& fmt, const T& val)
430 {
431 ostringstream oss;
432 printValue32(fmt, val, oss);
433 return oss.str();
434 }
435
436 template <typename T>
printValue64(const FloatFormat & fmt,const T & value,ostream & os)437 void printValue64 (const FloatFormat& fmt, const T& value, ostream& os)
438 {
439 Traits<T>::doPrintValue64(fmt, value, os);
440 }
441
442 template <typename T>
value64ToString(const FloatFormat & fmt,const T & val)443 string value64ToString (const FloatFormat& fmt, const T& val)
444 {
445 ostringstream oss;
446 printValue64(fmt, val, oss);
447 return oss.str();
448 }
449
450 //! Approximate `value` elementwise to the float precision defined in `fmt`.
451 //! The resulting interval might not be a singleton if rounding in both
452 //! directions is allowed.
453 template <typename T>
round(const FloatFormat & fmt,const T & value)454 typename Traits<T>::IVal round (const FloatFormat& fmt, const T& value)
455 {
456 return Traits<T>::doRound(fmt, value);
457 }
458
459 template <typename T>
convert(const FloatFormat & fmt,const typename Traits<T>::IVal & value)460 typename Traits<T>::IVal convert (const FloatFormat& fmt,
461 const typename Traits<T>::IVal& value)
462 {
463 return Traits<T>::doConvert(fmt, value);
464 }
465
466 // Matching input and output types. We may be in a modulo case and modularDivisor may have an actual value.
467 template <typename T>
intervalContains(const Interval & interval,T value,const tcu::Maybe<T> & modularDivisor)468 bool intervalContains (const Interval& interval, T value, const tcu::Maybe<T>& modularDivisor)
469 {
470 bool contained = interval.contains(value);
471
472 if (!contained && modularDivisor)
473 {
474 const T divisor = modularDivisor.get();
475
476 // In a modulo operation, if the calculated answer contains the divisor, allow exactly 0.0 as a replacement. Alternatively,
477 // if the calculated answer contains 0.0, allow exactly the divisor as a replacement.
478 if (interval.contains(static_cast<double>(divisor)))
479 contained |= (value == 0.0);
480 if (interval.contains(0.0))
481 contained |= (value == divisor);
482 }
483 return contained;
484 }
485
486 // When the input and output types do not match, we are not in a real modulo operation. Do not take the divisor into account. This
487 // version is provided for syntactical compatibility only.
488 template <typename T, typename U>
intervalContains(const Interval & interval,T value,const tcu::Maybe<U> & modularDivisor)489 bool intervalContains (const Interval& interval, T value, const tcu::Maybe<U>& modularDivisor)
490 {
491 DE_UNREF(modularDivisor); // For release builds.
492 DE_ASSERT(!modularDivisor);
493 return interval.contains(value);
494 }
495
496 //! Common traits for scalar types.
497 template <typename T>
498 struct ScalarTraits
499 {
500 typedef Interval IVal;
501
doMakeIValvkt::shaderexecutor::ScalarTraits502 static Interval doMakeIVal (const T& value)
503 {
504 // Thankfully all scalar types have a well-defined conversion to `double`,
505 // hence Interval can represent their ranges without problems.
506 return Interval(double(value));
507 }
508
doUnionvkt::shaderexecutor::ScalarTraits509 static Interval doUnion (const Interval& a, const Interval& b)
510 {
511 return a | b;
512 }
513
doContainsvkt::shaderexecutor::ScalarTraits514 static bool doContains (const Interval& a, T value)
515 {
516 return a.contains(double(value));
517 }
518
doConvertvkt::shaderexecutor::ScalarTraits519 static Interval doConvert (const FloatFormat& fmt, const IVal& ival)
520 {
521 return fmt.convert(ival);
522 }
523
doConvertvkt::shaderexecutor::ScalarTraits524 static Interval doConvert (const FloatFormat& fmt, const IVal& ival, bool is16Bit)
525 {
526 DE_UNREF(is16Bit);
527 return fmt.convert(ival);
528 }
529
doRoundvkt::shaderexecutor::ScalarTraits530 static Interval doRound (const FloatFormat& fmt, T value)
531 {
532 return fmt.roundOut(double(value), false);
533 }
534 };
535
536 template <>
537 struct ScalarTraits<deUint16>
538 {
539 typedef Interval IVal;
540
doMakeIValvkt::shaderexecutor::ScalarTraits541 static Interval doMakeIVal (const deUint16& value)
542 {
543 // Thankfully all scalar types have a well-defined conversion to `double`,
544 // hence Interval can represent their ranges without problems.
545 return Interval(double(deFloat16To32(value)));
546 }
547
doUnionvkt::shaderexecutor::ScalarTraits548 static Interval doUnion (const Interval& a, const Interval& b)
549 {
550 return a | b;
551 }
552
doConvertvkt::shaderexecutor::ScalarTraits553 static Interval doConvert (const FloatFormat& fmt, const IVal& ival)
554 {
555 return fmt.convert(ival);
556 }
557
doRoundvkt::shaderexecutor::ScalarTraits558 static Interval doRound (const FloatFormat& fmt, deUint16 value)
559 {
560 return fmt.roundOut(double(deFloat16To32(value)), false);
561 }
562 };
563
564 template<>
565 struct Traits<float> : ScalarTraits<float>
566 {
doPrintIValvkt::shaderexecutor::Traits567 static void doPrintIVal (const FloatFormat& fmt,
568 const Interval& ival,
569 ostream& os)
570 {
571 os << fmt.intervalToHex(ival);
572 }
573
doPrintValue16vkt::shaderexecutor::Traits574 static void doPrintValue16 (const FloatFormat& fmt,
575 const float& value,
576 ostream& os)
577 {
578 const deUint32 iRep = reinterpret_cast<const deUint32 & >(value);
579 float res0 = deFloat16To32((deFloat16)(iRep & 0xFFFF));
580 float res1 = deFloat16To32((deFloat16)(iRep >> 16));
581 os << fmt.floatToHex(res0) << " " << fmt.floatToHex(res1);
582 }
583
doPrintValue32vkt::shaderexecutor::Traits584 static void doPrintValue32 (const FloatFormat& fmt,
585 const float& value,
586 ostream& os)
587 {
588 os << fmt.floatToHex(value);
589 }
590
doPrintValue64vkt::shaderexecutor::Traits591 static void doPrintValue64 (const FloatFormat& fmt,
592 const float& value,
593 ostream& os)
594 {
595 os << fmt.floatToHex(value);
596 }
597
598 template <typename U>
doContainsvkt::shaderexecutor::Traits599 static bool doContains (const Interval& a, const float& value, bool is16Bit, const tcu::Maybe<U>& modularDivisor)
600 {
601 if(is16Bit)
602 {
603 // Note: for deFloat16s packed in 32 bits, the original divisor is provided as a float to the shader in the input
604 // buffer, so U is also float here and we call the right interlvalContains() version.
605 const deUint32 iRep = reinterpret_cast<const deUint32&>(value);
606 float res0 = deFloat16To32((deFloat16)(iRep & 0xFFFF));
607 float res1 = deFloat16To32((deFloat16)(iRep >> 16));
608 return intervalContains(a, res0, modularDivisor) && (res1 == -1.0);
609 }
610 return intervalContains(a, value, modularDivisor);
611 }
612 };
613
614 template<>
615 struct Traits<double> : ScalarTraits<double>
616 {
doPrintIValvkt::shaderexecutor::Traits617 static void doPrintIVal (const FloatFormat& fmt,
618 const Interval& ival,
619 ostream& os)
620 {
621 os << fmt.intervalToHex(ival);
622 }
623
doPrintValue16vkt::shaderexecutor::Traits624 static void doPrintValue16 (const FloatFormat& fmt,
625 const double& value,
626 ostream& os)
627 {
628 const deUint64 iRep = reinterpret_cast<const deUint64&>(value);
629 double byte0 = deFloat16To64((deFloat16)((iRep ) & 0xffff));
630 double byte1 = deFloat16To64((deFloat16)((iRep >> 16) & 0xffff));
631 double byte2 = deFloat16To64((deFloat16)((iRep >> 32) & 0xffff));
632 double byte3 = deFloat16To64((deFloat16)((iRep >> 48) & 0xffff));
633 os << fmt.floatToHex(byte0) << " " << fmt.floatToHex(byte1) << " " << fmt.floatToHex(byte2) << " " << fmt.floatToHex(byte3);
634 }
635
doPrintValue32vkt::shaderexecutor::Traits636 static void doPrintValue32 (const FloatFormat& fmt,
637 const double& value,
638 ostream& os)
639 {
640 const deUint64 iRep = reinterpret_cast<const deUint64&>(value);
641 double res0 = static_cast<double>((float)((iRep ) & 0xffffffff));
642 double res1 = static_cast<double>((float)((iRep >> 32) & 0xffffffff));
643 os << fmt.floatToHex(res0) << " " << fmt.floatToHex(res1);
644 }
645
doPrintValue64vkt::shaderexecutor::Traits646 static void doPrintValue64 (const FloatFormat& fmt,
647 const double& value,
648 ostream& os)
649 {
650 os << fmt.floatToHex(value);
651 }
652
653 template <class U>
doContainsvkt::shaderexecutor::Traits654 static bool doContains (const Interval& a, const double& value, bool is16Bit, const tcu::Maybe<U>& modularDivisor)
655 {
656 DE_UNREF(is16Bit);
657 DE_ASSERT(!is16Bit);
658 return intervalContains(a, value, modularDivisor);
659 }
660 };
661
662 template<>
663 struct Traits<deFloat16> : ScalarTraits<deFloat16>
664 {
doPrintIValvkt::shaderexecutor::Traits665 static void doPrintIVal (const FloatFormat& fmt,
666 const Interval& ival,
667 ostream& os)
668 {
669 os << fmt.intervalToHex(ival);
670 }
671
doPrintValue16vkt::shaderexecutor::Traits672 static void doPrintValue16 (const FloatFormat& fmt,
673 const deFloat16& value,
674 ostream& os)
675 {
676 const float res0 = deFloat16To32(value);
677 os << fmt.floatToHex(static_cast<double>(res0));
678 }
doPrintValue32vkt::shaderexecutor::Traits679 static void doPrintValue32 (const FloatFormat& fmt,
680 const deFloat16& value,
681 ostream& os)
682 {
683 const float res0 = deFloat16To32(value);
684 os << fmt.floatToHex(static_cast<double>(res0));
685 }
686
doPrintValue64vkt::shaderexecutor::Traits687 static void doPrintValue64 (const FloatFormat& fmt,
688 const deFloat16& value,
689 ostream& os)
690 {
691 const double res0 = deFloat16To64(value);
692 os << fmt.floatToHex(res0);
693 }
694
695 // When the value and divisor are both deFloat16, convert both to float to call the right intervalContains version.
doContainsvkt::shaderexecutor::Traits696 static bool doContains (const Interval& a, const deFloat16& value, bool is16Bit, const tcu::Maybe<deFloat16>& modularDivisor)
697 {
698 DE_UNREF(is16Bit);
699 float res0 = deFloat16To32(value);
700 const tcu::Maybe<float> convertedDivisor = (modularDivisor ? tcu::just(deFloat16To32(modularDivisor.get())) : tcu::nothing<float>());
701 return intervalContains(a, res0, convertedDivisor);
702 }
703
704 // If the types don't match we should not be in a modulo operation, so no conversion should take place.
705 template <class U>
doContainsvkt::shaderexecutor::Traits706 static bool doContains (const Interval& a, const deFloat16& value, bool is16Bit, const tcu::Maybe<U>& modularDivisor)
707 {
708 DE_UNREF(is16Bit);
709 float res0 = deFloat16To32(value);
710 return intervalContains(a, res0, modularDivisor);
711 }
712 };
713
714 template<>
715 struct Traits<bool> : ScalarTraits<bool>
716 {
doPrintValue16vkt::shaderexecutor::Traits717 static void doPrintValue16 (const FloatFormat&,
718 const float& value,
719 ostream& os)
720 {
721 os << (value != 0.0f ? "true" : "false");
722 }
723
doPrintValue32vkt::shaderexecutor::Traits724 static void doPrintValue32 (const FloatFormat&,
725 const float& value,
726 ostream& os)
727 {
728 os << (value != 0.0f ? "true" : "false");
729 }
730
doPrintValue64vkt::shaderexecutor::Traits731 static void doPrintValue64 (const FloatFormat&,
732 const float& value,
733 ostream& os)
734 {
735 os << (value != 0.0f ? "true" : "false");
736 }
737
doPrintIValvkt::shaderexecutor::Traits738 static void doPrintIVal (const FloatFormat&,
739 const Interval& ival,
740 ostream& os)
741 {
742 os << "{";
743 if (ival.contains(false))
744 os << "false";
745 if (ival.contains(false) && ival.contains(true))
746 os << ", ";
747 if (ival.contains(true))
748 os << "true";
749 os << "}";
750 }
751 };
752
753 template<>
754 struct Traits<int> : ScalarTraits<int>
755 {
doPrintValue16vkt::shaderexecutor::Traits756 static void doPrintValue16 (const FloatFormat&,
757 const int& value,
758 ostream& os)
759 {
760 int res0 = value & 0xFFFF;
761 int res1 = value >> 16;
762 os << res0 << " " << res1;
763 }
764
doPrintValue32vkt::shaderexecutor::Traits765 static void doPrintValue32 (const FloatFormat&,
766 const int& value,
767 ostream& os)
768 {
769 os << value;
770 }
771
doPrintValue64vkt::shaderexecutor::Traits772 static void doPrintValue64 (const FloatFormat&,
773 const int& value,
774 ostream& os)
775 {
776 os << value;
777 }
778
doPrintIValvkt::shaderexecutor::Traits779 static void doPrintIVal (const FloatFormat&,
780 const Interval& ival,
781 ostream& os)
782 {
783 os << "[" << int(ival.lo()) << ", " << int(ival.hi()) << "]";
784 }
785
786 template <typename U>
doContainsvkt::shaderexecutor::Traits787 static bool doContains (const Interval& a, const int& value, bool is16Bit, const tcu::Maybe<U>& modularDivisor)
788 {
789 DE_UNREF(is16Bit);
790 return intervalContains(a, value, modularDivisor);
791 }
792 };
793
794 //! Common traits for containers, i.e. vectors and matrices.
795 //! T is the container type itself, I is the same type with interval elements.
796 template <typename T, typename I>
797 struct ContainerTraits
798 {
799 typedef typename T::Element Element;
800 typedef I IVal;
801
doMakeIValvkt::shaderexecutor::ContainerTraits802 static IVal doMakeIVal (const T& value)
803 {
804 IVal ret;
805
806 for (int ndx = 0; ndx < T::SIZE; ++ndx)
807 ret[ndx] = makeIVal(value[ndx]);
808
809 return ret;
810 }
811
doUnionvkt::shaderexecutor::ContainerTraits812 static IVal doUnion (const IVal& a, const IVal& b)
813 {
814 IVal ret;
815
816 for (int ndx = 0; ndx < T::SIZE; ++ndx)
817 ret[ndx] = unionIVal<Element>(a[ndx], b[ndx]);
818
819 return ret;
820 }
821
822 // When the input and output types match, we may be in a modulo operation. If the divisor is provided, use each of its
823 // components to determine if the obtained result is fine.
doContainsvkt::shaderexecutor::ContainerTraits824 static bool doContains (const IVal& ival, const T& value, bool is16Bit, const tcu::Maybe<T>& modularDivisor)
825 {
826 using DivisorElement = typename T::Element;
827
828 for (int ndx = 0; ndx < T::SIZE; ++ndx)
829 {
830 const tcu::Maybe<DivisorElement> divisorElement = (modularDivisor ? tcu::just((*modularDivisor)[ndx]) : tcu::nothing<DivisorElement>());
831 if (!contains(ival[ndx], value[ndx], is16Bit, divisorElement))
832 return false;
833 }
834
835 return true;
836 }
837
838 // When the input and output types do not match we should not be in a modulo operation. This version is provided for syntactical
839 // compatibility.
840 template <typename U>
doContainsvkt::shaderexecutor::ContainerTraits841 static bool doContains (const IVal& ival, const T& value, bool is16Bit, const tcu::Maybe<U>& modularDivisor)
842 {
843 for (int ndx = 0; ndx < T::SIZE; ++ndx)
844 {
845 if (!contains(ival[ndx], value[ndx], is16Bit, modularDivisor))
846 return false;
847 }
848
849 return true;
850 }
851
doPrintIValvkt::shaderexecutor::ContainerTraits852 static void doPrintIVal (const FloatFormat& fmt, const IVal ival, ostream& os)
853 {
854 os << "(";
855
856 for (int ndx = 0; ndx < T::SIZE; ++ndx)
857 {
858 if (ndx > 0)
859 os << ", ";
860
861 printIVal<Element>(fmt, ival[ndx], os);
862 }
863
864 os << ")";
865 }
866
doPrintValue16vkt::shaderexecutor::ContainerTraits867 static void doPrintValue16 (const FloatFormat& fmt, const T& value, ostream& os)
868 {
869 os << dataTypeNameOf<T>() << "(";
870
871 for (int ndx = 0; ndx < T::SIZE; ++ndx)
872 {
873 if (ndx > 0)
874 os << ", ";
875
876 printValue16<Element>(fmt, value[ndx], os);
877 }
878
879 os << ")";
880 }
881
doPrintValue32vkt::shaderexecutor::ContainerTraits882 static void doPrintValue32 (const FloatFormat& fmt, const T& value, ostream& os)
883 {
884 os << dataTypeNameOf<T>() << "(";
885
886 for (int ndx = 0; ndx < T::SIZE; ++ndx)
887 {
888 if (ndx > 0)
889 os << ", ";
890
891 printValue32<Element>(fmt, value[ndx], os);
892 }
893
894 os << ")";
895 }
896
doPrintValue64vkt::shaderexecutor::ContainerTraits897 static void doPrintValue64 (const FloatFormat& fmt, const T& value, ostream& os)
898 {
899 os << dataTypeNameOf<T>() << "(";
900
901 for (int ndx = 0; ndx < T::SIZE; ++ndx)
902 {
903 if (ndx > 0)
904 os << ", ";
905
906 printValue64<Element>(fmt, value[ndx], os);
907 }
908
909 os << ")";
910 }
911
doConvertvkt::shaderexecutor::ContainerTraits912 static IVal doConvert (const FloatFormat& fmt, const IVal& value)
913 {
914 IVal ret;
915
916 for (int ndx = 0; ndx < T::SIZE; ++ndx)
917 ret[ndx] = convert<Element>(fmt, value[ndx]);
918
919 return ret;
920 }
921
doRoundvkt::shaderexecutor::ContainerTraits922 static IVal doRound (const FloatFormat& fmt, T value)
923 {
924 IVal ret;
925
926 for (int ndx = 0; ndx < T::SIZE; ++ndx)
927 ret[ndx] = round(fmt, value[ndx]);
928
929 return ret;
930 }
931 };
932
933 template <typename T, int Size>
934 struct Traits<Vector<T, Size> > :
935 ContainerTraits<Vector<T, Size>, Vector<typename Traits<T>::IVal, Size> >
936 {
937 };
938
939 template <typename T, int Rows, int Cols>
940 struct Traits<Matrix<T, Rows, Cols> > :
941 ContainerTraits<Matrix<T, Rows, Cols>, Matrix<typename Traits<T>::IVal, Rows, Cols> >
942 {
943 };
944
945 //! Void traits. These are just dummies, but technically valid: a Void is a
946 //! unit type with a single possible value.
947 template<>
948 struct Traits<Void>
949 {
950 typedef Void IVal;
951
doMakeIValvkt::shaderexecutor::Traits952 static Void doMakeIVal (const Void& value) { return value; }
doUnionvkt::shaderexecutor::Traits953 static Void doUnion (const Void&, const Void&) { return Void(); }
doContainsvkt::shaderexecutor::Traits954 static bool doContains (const Void&, Void) { return true; }
955 template <typename U>
doContainsvkt::shaderexecutor::Traits956 static bool doContains (const Void&, const Void& value, bool is16Bit, const tcu::Maybe<U>& modularDivisor) { DE_UNREF(value); DE_UNREF(is16Bit); DE_UNREF(modularDivisor); return true; }
doRoundvkt::shaderexecutor::Traits957 static Void doRound (const FloatFormat&, const Void& value) { return value; }
doConvertvkt::shaderexecutor::Traits958 static Void doConvert (const FloatFormat&, const Void& value) { return value; }
959
doPrintValue16vkt::shaderexecutor::Traits960 static void doPrintValue16 (const FloatFormat&, const Void&, ostream& os)
961 {
962 os << "()";
963 }
964
doPrintValue32vkt::shaderexecutor::Traits965 static void doPrintValue32 (const FloatFormat&, const Void&, ostream& os)
966 {
967 os << "()";
968 }
969
doPrintValue64vkt::shaderexecutor::Traits970 static void doPrintValue64 (const FloatFormat&, const Void&, ostream& os)
971 {
972 os << "()";
973 }
974
doPrintIValvkt::shaderexecutor::Traits975 static void doPrintIVal (const FloatFormat&, const Void&, ostream& os)
976 {
977 os << "()";
978 }
979 };
980
981 //! This is needed for container-generic operations.
982 //! We want a scalar type T to be its own "one-element vector".
983 template <typename T, int Size> struct ContainerOf { typedef Vector<T, Size> Container; };
984
985 template <typename T> struct ContainerOf<T, 1> { typedef T Container; };
986 template <int Size> struct ContainerOf<Void, Size> { typedef Void Container; };
987
988 // This is a kludge that is only needed to get the ExprP::operator[] syntactic sugar to work.
989 template <typename T> struct ElementOf { typedef typename T::Element Element; };
990 template <> struct ElementOf<float> { typedef void Element; };
991 template <> struct ElementOf<double>{ typedef void Element; };
992 template <> struct ElementOf<bool> { typedef void Element; };
993 template <> struct ElementOf<int> { typedef void Element; };
994
995 template <typename T>
comparisonMessageInterval(const typename Traits<T>::IVal & val)996 string comparisonMessageInterval(const typename Traits<T>::IVal& val)
997 {
998 DE_UNREF(val);
999 return "";
1000 }
1001
1002 template <>
comparisonMessageInterval(const Traits<int>::IVal & val)1003 string comparisonMessageInterval<int>(const Traits<int>::IVal& val)
1004 {
1005 return comparisonMessage(static_cast<int>(val.lo()));
1006 }
1007
1008 template <>
comparisonMessageInterval(const Traits<float>::IVal & val)1009 string comparisonMessageInterval<float>(const Traits<float>::IVal& val)
1010 {
1011 return comparisonMessage(static_cast<int>(val.lo()));
1012 }
1013
1014 template <>
comparisonMessageInterval(const tcu::Vector<tcu::Interval,2> & val)1015 string comparisonMessageInterval<tcu::Vector<int, 2> >(const tcu::Vector<tcu::Interval, 2> & val)
1016 {
1017 tcu::IVec2 result(static_cast<int>(val[0].lo()), static_cast<int>(val[1].lo()));
1018 return comparisonMessage(result);
1019 }
1020
1021 template <>
comparisonMessageInterval(const tcu::Vector<tcu::Interval,3> & val)1022 string comparisonMessageInterval<tcu::Vector<int, 3> >(const tcu::Vector<tcu::Interval, 3> & val)
1023 {
1024 tcu::IVec3 result(static_cast<int>(val[0].lo()), static_cast<int>(val[1].lo()), static_cast<int>(val[2].lo()));
1025 return comparisonMessage(result);
1026 }
1027
1028 template <>
comparisonMessageInterval(const tcu::Vector<tcu::Interval,4> & val)1029 string comparisonMessageInterval<tcu::Vector<int, 4> >(const tcu::Vector<tcu::Interval, 4> & val)
1030 {
1031 tcu::IVec4 result(static_cast<int>(val[0].lo()), static_cast<int>(val[1].lo()), static_cast<int>(val[2].lo()), static_cast<int>(val[3].lo()));
1032 return comparisonMessage(result);
1033 }
1034
1035 /*--------------------------------------------------------------------*//*!
1036 *
1037 * \name Abstract syntax for expressions and statements.
1038 *
1039 * We represent GLSL programs as syntax objects: an Expr<T> represents an
1040 * expression whose GLSL type corresponds to the C++ type T, and a Statement
1041 * represents a statement.
1042 *
1043 * To ease memory management, we use shared pointers to refer to expressions
1044 * and statements. ExprP<T> is a shared pointer to an Expr<T>, and StatementP
1045 * is a shared pointer to a Statement.
1046 *
1047 * \{
1048 *
1049 *//*--------------------------------------------------------------------*/
1050
1051 class ExprBase;
1052 class ExpandContext;
1053 class Statement;
1054 class StatementP;
1055 class FuncBase;
1056 template <typename T> class ExprP;
1057 template <typename T> class Variable;
1058 template <typename T> class VariableP;
1059 template <typename T> class DefaultSampling;
1060
1061 typedef set<const FuncBase*> FuncSet;
1062
1063 template <typename T>
1064 VariableP<T> variable (const string& name);
1065 StatementP compoundStatement (const vector<StatementP>& statements);
1066
1067 /*--------------------------------------------------------------------*//*!
1068 * \brief A variable environment.
1069 *
1070 * An Environment object maintains the mapping between variables of the
1071 * abstract syntax tree and their values.
1072 *
1073 * \todo [2014-03-28 lauri] At least run-time type safety.
1074 *
1075 *//*--------------------------------------------------------------------*/
1076 class Environment
1077 {
1078 public:
1079 template<typename T>
bind(const Variable<T> & variable,const typename Traits<T>::IVal & value)1080 void bind (const Variable<T>& variable,
1081 const typename Traits<T>::IVal& value)
1082 {
1083 deUint8* const data = new deUint8[sizeof(value)];
1084
1085 deMemcpy(data, &value, sizeof(value));
1086 de::insert(m_map, variable.getName(), SharedPtr<deUint8>(data, de::ArrayDeleter<deUint8>()));
1087 }
1088
1089 template<typename T>
lookup(const Variable<T> & variable) const1090 typename Traits<T>::IVal& lookup (const Variable<T>& variable) const
1091 {
1092 deUint8* const data = de::lookup(m_map, variable.getName()).get();
1093
1094 return *reinterpret_cast<typename Traits<T>::IVal*>(data);
1095 }
1096
1097 private:
1098 map<string, SharedPtr<deUint8> > m_map;
1099 };
1100
1101 /*--------------------------------------------------------------------*//*!
1102 * \brief Evaluation context.
1103 *
1104 * The evaluation context contains everything that separates one execution of
1105 * an expression from the next. Currently this means the desired floating
1106 * point precision and the current variable environment.
1107 *
1108 *//*--------------------------------------------------------------------*/
1109 struct EvalContext
1110 {
EvalContextvkt::shaderexecutor::EvalContext1111 EvalContext (const FloatFormat& format_,
1112 Precision floatPrecision_,
1113 Environment& env_,
1114 int callDepth_)
1115 : format (format_)
1116 , floatPrecision (floatPrecision_)
1117 , env (env_)
1118 , callDepth (callDepth_) {}
1119
1120 FloatFormat format;
1121 Precision floatPrecision;
1122 Environment& env;
1123 int callDepth;
1124 };
1125
1126 /*--------------------------------------------------------------------*//*!
1127 * \brief Simple incremental counter.
1128 *
1129 * This is used to make sure that different ExpandContexts will not produce
1130 * overlapping temporary names.
1131 *
1132 *//*--------------------------------------------------------------------*/
1133 class Counter
1134 {
1135 public:
Counter(int count=0)1136 Counter (int count = 0) : m_count(count) {}
operator ()(void)1137 int operator() (void) { return m_count++; }
1138
1139 private:
1140 int m_count;
1141 };
1142
1143 class ExpandContext
1144 {
1145 public:
ExpandContext(Counter & symCounter)1146 ExpandContext (Counter& symCounter) : m_symCounter(symCounter) {}
ExpandContext(const ExpandContext & parent)1147 ExpandContext (const ExpandContext& parent)
1148 : m_symCounter(parent.m_symCounter) {}
1149
1150 template<typename T>
genSym(const string & baseName)1151 VariableP<T> genSym (const string& baseName)
1152 {
1153 return variable<T>(baseName + de::toString(m_symCounter()));
1154 }
1155
addStatement(const StatementP & stmt)1156 void addStatement (const StatementP& stmt)
1157 {
1158 m_statements.push_back(stmt);
1159 }
1160
getStatements(void) const1161 vector<StatementP> getStatements (void) const
1162 {
1163 return m_statements;
1164 }
1165 private:
1166 Counter& m_symCounter;
1167 vector<StatementP> m_statements;
1168 };
1169
1170 /*--------------------------------------------------------------------*//*!
1171 * \brief A statement or declaration.
1172 *
1173 * Statements have no values. Instead, they are executed for their side
1174 * effects only: the execute() method should modify at least one variable in
1175 * the environment.
1176 *
1177 * As a bit of a kludge, a Statement object can also represent a declaration:
1178 * when it is evaluated, it can add a variable binding to the environment
1179 * instead of modifying a current one.
1180 *
1181 *//*--------------------------------------------------------------------*/
1182 class Statement
1183 {
1184 public:
~Statement(void)1185 virtual ~Statement (void) { }
1186 //! Execute the statement, modifying the environment of `ctx`
execute(EvalContext & ctx) const1187 void execute (EvalContext& ctx) const { this->doExecute(ctx); }
print(ostream & os) const1188 void print (ostream& os) const { this->doPrint(os); }
1189 //! Add the functions used in this statement to `dst`.
getUsedFuncs(FuncSet & dst) const1190 void getUsedFuncs (FuncSet& dst) const { this->doGetUsedFuncs(dst); }
failed(EvalContext & ctx) const1191 void failed (EvalContext& ctx) const { this->doFail(ctx); }
1192
1193 protected:
1194 virtual void doPrint (ostream& os) const = 0;
1195 virtual void doExecute (EvalContext& ctx) const = 0;
1196 virtual void doGetUsedFuncs (FuncSet& dst) const = 0;
doFail(EvalContext & ctx) const1197 virtual void doFail (EvalContext& ctx) const { DE_UNREF(ctx); }
1198 };
1199
operator <<(ostream & os,const Statement & stmt)1200 ostream& operator<<(ostream& os, const Statement& stmt)
1201 {
1202 stmt.print(os);
1203 return os;
1204 }
1205
1206 /*--------------------------------------------------------------------*//*!
1207 * \brief Smart pointer for statements (and declarations)
1208 *
1209 *//*--------------------------------------------------------------------*/
1210 class StatementP : public SharedPtr<const Statement>
1211 {
1212 public:
1213 typedef SharedPtr<const Statement> Super;
1214
StatementP(void)1215 StatementP (void) {}
StatementP(const Statement * ptr)1216 explicit StatementP (const Statement* ptr) : Super(ptr) {}
StatementP(const Super & ptr)1217 StatementP (const Super& ptr) : Super(ptr) {}
1218 };
1219
1220 /*--------------------------------------------------------------------*//*!
1221 * \brief
1222 *
1223 * A statement that modifies a variable or a declaration that binds a variable.
1224 *
1225 *//*--------------------------------------------------------------------*/
1226 template <typename T>
1227 class VariableStatement : public Statement
1228 {
1229 public:
VariableStatement(const VariableP<T> & variable,const ExprP<T> & value,bool isDeclaration)1230 VariableStatement (const VariableP<T>& variable, const ExprP<T>& value,
1231 bool isDeclaration)
1232 : m_variable (variable)
1233 , m_value (value)
1234 , m_isDeclaration (isDeclaration) {}
1235
1236 protected:
doPrint(ostream & os) const1237 void doPrint (ostream& os) const
1238 {
1239 if (m_isDeclaration)
1240 os << glu::declare(getVarTypeOf<T>(), m_variable->getName());
1241 else
1242 os << m_variable->getName();
1243
1244 os << " = ";
1245 os<< *m_value << ";\n";
1246 }
1247
doExecute(EvalContext & ctx) const1248 void doExecute (EvalContext& ctx) const
1249 {
1250 if (m_isDeclaration)
1251 ctx.env.bind(*m_variable, m_value->evaluate(ctx));
1252 else
1253 ctx.env.lookup(*m_variable) = m_value->evaluate(ctx);
1254 }
1255
doGetUsedFuncs(FuncSet & dst) const1256 void doGetUsedFuncs (FuncSet& dst) const
1257 {
1258 m_value->getUsedFuncs(dst);
1259 }
1260
doFail(EvalContext & ctx) const1261 virtual void doFail (EvalContext& ctx) const
1262 {
1263 if (m_isDeclaration)
1264 ctx.env.bind(*m_variable, m_value->fails(ctx));
1265 else
1266 ctx.env.lookup(*m_variable) = m_value->fails(ctx);
1267 }
1268
1269 VariableP<T> m_variable;
1270 ExprP<T> m_value;
1271 bool m_isDeclaration;
1272 };
1273
1274 template <typename T>
variableStatement(const VariableP<T> & variable,const ExprP<T> & value,bool isDeclaration)1275 StatementP variableStatement (const VariableP<T>& variable,
1276 const ExprP<T>& value,
1277 bool isDeclaration)
1278 {
1279 return StatementP(new VariableStatement<T>(variable, value, isDeclaration));
1280 }
1281
1282 template <typename T>
variableDeclaration(const VariableP<T> & variable,const ExprP<T> & definiens)1283 StatementP variableDeclaration (const VariableP<T>& variable, const ExprP<T>& definiens)
1284 {
1285 return variableStatement(variable, definiens, true);
1286 }
1287
1288 template <typename T>
variableAssignment(const VariableP<T> & variable,const ExprP<T> & value)1289 StatementP variableAssignment (const VariableP<T>& variable, const ExprP<T>& value)
1290 {
1291 return variableStatement(variable, value, false);
1292 }
1293
1294 /*--------------------------------------------------------------------*//*!
1295 * \brief A compound statement, i.e. a block.
1296 *
1297 * A compound statement is executed by executing its constituent statements in
1298 * sequence.
1299 *
1300 *//*--------------------------------------------------------------------*/
1301 class CompoundStatement : public Statement
1302 {
1303 public:
CompoundStatement(const vector<StatementP> & statements)1304 CompoundStatement (const vector<StatementP>& statements)
1305 : m_statements (statements) {}
1306
1307 protected:
doPrint(ostream & os) const1308 void doPrint (ostream& os) const
1309 {
1310 os << "{\n";
1311
1312 for (size_t ndx = 0; ndx < m_statements.size(); ++ndx)
1313 os << *m_statements[ndx];
1314
1315 os << "}\n";
1316 }
1317
doExecute(EvalContext & ctx) const1318 void doExecute (EvalContext& ctx) const
1319 {
1320 for (size_t ndx = 0; ndx < m_statements.size(); ++ndx)
1321 m_statements[ndx]->execute(ctx);
1322 }
1323
doGetUsedFuncs(FuncSet & dst) const1324 void doGetUsedFuncs (FuncSet& dst) const
1325 {
1326 for (size_t ndx = 0; ndx < m_statements.size(); ++ndx)
1327 m_statements[ndx]->getUsedFuncs(dst);
1328 }
1329
1330 vector<StatementP> m_statements;
1331 };
1332
compoundStatement(const vector<StatementP> & statements)1333 StatementP compoundStatement(const vector<StatementP>& statements)
1334 {
1335 return StatementP(new CompoundStatement(statements));
1336 }
1337
1338 //! Common base class for all expressions regardless of their type.
1339 class ExprBase
1340 {
1341 public:
~ExprBase(void)1342 virtual ~ExprBase (void) {}
printExpr(ostream & os) const1343 void printExpr (ostream& os) const { this->doPrintExpr(os); }
1344
1345 //! Output the functions that this expression refers to
getUsedFuncs(FuncSet & dst) const1346 void getUsedFuncs (FuncSet& dst) const
1347 {
1348 this->doGetUsedFuncs(dst);
1349 }
1350
1351 protected:
doPrintExpr(ostream &) const1352 virtual void doPrintExpr (ostream&) const {}
doGetUsedFuncs(FuncSet &) const1353 virtual void doGetUsedFuncs (FuncSet&) const {}
1354 };
1355
1356 //! Type-specific operations for an expression representing type T.
1357 template <typename T>
1358 class Expr : public ExprBase
1359 {
1360 public:
1361 typedef T Val;
1362 typedef typename Traits<T>::IVal IVal;
1363
1364 IVal evaluate (const EvalContext& ctx) const;
fails(const EvalContext & ctx) const1365 IVal fails (const EvalContext& ctx) const { return this->doFails(ctx); }
1366
1367 protected:
1368 virtual IVal doEvaluate (const EvalContext& ctx) const = 0;
doFails(const EvalContext & ctx) const1369 virtual IVal doFails (const EvalContext& ctx) const {return doEvaluate(ctx);}
1370 };
1371
1372 //! Evaluate an expression with the given context, optionally tracing the calls to stderr.
1373 template <typename T>
evaluate(const EvalContext & ctx) const1374 typename Traits<T>::IVal Expr<T>::evaluate (const EvalContext& ctx) const
1375 {
1376 #ifdef GLS_ENABLE_TRACE
1377 static const FloatFormat highpFmt (-126, 127, 23, true,
1378 tcu::MAYBE,
1379 tcu::YES,
1380 tcu::MAYBE);
1381 EvalContext newCtx (ctx.format, ctx.floatPrecision,
1382 ctx.env, ctx.callDepth + 1);
1383 const IVal ret = this->doEvaluate(newCtx);
1384
1385 if (isTypeValid<T>())
1386 {
1387 std::cerr << string(ctx.callDepth, ' ');
1388 this->printExpr(std::cerr);
1389 std::cerr << " -> " << intervalToString<T>(highpFmt, ret) << std::endl;
1390 }
1391 return ret;
1392 #else
1393 return this->doEvaluate(ctx);
1394 #endif
1395 }
1396
1397 template <typename T>
1398 class ExprPBase : public SharedPtr<const Expr<T> >
1399 {
1400 public:
1401 };
1402
operator <<(ostream & os,const ExprBase & expr)1403 ostream& operator<< (ostream& os, const ExprBase& expr)
1404 {
1405 expr.printExpr(os);
1406 return os;
1407 }
1408
1409 /*--------------------------------------------------------------------*//*!
1410 * \brief Shared pointer to an expression of a container type.
1411 *
1412 * Container types (i.e. vectors and matrices) support the subscription
1413 * operator. This class provides a bit of syntactic sugar to allow us to use
1414 * the C++ subscription operator to create a subscription expression.
1415 *//*--------------------------------------------------------------------*/
1416 template <typename T>
1417 class ContainerExprPBase : public ExprPBase<T>
1418 {
1419 public:
1420 ExprP<typename T::Element> operator[] (int i) const;
1421 };
1422
1423 template <typename T>
1424 class ExprP : public ExprPBase<T> {};
1425
1426 // We treat Voids as containers since the dummy parameters in generalized
1427 // vector functions are represented as Voids.
1428 template <>
1429 class ExprP<Void> : public ContainerExprPBase<Void> {};
1430
1431 template <typename T, int Size>
1432 class ExprP<Vector<T, Size> > : public ContainerExprPBase<Vector<T, Size> > {};
1433
1434 template <typename T, int Rows, int Cols>
1435 class ExprP<Matrix<T, Rows, Cols> > : public ContainerExprPBase<Matrix<T, Rows, Cols> > {};
1436
exprP(void)1437 template <typename T> ExprP<T> exprP (void)
1438 {
1439 return ExprP<T>();
1440 }
1441
1442 template <typename T>
exprP(const SharedPtr<const Expr<T>> & ptr)1443 ExprP<T> exprP (const SharedPtr<const Expr<T> >& ptr)
1444 {
1445 ExprP<T> ret;
1446 static_cast<SharedPtr<const Expr<T> >&>(ret) = ptr;
1447 return ret;
1448 }
1449
1450 template <typename T>
exprP(const Expr<T> * ptr)1451 ExprP<T> exprP (const Expr<T>* ptr)
1452 {
1453 return exprP(SharedPtr<const Expr<T> >(ptr));
1454 }
1455
1456 /*--------------------------------------------------------------------*//*!
1457 * \brief A shared pointer to a variable expression.
1458 *
1459 * This is just a narrowing of ExprP for the operations that require a variable
1460 * instead of an arbitrary expression.
1461 *
1462 *//*--------------------------------------------------------------------*/
1463 template <typename T>
1464 class VariableP : public SharedPtr<const Variable<T> >
1465 {
1466 public:
1467 typedef SharedPtr<const Variable<T> > Super;
VariableP(const Variable<T> * ptr)1468 explicit VariableP (const Variable<T>* ptr) : Super(ptr) {}
VariableP(void)1469 VariableP (void) {}
VariableP(const Super & ptr)1470 VariableP (const Super& ptr) : Super(ptr) {}
1471
operator ExprP<T>(void) const1472 operator ExprP<T> (void) const { return exprP(SharedPtr<const Expr<T> >(*this)); }
1473 };
1474
1475 /*--------------------------------------------------------------------*//*!
1476 * \name Syntactic sugar operators for expressions.
1477 *
1478 * @{
1479 *
1480 * These operators allow the use of C++ syntax to construct GLSL expressions
1481 * containing operators: e.g. "a+b" creates an addition expression with
1482 * operands a and b, and so on.
1483 *
1484 *//*--------------------------------------------------------------------*/
1485 ExprP<float> operator+ (const ExprP<float>& arg0,
1486 const ExprP<float>& arg1);
1487 ExprP<deFloat16> operator+ (const ExprP<deFloat16>& arg0,
1488 const ExprP<deFloat16>& arg1);
1489 ExprP<double> operator+ (const ExprP<double>& arg0,
1490 const ExprP<double>& arg1);
1491 template <typename T>
1492 ExprP<T> operator- (const ExprP<T>& arg0);
1493 template <typename T>
1494 ExprP<T> operator- (const ExprP<T>& arg0,
1495 const ExprP<T>& arg1);
1496 template<int Left, int Mid, int Right, typename T>
1497 ExprP<Matrix<T, Left, Right> > operator* (const ExprP<Matrix<T, Left, Mid> >& left,
1498 const ExprP<Matrix<T, Mid, Right> >& right);
1499 ExprP<float> operator* (const ExprP<float>& arg0,
1500 const ExprP<float>& arg1);
1501 ExprP<deFloat16> operator* (const ExprP<deFloat16>& arg0,
1502 const ExprP<deFloat16>& arg1);
1503 ExprP<double> operator* (const ExprP<double>& arg0,
1504 const ExprP<double>& arg1);
1505 template <typename T>
1506 ExprP<T> operator/ (const ExprP<T>& arg0,
1507 const ExprP<T>& arg1);
1508 template<typename T, int Size>
1509 ExprP<Vector<T, Size> > operator- (const ExprP<Vector<T, Size> >& arg0);
1510 template<typename T, int Size>
1511 ExprP<Vector<T, Size> > operator- (const ExprP<Vector<T, Size> >& arg0,
1512 const ExprP<Vector<T, Size> >& arg1);
1513 template<int Size, typename T>
1514 ExprP<Vector<T, Size> > operator* (const ExprP<Vector<T, Size> >& arg0,
1515 const ExprP<T>& arg1);
1516 template<typename T, int Size>
1517 ExprP<Vector<T, Size> > operator* (const ExprP<Vector<T, Size> >& arg0,
1518 const ExprP<Vector<T, Size> >& arg1);
1519 template<int Rows, int Cols, typename T>
1520 ExprP<Vector<T, Rows> > operator* (const ExprP<Vector<T, Cols> >& left,
1521 const ExprP<Matrix<T, Rows, Cols> >& right);
1522 template<int Rows, int Cols, typename T>
1523 ExprP<Vector<T, Cols> > operator* (const ExprP<Matrix<T, Rows, Cols> >& left,
1524 const ExprP<Vector<T, Rows> >& right);
1525 template<int Rows, int Cols, typename T>
1526 ExprP<Matrix<T, Rows, Cols> > operator* (const ExprP<Matrix<T, Rows, Cols> >& left,
1527 const ExprP<T>& right);
1528 template<int Rows, int Cols>
1529 ExprP<Matrix<float, Rows, Cols> > operator+ (const ExprP<Matrix<float, Rows, Cols> >& left,
1530 const ExprP<Matrix<float, Rows, Cols> >& right);
1531 template<int Rows, int Cols>
1532 ExprP<Matrix<deFloat16, Rows, Cols> > operator+ (const ExprP<Matrix<deFloat16, Rows, Cols> >& left,
1533 const ExprP<Matrix<deFloat16, Rows, Cols> >& right);
1534 template<int Rows, int Cols>
1535 ExprP<Matrix<double, Rows, Cols> > operator+ (const ExprP<Matrix<double, Rows, Cols> >& left,
1536 const ExprP<Matrix<double, Rows, Cols> >& right);
1537 template<typename T, int Rows, int Cols>
1538 ExprP<Matrix<T, Rows, Cols> > operator- (const ExprP<Matrix<T, Rows, Cols> >& mat);
1539
1540 //! @}
1541
1542 /*--------------------------------------------------------------------*//*!
1543 * \brief Variable expression.
1544 *
1545 * A variable is evaluated by looking up its range of possible values from an
1546 * environment.
1547 *//*--------------------------------------------------------------------*/
1548 template <typename T>
1549 class Variable : public Expr<T>
1550 {
1551 public:
1552 typedef typename Expr<T>::IVal IVal;
1553
Variable(const string & name)1554 Variable (const string& name) : m_name (name) {}
getName(void) const1555 string getName (void) const { return m_name; }
1556
1557 protected:
doPrintExpr(ostream & os) const1558 void doPrintExpr (ostream& os) const { os << m_name; }
doEvaluate(const EvalContext & ctx) const1559 IVal doEvaluate (const EvalContext& ctx) const
1560 {
1561 return ctx.env.lookup<T>(*this);
1562 }
1563
1564 private:
1565 string m_name;
1566 };
1567
1568 template <typename T>
variable(const string & name)1569 VariableP<T> variable (const string& name)
1570 {
1571 return VariableP<T>(new Variable<T>(name));
1572 }
1573
1574 template <typename T>
bindExpression(const string & name,ExpandContext & ctx,const ExprP<T> & expr)1575 VariableP<T> bindExpression (const string& name, ExpandContext& ctx, const ExprP<T>& expr)
1576 {
1577 VariableP<T> var = ctx.genSym<T>(name);
1578 ctx.addStatement(variableDeclaration(var, expr));
1579 return var;
1580 }
1581
1582 /*--------------------------------------------------------------------*//*!
1583 * \brief Constant expression.
1584 *
1585 * A constant is evaluated by rounding it to a set of possible values allowed
1586 * by the current floating point precision.
1587 *//*--------------------------------------------------------------------*/
1588 template <typename T>
1589 class Constant : public Expr<T>
1590 {
1591 public:
1592 typedef typename Expr<T>::IVal IVal;
1593
Constant(const T & value)1594 Constant (const T& value) : m_value(value) {}
1595
1596 protected:
doPrintExpr(ostream & os) const1597 void doPrintExpr (ostream& os) const { os << m_value; }
doEvaluate(const EvalContext &) const1598 IVal doEvaluate (const EvalContext&) const { return makeIVal(m_value); }
1599
1600 private:
1601 T m_value;
1602 };
1603
1604 template <typename T>
constant(const T & value)1605 ExprP<T> constant (const T& value)
1606 {
1607 return exprP(new Constant<T>(value));
1608 }
1609
1610 //! Return a reference to a singleton void constant.
voidP(void)1611 const ExprP<Void>& voidP (void)
1612 {
1613 static const ExprP<Void> singleton = constant(Void());
1614
1615 return singleton;
1616 }
1617
1618 /*--------------------------------------------------------------------*//*!
1619 * \brief Four-element tuple.
1620 *
1621 * This is used for various things where we need one thing for each possible
1622 * function parameter. Currently the maximum supported number of parameters is
1623 * four.
1624 *//*--------------------------------------------------------------------*/
1625 template <typename T0 = Void, typename T1 = Void, typename T2 = Void, typename T3 = Void>
1626 struct Tuple4
1627 {
Tuple4vkt::shaderexecutor::Tuple41628 explicit Tuple4 (const T0 e0 = T0(),
1629 const T1 e1 = T1(),
1630 const T2 e2 = T2(),
1631 const T3 e3 = T3())
1632 : a (e0)
1633 , b (e1)
1634 , c (e2)
1635 , d (e3)
1636 {
1637 }
1638
1639 T0 a;
1640 T1 b;
1641 T2 c;
1642 T3 d;
1643 };
1644
1645 /*--------------------------------------------------------------------*//*!
1646 * \brief Function signature.
1647 *
1648 * This is a purely compile-time structure used to bundle all types in a
1649 * function signature together. This makes passing the signature around in
1650 * templates easier, since we only need to take and pass a single Sig instead
1651 * of a bunch of parameter types and a return type.
1652 *
1653 *//*--------------------------------------------------------------------*/
1654 template <typename R,
1655 typename P0 = Void, typename P1 = Void,
1656 typename P2 = Void, typename P3 = Void>
1657 struct Signature
1658 {
1659 typedef R Ret;
1660 typedef P0 Arg0;
1661 typedef P1 Arg1;
1662 typedef P2 Arg2;
1663 typedef P3 Arg3;
1664 typedef typename Traits<Ret>::IVal IRet;
1665 typedef typename Traits<Arg0>::IVal IArg0;
1666 typedef typename Traits<Arg1>::IVal IArg1;
1667 typedef typename Traits<Arg2>::IVal IArg2;
1668 typedef typename Traits<Arg3>::IVal IArg3;
1669
1670 typedef Tuple4< const Arg0&, const Arg1&, const Arg2&, const Arg3&> Args;
1671 typedef Tuple4< const IArg0&, const IArg1&, const IArg2&, const IArg3&> IArgs;
1672 typedef Tuple4< ExprP<Arg0>, ExprP<Arg1>, ExprP<Arg2>, ExprP<Arg3> > ArgExprs;
1673 };
1674
1675 typedef vector<const ExprBase*> BaseArgExprs;
1676
1677 /*--------------------------------------------------------------------*//*!
1678 * \brief Type-independent operations for function objects.
1679 *
1680 *//*--------------------------------------------------------------------*/
1681 class FuncBase
1682 {
1683 public:
~FuncBase(void)1684 virtual ~FuncBase (void) {}
1685 virtual string getName (void) const = 0;
1686 //! Name of extension that this function requires, or empty.
getRequiredExtension(void) const1687 virtual string getRequiredExtension (void) const { return ""; }
getInputRange(const bool is16bit) const1688 virtual Interval getInputRange (const bool is16bit) const {DE_UNREF(is16bit); return Interval(true, -TCU_INFINITY, TCU_INFINITY); }
1689 virtual void print (ostream&,
1690 const BaseArgExprs&) const = 0;
1691 //! Index of output parameter, or -1 if none of the parameters is output.
getOutParamIndex(void) const1692 virtual int getOutParamIndex (void) const { return -1; }
1693
getSpirvCase(void) const1694 virtual SpirVCaseT getSpirvCase (void) const { return SPIRV_CASETYPE_NONE; }
1695
printDefinition(ostream & os) const1696 void printDefinition (ostream& os) const
1697 {
1698 doPrintDefinition(os);
1699 }
1700
getUsedFuncs(FuncSet & dst) const1701 void getUsedFuncs (FuncSet& dst) const
1702 {
1703 this->doGetUsedFuncs(dst);
1704 }
1705
1706 protected:
1707 virtual void doPrintDefinition (ostream& os) const = 0;
1708 virtual void doGetUsedFuncs (FuncSet& dst) const = 0;
1709 };
1710
1711 typedef Tuple4<string, string, string, string> ParamNames;
1712
1713 /*--------------------------------------------------------------------*//*!
1714 * \brief Function objects.
1715 *
1716 * Each Func object represents a GLSL function. It can be applied to interval
1717 * arguments, and it returns the an interval that is a conservative
1718 * approximation of the image of the GLSL function over the argument
1719 * intervals. That is, it is given a set of possible arguments and it returns
1720 * the set of possible values.
1721 *
1722 *//*--------------------------------------------------------------------*/
1723 template <typename Sig_>
1724 class Func : public FuncBase
1725 {
1726 public:
1727 typedef Sig_ Sig;
1728 typedef typename Sig::Ret Ret;
1729 typedef typename Sig::Arg0 Arg0;
1730 typedef typename Sig::Arg1 Arg1;
1731 typedef typename Sig::Arg2 Arg2;
1732 typedef typename Sig::Arg3 Arg3;
1733 typedef typename Sig::IRet IRet;
1734 typedef typename Sig::IArg0 IArg0;
1735 typedef typename Sig::IArg1 IArg1;
1736 typedef typename Sig::IArg2 IArg2;
1737 typedef typename Sig::IArg3 IArg3;
1738 typedef typename Sig::Args Args;
1739 typedef typename Sig::IArgs IArgs;
1740 typedef typename Sig::ArgExprs ArgExprs;
1741
print(ostream & os,const BaseArgExprs & args) const1742 void print (ostream& os,
1743 const BaseArgExprs& args) const
1744 {
1745 this->doPrint(os, args);
1746 }
1747
apply(const EvalContext & ctx,const IArg0 & arg0=IArg0 (),const IArg1 & arg1=IArg1 (),const IArg2 & arg2=IArg2 (),const IArg3 & arg3=IArg3 ()) const1748 IRet apply (const EvalContext& ctx,
1749 const IArg0& arg0 = IArg0(),
1750 const IArg1& arg1 = IArg1(),
1751 const IArg2& arg2 = IArg2(),
1752 const IArg3& arg3 = IArg3()) const
1753 {
1754 return this->applyArgs(ctx, IArgs(arg0, arg1, arg2, arg3));
1755 }
1756
fail(const EvalContext & ctx,const IArg0 & arg0=IArg0 (),const IArg1 & arg1=IArg1 (),const IArg2 & arg2=IArg2 (),const IArg3 & arg3=IArg3 ()) const1757 IRet fail (const EvalContext& ctx,
1758 const IArg0& arg0 = IArg0(),
1759 const IArg1& arg1 = IArg1(),
1760 const IArg2& arg2 = IArg2(),
1761 const IArg3& arg3 = IArg3()) const
1762 {
1763 return this->doFail(ctx, IArgs(arg0, arg1, arg2, arg3));
1764 }
applyArgs(const EvalContext & ctx,const IArgs & args) const1765 IRet applyArgs (const EvalContext& ctx,
1766 const IArgs& args) const
1767 {
1768 return this->doApply(ctx, args);
1769 }
1770 ExprP<Ret> operator() (const ExprP<Arg0>& arg0 = voidP(),
1771 const ExprP<Arg1>& arg1 = voidP(),
1772 const ExprP<Arg2>& arg2 = voidP(),
1773 const ExprP<Arg3>& arg3 = voidP()) const;
1774
getParamNames(void) const1775 const ParamNames& getParamNames (void) const
1776 {
1777 return this->doGetParamNames();
1778 }
1779
1780 protected:
1781 virtual IRet doApply (const EvalContext&,
1782 const IArgs&) const = 0;
doFail(const EvalContext & ctx,const IArgs & args) const1783 virtual IRet doFail (const EvalContext& ctx,
1784 const IArgs& args) const
1785 {
1786 return this->doApply(ctx, args);
1787 }
doPrint(ostream & os,const BaseArgExprs & args) const1788 virtual void doPrint (ostream& os, const BaseArgExprs& args) const
1789 {
1790 os << getName() << "(";
1791
1792 if (isTypeValid<Arg0>())
1793 os << *args[0];
1794
1795 if (isTypeValid<Arg1>())
1796 os << ", " << *args[1];
1797
1798 if (isTypeValid<Arg2>())
1799 os << ", " << *args[2];
1800
1801 if (isTypeValid<Arg3>())
1802 os << ", " << *args[3];
1803
1804 os << ")";
1805 }
1806
doGetParamNames(void) const1807 virtual const ParamNames& doGetParamNames (void) const
1808 {
1809 static ParamNames names ("a", "b", "c", "d");
1810 return names;
1811 }
1812 };
1813
1814 template <typename Sig>
1815 class Apply : public Expr<typename Sig::Ret>
1816 {
1817 public:
1818 typedef typename Sig::Ret Ret;
1819 typedef typename Sig::Arg0 Arg0;
1820 typedef typename Sig::Arg1 Arg1;
1821 typedef typename Sig::Arg2 Arg2;
1822 typedef typename Sig::Arg3 Arg3;
1823 typedef typename Expr<Ret>::Val Val;
1824 typedef typename Expr<Ret>::IVal IVal;
1825 typedef Func<Sig> ApplyFunc;
1826 typedef typename ApplyFunc::ArgExprs ArgExprs;
1827
Apply(const ApplyFunc & func,const ExprP<Arg0> & arg0=voidP (),const ExprP<Arg1> & arg1=voidP (),const ExprP<Arg2> & arg2=voidP (),const ExprP<Arg3> & arg3=voidP ())1828 Apply (const ApplyFunc& func,
1829 const ExprP<Arg0>& arg0 = voidP(),
1830 const ExprP<Arg1>& arg1 = voidP(),
1831 const ExprP<Arg2>& arg2 = voidP(),
1832 const ExprP<Arg3>& arg3 = voidP())
1833 : m_func (func),
1834 m_args (arg0, arg1, arg2, arg3) {}
1835
Apply(const ApplyFunc & func,const ArgExprs & args)1836 Apply (const ApplyFunc& func,
1837 const ArgExprs& args)
1838 : m_func (func),
1839 m_args (args) {}
1840 protected:
doPrintExpr(ostream & os) const1841 void doPrintExpr (ostream& os) const
1842 {
1843 BaseArgExprs args;
1844 args.push_back(m_args.a.get());
1845 args.push_back(m_args.b.get());
1846 args.push_back(m_args.c.get());
1847 args.push_back(m_args.d.get());
1848 m_func.print(os, args);
1849 }
1850
doEvaluate(const EvalContext & ctx) const1851 IVal doEvaluate (const EvalContext& ctx) const
1852 {
1853 return m_func.apply(ctx,
1854 m_args.a->evaluate(ctx), m_args.b->evaluate(ctx),
1855 m_args.c->evaluate(ctx), m_args.d->evaluate(ctx));
1856 }
1857
doGetUsedFuncs(FuncSet & dst) const1858 void doGetUsedFuncs (FuncSet& dst) const
1859 {
1860 m_func.getUsedFuncs(dst);
1861 m_args.a->getUsedFuncs(dst);
1862 m_args.b->getUsedFuncs(dst);
1863 m_args.c->getUsedFuncs(dst);
1864 m_args.d->getUsedFuncs(dst);
1865 }
1866
1867 const ApplyFunc& m_func;
1868 ArgExprs m_args;
1869 };
1870
1871 template<typename T>
1872 class Alternatives : public Func<Signature<T, T, T> >
1873 {
1874 public:
1875 typedef typename Alternatives::Sig Sig;
1876
1877 protected:
1878 typedef typename Alternatives::IRet IRet;
1879 typedef typename Alternatives::IArgs IArgs;
1880
getName(void) const1881 virtual string getName (void) const { return "alternatives"; }
doPrintDefinition(std::ostream &) const1882 virtual void doPrintDefinition (std::ostream&) const {}
doGetUsedFuncs(FuncSet &) const1883 void doGetUsedFuncs (FuncSet&) const {}
1884
doApply(const EvalContext &,const IArgs & args) const1885 virtual IRet doApply (const EvalContext&, const IArgs& args) const
1886 {
1887 return unionIVal<T>(args.a, args.b);
1888 }
1889
doPrint(ostream & os,const BaseArgExprs & args) const1890 virtual void doPrint (ostream& os, const BaseArgExprs& args) const
1891 {
1892 os << "{" << *args[0] << " | " << *args[1] << "}";
1893 }
1894 };
1895
1896 template <typename Sig>
createApply(const Func<Sig> & func,const typename Func<Sig>::ArgExprs & args)1897 ExprP<typename Sig::Ret> createApply (const Func<Sig>& func,
1898 const typename Func<Sig>::ArgExprs& args)
1899 {
1900 return exprP(new Apply<Sig>(func, args));
1901 }
1902
1903 template <typename Sig>
createApply(const Func<Sig> & func,const ExprP<typename Sig::Arg0> & arg0=voidP (),const ExprP<typename Sig::Arg1> & arg1=voidP (),const ExprP<typename Sig::Arg2> & arg2=voidP (),const ExprP<typename Sig::Arg3> & arg3=voidP ())1904 ExprP<typename Sig::Ret> createApply (
1905 const Func<Sig>& func,
1906 const ExprP<typename Sig::Arg0>& arg0 = voidP(),
1907 const ExprP<typename Sig::Arg1>& arg1 = voidP(),
1908 const ExprP<typename Sig::Arg2>& arg2 = voidP(),
1909 const ExprP<typename Sig::Arg3>& arg3 = voidP())
1910 {
1911 return exprP(new Apply<Sig>(func, arg0, arg1, arg2, arg3));
1912 }
1913
1914 template <typename Sig>
operator ()(const ExprP<typename Sig::Arg0> & arg0,const ExprP<typename Sig::Arg1> & arg1,const ExprP<typename Sig::Arg2> & arg2,const ExprP<typename Sig::Arg3> & arg3) const1915 ExprP<typename Sig::Ret> Func<Sig>::operator() (const ExprP<typename Sig::Arg0>& arg0,
1916 const ExprP<typename Sig::Arg1>& arg1,
1917 const ExprP<typename Sig::Arg2>& arg2,
1918 const ExprP<typename Sig::Arg3>& arg3) const
1919 {
1920 return createApply(*this, arg0, arg1, arg2, arg3);
1921 }
1922
1923 template <typename F>
app(const ExprP<typename F::Arg0> & arg0=voidP (),const ExprP<typename F::Arg1> & arg1=voidP (),const ExprP<typename F::Arg2> & arg2=voidP (),const ExprP<typename F::Arg3> & arg3=voidP ())1924 ExprP<typename F::Ret> app (const ExprP<typename F::Arg0>& arg0 = voidP(),
1925 const ExprP<typename F::Arg1>& arg1 = voidP(),
1926 const ExprP<typename F::Arg2>& arg2 = voidP(),
1927 const ExprP<typename F::Arg3>& arg3 = voidP())
1928 {
1929 return createApply(instance<F>(), arg0, arg1, arg2, arg3);
1930 }
1931
1932 template <typename F>
call(const EvalContext & ctx,const typename F::IArg0 & arg0=Void (),const typename F::IArg1 & arg1=Void (),const typename F::IArg2 & arg2=Void (),const typename F::IArg3 & arg3=Void ())1933 typename F::IRet call (const EvalContext& ctx,
1934 const typename F::IArg0& arg0 = Void(),
1935 const typename F::IArg1& arg1 = Void(),
1936 const typename F::IArg2& arg2 = Void(),
1937 const typename F::IArg3& arg3 = Void())
1938 {
1939 return instance<F>().apply(ctx, arg0, arg1, arg2, arg3);
1940 }
1941
1942 template <typename T>
alternatives(const ExprP<T> & arg0,const ExprP<T> & arg1)1943 ExprP<T> alternatives (const ExprP<T>& arg0,
1944 const ExprP<T>& arg1)
1945 {
1946 return createApply<typename Alternatives<T>::Sig>(instance<Alternatives<T> >(), arg0, arg1);
1947 }
1948
1949 template <typename Sig>
1950 class ApplyVar : public Apply<Sig>
1951 {
1952 public:
1953 typedef typename Sig::Ret Ret;
1954 typedef typename Sig::Arg0 Arg0;
1955 typedef typename Sig::Arg1 Arg1;
1956 typedef typename Sig::Arg2 Arg2;
1957 typedef typename Sig::Arg3 Arg3;
1958 typedef typename Expr<Ret>::Val Val;
1959 typedef typename Expr<Ret>::IVal IVal;
1960 typedef Func<Sig> ApplyFunc;
1961 typedef typename ApplyFunc::ArgExprs ArgExprs;
1962
ApplyVar(const ApplyFunc & func,const VariableP<Arg0> & arg0,const VariableP<Arg1> & arg1,const VariableP<Arg2> & arg2,const VariableP<Arg3> & arg3)1963 ApplyVar (const ApplyFunc& func,
1964 const VariableP<Arg0>& arg0,
1965 const VariableP<Arg1>& arg1,
1966 const VariableP<Arg2>& arg2,
1967 const VariableP<Arg3>& arg3)
1968 : Apply<Sig> (func, arg0, arg1, arg2, arg3) {}
1969 protected:
doEvaluate(const EvalContext & ctx) const1970 IVal doEvaluate (const EvalContext& ctx) const
1971 {
1972 const Variable<Arg0>& var0 = static_cast<const Variable<Arg0>&>(*this->m_args.a);
1973 const Variable<Arg1>& var1 = static_cast<const Variable<Arg1>&>(*this->m_args.b);
1974 const Variable<Arg2>& var2 = static_cast<const Variable<Arg2>&>(*this->m_args.c);
1975 const Variable<Arg3>& var3 = static_cast<const Variable<Arg3>&>(*this->m_args.d);
1976 return this->m_func.apply(ctx,
1977 ctx.env.lookup(var0), ctx.env.lookup(var1),
1978 ctx.env.lookup(var2), ctx.env.lookup(var3));
1979 }
1980
doFails(const EvalContext & ctx) const1981 IVal doFails (const EvalContext& ctx) const
1982 {
1983 const Variable<Arg0>& var0 = static_cast<const Variable<Arg0>&>(*this->m_args.a);
1984 const Variable<Arg1>& var1 = static_cast<const Variable<Arg1>&>(*this->m_args.b);
1985 const Variable<Arg2>& var2 = static_cast<const Variable<Arg2>&>(*this->m_args.c);
1986 const Variable<Arg3>& var3 = static_cast<const Variable<Arg3>&>(*this->m_args.d);
1987 return this->m_func.fail(ctx,
1988 ctx.env.lookup(var0), ctx.env.lookup(var1),
1989 ctx.env.lookup(var2), ctx.env.lookup(var3));
1990 }
1991 };
1992
1993 template <typename Sig>
applyVar(const Func<Sig> & func,const VariableP<typename Sig::Arg0> & arg0,const VariableP<typename Sig::Arg1> & arg1,const VariableP<typename Sig::Arg2> & arg2,const VariableP<typename Sig::Arg3> & arg3)1994 ExprP<typename Sig::Ret> applyVar (const Func<Sig>& func,
1995 const VariableP<typename Sig::Arg0>& arg0,
1996 const VariableP<typename Sig::Arg1>& arg1,
1997 const VariableP<typename Sig::Arg2>& arg2,
1998 const VariableP<typename Sig::Arg3>& arg3)
1999 {
2000 return exprP(new ApplyVar<Sig>(func, arg0, arg1, arg2, arg3));
2001 }
2002
2003 template <typename Sig_>
2004 class DerivedFunc : public Func<Sig_>
2005 {
2006 public:
2007 typedef typename DerivedFunc::ArgExprs ArgExprs;
2008 typedef typename DerivedFunc::IRet IRet;
2009 typedef typename DerivedFunc::IArgs IArgs;
2010 typedef typename DerivedFunc::Ret Ret;
2011 typedef typename DerivedFunc::Arg0 Arg0;
2012 typedef typename DerivedFunc::Arg1 Arg1;
2013 typedef typename DerivedFunc::Arg2 Arg2;
2014 typedef typename DerivedFunc::Arg3 Arg3;
2015 typedef typename DerivedFunc::IArg0 IArg0;
2016 typedef typename DerivedFunc::IArg1 IArg1;
2017 typedef typename DerivedFunc::IArg2 IArg2;
2018 typedef typename DerivedFunc::IArg3 IArg3;
2019
2020 protected:
doPrintDefinition(ostream & os) const2021 void doPrintDefinition (ostream& os) const
2022 {
2023 const ParamNames& paramNames = this->getParamNames();
2024
2025 initialize();
2026
2027 os << dataTypeNameOf<Ret>() << " " << this->getName()
2028 << "(";
2029 if (isTypeValid<Arg0>())
2030 os << dataTypeNameOf<Arg0>() << " " << paramNames.a;
2031 if (isTypeValid<Arg1>())
2032 os << ", " << dataTypeNameOf<Arg1>() << " " << paramNames.b;
2033 if (isTypeValid<Arg2>())
2034 os << ", " << dataTypeNameOf<Arg2>() << " " << paramNames.c;
2035 if (isTypeValid<Arg3>())
2036 os << ", " << dataTypeNameOf<Arg3>() << " " << paramNames.d;
2037 os << ")\n{\n";
2038
2039 for (size_t ndx = 0; ndx < m_body.size(); ++ndx)
2040 os << *m_body[ndx];
2041 os << "return " << *m_ret << ";\n";
2042 os << "}\n";
2043 }
2044
doApply(const EvalContext & ctx,const IArgs & args) const2045 IRet doApply (const EvalContext& ctx,
2046 const IArgs& args) const
2047 {
2048 Environment funEnv;
2049 IArgs& mutArgs = const_cast<IArgs&>(args);
2050 IRet ret;
2051
2052 initialize();
2053
2054 funEnv.bind(*m_var0, args.a);
2055 funEnv.bind(*m_var1, args.b);
2056 funEnv.bind(*m_var2, args.c);
2057 funEnv.bind(*m_var3, args.d);
2058
2059 {
2060 EvalContext funCtx(ctx.format, ctx.floatPrecision, funEnv, ctx.callDepth);
2061
2062 for (size_t ndx = 0; ndx < m_body.size(); ++ndx)
2063 m_body[ndx]->execute(funCtx);
2064
2065 ret = m_ret->evaluate(funCtx);
2066 }
2067
2068 // \todo [lauri] Store references instead of values in environment
2069 const_cast<IArg0&>(mutArgs.a) = funEnv.lookup(*m_var0);
2070 const_cast<IArg1&>(mutArgs.b) = funEnv.lookup(*m_var1);
2071 const_cast<IArg2&>(mutArgs.c) = funEnv.lookup(*m_var2);
2072 const_cast<IArg3&>(mutArgs.d) = funEnv.lookup(*m_var3);
2073
2074 return ret;
2075 }
2076
doGetUsedFuncs(FuncSet & dst) const2077 void doGetUsedFuncs (FuncSet& dst) const
2078 {
2079 initialize();
2080 if (dst.insert(this).second)
2081 {
2082 for (size_t ndx = 0; ndx < m_body.size(); ++ndx)
2083 m_body[ndx]->getUsedFuncs(dst);
2084 m_ret->getUsedFuncs(dst);
2085 }
2086 }
2087
2088 virtual ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args_) const = 0;
2089
2090 // These are transparently initialized when first needed. They cannot be
2091 // initialized in the constructor because they depend on the doExpand
2092 // method of the subclass.
2093
2094 mutable VariableP<Arg0> m_var0;
2095 mutable VariableP<Arg1> m_var1;
2096 mutable VariableP<Arg2> m_var2;
2097 mutable VariableP<Arg3> m_var3;
2098 mutable vector<StatementP> m_body;
2099 mutable ExprP<Ret> m_ret;
2100
2101 private:
2102
initialize(void) const2103 void initialize (void) const
2104 {
2105 if (!m_ret)
2106 {
2107 const ParamNames& paramNames = this->getParamNames();
2108 Counter symCounter;
2109 ExpandContext ctx (symCounter);
2110 ArgExprs args;
2111
2112 args.a = m_var0 = variable<Arg0>(paramNames.a);
2113 args.b = m_var1 = variable<Arg1>(paramNames.b);
2114 args.c = m_var2 = variable<Arg2>(paramNames.c);
2115 args.d = m_var3 = variable<Arg3>(paramNames.d);
2116
2117 m_ret = this->doExpand(ctx, args);
2118 m_body = ctx.getStatements();
2119 }
2120 }
2121 };
2122
2123 template <typename Sig>
2124 class PrimitiveFunc : public Func<Sig>
2125 {
2126 public:
2127 typedef typename PrimitiveFunc::Ret Ret;
2128 typedef typename PrimitiveFunc::ArgExprs ArgExprs;
2129
2130 protected:
doPrintDefinition(ostream &) const2131 void doPrintDefinition (ostream&) const {}
doGetUsedFuncs(FuncSet &) const2132 void doGetUsedFuncs (FuncSet&) const {}
2133 };
2134
2135 template <typename T>
2136 class Cond : public PrimitiveFunc<Signature<T, bool, T, T> >
2137 {
2138 public:
2139 typedef typename Cond::IArgs IArgs;
2140 typedef typename Cond::IRet IRet;
2141
getName(void) const2142 string getName (void) const
2143 {
2144 return "_cond";
2145 }
2146
2147 protected:
2148
doPrint(ostream & os,const BaseArgExprs & args) const2149 void doPrint (ostream& os, const BaseArgExprs& args) const
2150 {
2151 os << "(" << *args[0] << " ? " << *args[1] << " : " << *args[2] << ")";
2152 }
2153
doApply(const EvalContext &,const IArgs & iargs) const2154 IRet doApply (const EvalContext&, const IArgs& iargs)const
2155 {
2156 IRet ret;
2157
2158 if (iargs.a.contains(true))
2159 ret = unionIVal<T>(ret, iargs.b);
2160
2161 if (iargs.a.contains(false))
2162 ret = unionIVal<T>(ret, iargs.c);
2163
2164 return ret;
2165 }
2166 };
2167
2168 template <typename T>
2169 class CompareOperator : public PrimitiveFunc<Signature<bool, T, T> >
2170 {
2171 public:
2172 typedef typename CompareOperator::IArgs IArgs;
2173 typedef typename CompareOperator::IArg0 IArg0;
2174 typedef typename CompareOperator::IArg1 IArg1;
2175 typedef typename CompareOperator::IRet IRet;
2176
2177 protected:
doPrint(ostream & os,const BaseArgExprs & args) const2178 void doPrint (ostream& os, const BaseArgExprs& args) const
2179 {
2180 os << "(" << *args[0] << getSymbol() << *args[1] << ")";
2181 }
2182
doApply(const EvalContext &,const IArgs & iargs) const2183 Interval doApply (const EvalContext&, const IArgs& iargs) const
2184 {
2185 const IArg0& arg0 = iargs.a;
2186 const IArg1& arg1 = iargs.b;
2187 IRet ret;
2188
2189 if (canSucceed(arg0, arg1))
2190 ret |= true;
2191 if (canFail(arg0, arg1))
2192 ret |= false;
2193
2194 return ret;
2195 }
2196
2197 virtual string getSymbol (void) const = 0;
2198 virtual bool canSucceed (const IArg0&, const IArg1&) const = 0;
2199 virtual bool canFail (const IArg0&, const IArg1&) const = 0;
2200 };
2201
2202 template <typename T>
2203 class LessThan : public CompareOperator<T>
2204 {
2205 public:
getName(void) const2206 string getName (void) const { return "lessThan"; }
2207
2208 protected:
getSymbol(void) const2209 string getSymbol (void) const { return "<"; }
2210
canSucceed(const Interval & a,const Interval & b) const2211 bool canSucceed (const Interval& a, const Interval& b) const
2212 {
2213 return (a.lo() < b.hi());
2214 }
2215
canFail(const Interval & a,const Interval & b) const2216 bool canFail (const Interval& a, const Interval& b) const
2217 {
2218 return !(a.hi() < b.lo());
2219 }
2220 };
2221
2222 template <typename T>
operator <(const ExprP<T> & a,const ExprP<T> & b)2223 ExprP<bool> operator< (const ExprP<T>& a, const ExprP<T>& b)
2224 {
2225 return app<LessThan<T> >(a, b);
2226 }
2227
2228 template <typename T>
cond(const ExprP<bool> & test,const ExprP<T> & consequent,const ExprP<T> & alternative)2229 ExprP<T> cond (const ExprP<bool>& test,
2230 const ExprP<T>& consequent,
2231 const ExprP<T>& alternative)
2232 {
2233 return app<Cond<T> >(test, consequent, alternative);
2234 }
2235
2236 /*--------------------------------------------------------------------*//*!
2237 *
2238 * @}
2239 *
2240 *//*--------------------------------------------------------------------*/
2241 //Proper parameters for template T
2242 // Signature<float, float> 32bit tests
2243 // Signature<float, deFloat16> 16bit tests
2244 // Signature<double, double> 64bit tests
2245 template< class T>
2246 class FloatFunc1 : public PrimitiveFunc<T>
2247 {
2248 protected:
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0>::IArgs & iargs) const2249 Interval doApply (const EvalContext& ctx, const typename Signature<typename T::Ret, typename T::Arg0>::IArgs& iargs) const
2250 {
2251 return this->applyMonotone(ctx, iargs.a);
2252 }
2253
applyMonotone(const EvalContext & ctx,const Interval & iarg0) const2254 Interval applyMonotone (const EvalContext& ctx, const Interval& iarg0) const
2255 {
2256 Interval ret;
2257
2258 TCU_INTERVAL_APPLY_MONOTONE1(ret, arg0, iarg0, val,
2259 TCU_SET_INTERVAL(val, point,
2260 point = this->applyPoint(ctx, arg0)));
2261
2262 ret |= innerExtrema(ctx, iarg0);
2263 ret &= (this->getCodomain(ctx) | TCU_NAN);
2264
2265 return ctx.format.convert(ret);
2266 }
2267
innerExtrema(const EvalContext &,const Interval &) const2268 virtual Interval innerExtrema (const EvalContext&, const Interval&) const
2269 {
2270 return Interval(); // empty interval, i.e. no extrema
2271 }
2272
applyPoint(const EvalContext & ctx,double arg0) const2273 virtual Interval applyPoint (const EvalContext& ctx, double arg0) const
2274 {
2275 const double exact = this->applyExact(arg0);
2276 const double prec = this->precision(ctx, exact, arg0);
2277
2278 return exact + Interval(-prec, prec);
2279 }
2280
applyExact(double) const2281 virtual double applyExact (double) const
2282 {
2283 TCU_THROW(InternalError, "Cannot apply");
2284 }
2285
getCodomain(const EvalContext &) const2286 virtual Interval getCodomain (const EvalContext&) const
2287 {
2288 return Interval::unbounded(true);
2289 }
2290
2291 virtual double precision (const EvalContext& ctx, double, double) const = 0;
2292 };
2293
2294 /*Proper parameters for template T
2295 Signature<double, double> 64bit tests
2296 Signature<float, float> 32bit tests
2297 Signature<float, deFloat16> 16bit tests*/
2298 template <class T>
2299 class CFloatFunc1 : public FloatFunc1<T>
2300 {
2301 public:
CFloatFunc1(const string & name,tcu::DoubleFunc1 & func)2302 CFloatFunc1 (const string& name, tcu::DoubleFunc1& func)
2303 : m_name(name), m_func(func) {}
2304
getName(void) const2305 string getName (void) const { return m_name; }
2306
2307 protected:
applyExact(double x) const2308 double applyExact (double x) const { return m_func(x); }
2309
2310 const string m_name;
2311 tcu::DoubleFunc1& m_func;
2312 };
2313
2314 //<Signature<float, deFloat16, deFloat16> >
2315 //<Signature<float, float, float> >
2316 //<Signature<double, double, double> >
2317 template <class T>
2318 class FloatFunc2 : public PrimitiveFunc<T>
2319 {
2320 protected:
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2321 Interval doApply (const EvalContext& ctx, const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs& iargs) const
2322 {
2323 return this->applyMonotone(ctx, iargs.a, iargs.b);
2324 }
2325
applyMonotone(const EvalContext & ctx,const Interval & xi,const Interval & yi) const2326 Interval applyMonotone (const EvalContext& ctx,
2327 const Interval& xi,
2328 const Interval& yi) const
2329 {
2330 Interval reti;
2331
2332 TCU_INTERVAL_APPLY_MONOTONE2(reti, x, xi, y, yi, ret,
2333 TCU_SET_INTERVAL(ret, point,
2334 point = this->applyPoint(ctx, x, y)));
2335 reti |= innerExtrema(ctx, xi, yi);
2336 reti &= (this->getCodomain(ctx) | TCU_NAN);
2337
2338 return ctx.format.convert(reti);
2339 }
2340
innerExtrema(const EvalContext &,const Interval &,const Interval &) const2341 virtual Interval innerExtrema (const EvalContext&,
2342 const Interval&,
2343 const Interval&) const
2344 {
2345 return Interval(); // empty interval, i.e. no extrema
2346 }
2347
applyPoint(const EvalContext & ctx,double x,double y) const2348 virtual Interval applyPoint (const EvalContext& ctx,
2349 double x,
2350 double y) const
2351 {
2352 const double exact = this->applyExact(x, y);
2353 const double prec = this->precision(ctx, exact, x, y);
2354
2355 return exact + Interval(-prec, prec);
2356 }
2357
applyExact(double,double) const2358 virtual double applyExact (double, double) const
2359 {
2360 TCU_THROW(InternalError, "Cannot apply");
2361 }
2362
getCodomain(const EvalContext &) const2363 virtual Interval getCodomain (const EvalContext&) const
2364 {
2365 return Interval::unbounded(true);
2366 }
2367
2368 virtual double precision (const EvalContext& ctx,
2369 double ret,
2370 double x,
2371 double y) const = 0;
2372 };
2373
2374 template <class T>
2375 class CFloatFunc2 : public FloatFunc2<T>
2376 {
2377 public:
CFloatFunc2(const string & name,tcu::DoubleFunc2 & func)2378 CFloatFunc2 (const string& name,
2379 tcu::DoubleFunc2& func)
2380 : m_name(name)
2381 , m_func(func)
2382 {
2383 }
2384
getName(void) const2385 string getName (void) const { return m_name; }
2386
2387 protected:
applyExact(double x,double y) const2388 double applyExact (double x, double y) const { return m_func(x, y); }
2389
2390 const string m_name;
2391 tcu::DoubleFunc2& m_func;
2392 };
2393
2394 template <class T>
2395 class InfixOperator : public FloatFunc2<T>
2396 {
2397 protected:
2398 virtual string getSymbol (void) const = 0;
2399
doPrint(ostream & os,const BaseArgExprs & args) const2400 void doPrint (ostream& os, const BaseArgExprs& args) const
2401 {
2402 os << "(" << *args[0] << " " << getSymbol() << " " << *args[1] << ")";
2403 }
2404
applyPoint(const EvalContext & ctx,double x,double y) const2405 Interval applyPoint (const EvalContext& ctx,
2406 double x,
2407 double y) const
2408 {
2409 const double exact = this->applyExact(x, y);
2410
2411 // Allow either representable number on both sides of the exact value,
2412 // but require exactly representable values to be preserved.
2413 return ctx.format.roundOut(exact, !deIsInf(x) && !deIsInf(y));
2414 }
2415
precision(const EvalContext &,double,double,double) const2416 double precision (const EvalContext&, double, double, double) const
2417 {
2418 return 0.0;
2419 }
2420 };
2421
2422 class InfixOperator16Bit : public FloatFunc2 <Signature<float, deFloat16, deFloat16> >
2423 {
2424 protected:
2425 virtual string getSymbol (void) const = 0;
2426
doPrint(ostream & os,const BaseArgExprs & args) const2427 void doPrint (ostream& os, const BaseArgExprs& args) const
2428 {
2429 os << "(" << *args[0] << " " << getSymbol() << " " << *args[1] << ")";
2430 }
2431
applyPoint(const EvalContext & ctx,double x,double y) const2432 Interval applyPoint (const EvalContext& ctx,
2433 double x,
2434 double y) const
2435 {
2436 const double exact = this->applyExact(x, y);
2437
2438 // Allow either representable number on both sides of the exact value,
2439 // but require exactly representable values to be preserved.
2440 return ctx.format.roundOut(exact, !deIsInf(x) && !deIsInf(y));
2441 }
2442
precision(const EvalContext &,double,double,double) const2443 double precision (const EvalContext&, double, double, double) const
2444 {
2445 return 0.0;
2446 }
2447 };
2448
2449 template <class T>
2450 class FloatFunc3 : public PrimitiveFunc<T>
2451 {
2452 protected:
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1,typename T::Arg2>::IArgs & iargs) const2453 Interval doApply (const EvalContext& ctx, const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1, typename T::Arg2>::IArgs& iargs) const
2454 {
2455 return this->applyMonotone(ctx, iargs.a, iargs.b, iargs.c);
2456 }
2457
applyMonotone(const EvalContext & ctx,const Interval & xi,const Interval & yi,const Interval & zi) const2458 Interval applyMonotone (const EvalContext& ctx,
2459 const Interval& xi,
2460 const Interval& yi,
2461 const Interval& zi) const
2462 {
2463 Interval reti;
2464 TCU_INTERVAL_APPLY_MONOTONE3(reti, x, xi, y, yi, z, zi, ret,
2465 TCU_SET_INTERVAL(ret, point,
2466 point = this->applyPoint(ctx, x, y, z)));
2467 return ctx.format.convert(reti);
2468 }
2469
applyPoint(const EvalContext & ctx,double x,double y,double z) const2470 virtual Interval applyPoint (const EvalContext& ctx,
2471 double x,
2472 double y,
2473 double z) const
2474 {
2475 const double exact = this->applyExact(x, y, z);
2476 const double prec = this->precision(ctx, exact, x, y, z);
2477 return exact + Interval(-prec, prec);
2478 }
2479
applyExact(double,double,double) const2480 virtual double applyExact (double, double, double) const
2481 {
2482 TCU_THROW(InternalError, "Cannot apply");
2483 }
2484
2485 virtual double precision (const EvalContext& ctx,
2486 double result,
2487 double x,
2488 double y,
2489 double z) const = 0;
2490 };
2491
2492 // We define syntactic sugar functions for expression constructors. Since
2493 // these have the same names as ordinary mathematical operations (sin, log
2494 // etc.), it's better to give them a dedicated namespace.
2495 namespace Functions
2496 {
2497
2498 using namespace tcu;
2499
2500 template <class T>
2501 class Comparison : public InfixOperator < T >
2502 {
2503 public:
getName(void) const2504 string getName (void) const { return "comparison"; }
getSymbol(void) const2505 string getSymbol (void) const { return ""; }
2506
getSpirvCase() const2507 SpirVCaseT getSpirvCase () const { return SPIRV_CASETYPE_COMPARE; }
2508
doApply(const EvalContext & ctx,const typename Comparison<T>::IArgs & iargs) const2509 Interval doApply (const EvalContext& ctx,
2510 const typename Comparison<T>::IArgs& iargs) const
2511 {
2512 DE_UNREF(ctx);
2513 if (iargs.a.hasNaN() || iargs.b.hasNaN())
2514 {
2515 return TCU_NAN; // one of the floats is NaN: block analysis
2516 }
2517
2518 int operationFlag = 1;
2519 int result = 0;
2520 const double a = iargs.a.midpoint();
2521 const double b = iargs.b.midpoint();
2522
2523 for (int i = 0; i<2; ++i)
2524 {
2525 if (a == b)
2526 result += operationFlag;
2527 operationFlag = operationFlag << 1;
2528
2529 if (a > b)
2530 result += operationFlag;
2531 operationFlag = operationFlag << 1;
2532
2533 if (a < b)
2534 result += operationFlag;
2535 operationFlag = operationFlag << 1;
2536
2537 if (a >= b)
2538 result += operationFlag;
2539 operationFlag = operationFlag << 1;
2540
2541 if (a <= b)
2542 result += operationFlag;
2543 operationFlag = operationFlag << 1;
2544 }
2545 return result;
2546 }
2547 };
2548
2549 template <class T>
2550 class Add : public InfixOperator < T >
2551 {
2552 public:
getName(void) const2553 string getName (void) const { return "add"; }
getSymbol(void) const2554 string getSymbol (void) const { return "+"; }
2555
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2556 Interval doApply (const EvalContext& ctx,
2557 const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs& iargs) const
2558 {
2559 // Fast-path for common case
2560 if (iargs.a.isOrdinary() && iargs.b.isOrdinary())
2561 {
2562 Interval ret;
2563 TCU_SET_INTERVAL_BOUNDS(ret, sum,
2564 sum = iargs.a.lo() + iargs.b.lo(),
2565 sum = iargs.a.hi() + iargs.b.hi());
2566 return ctx.format.convert(ctx.format.roundOut(ret, true));
2567 }
2568 return this->applyMonotone(ctx, iargs.a, iargs.b);
2569 }
2570
2571 protected:
applyExact(double x,double y) const2572 double applyExact (double x, double y) const { return x + y; }
2573 };
2574
2575 template<class T>
2576 class Mul : public InfixOperator<T>
2577 {
2578 public:
getName(void) const2579 string getName (void) const { return "mul"; }
getSymbol(void) const2580 string getSymbol (void) const { return "*"; }
2581
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2582 Interval doApply (const EvalContext& ctx, const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs& iargs) const
2583 {
2584 Interval a = iargs.a;
2585 Interval b = iargs.b;
2586
2587 // Fast-path for common case
2588 if (a.isOrdinary() && b.isOrdinary())
2589 {
2590 Interval ret;
2591 if (a.hi() < 0)
2592 {
2593 a = -a;
2594 b = -b;
2595 }
2596 if (a.lo() >= 0 && b.lo() >= 0)
2597 {
2598 TCU_SET_INTERVAL_BOUNDS(ret, prod,
2599 prod = a.lo() * b.lo(),
2600 prod = a.hi() * b.hi());
2601 return ctx.format.convert(ctx.format.roundOut(ret, true));
2602 }
2603 if (a.lo() >= 0 && b.hi() <= 0)
2604 {
2605 TCU_SET_INTERVAL_BOUNDS(ret, prod,
2606 prod = a.hi() * b.lo(),
2607 prod = a.lo() * b.hi());
2608 return ctx.format.convert(ctx.format.roundOut(ret, true));
2609 }
2610 }
2611 return this->applyMonotone(ctx, iargs.a, iargs.b);
2612 }
2613
2614 protected:
applyExact(double x,double y) const2615 double applyExact (double x, double y) const { return x * y; }
2616
innerExtrema(const EvalContext &,const Interval & xi,const Interval & yi) const2617 Interval innerExtrema(const EvalContext&, const Interval& xi, const Interval& yi) const
2618 {
2619 if (((xi.contains(-TCU_INFINITY) || xi.contains(TCU_INFINITY)) && yi.contains(0.0)) ||
2620 ((yi.contains(-TCU_INFINITY) || yi.contains(TCU_INFINITY)) && xi.contains(0.0)))
2621 return Interval(TCU_NAN);
2622
2623 return Interval();
2624 }
2625 };
2626
2627 template<class T>
2628 class Sub : public InfixOperator <T>
2629 {
2630 public:
getName(void) const2631 string getName (void) const { return "sub"; }
getSymbol(void) const2632 string getSymbol (void) const { return "-"; }
2633
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2634 Interval doApply (const EvalContext& ctx, const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs& iargs) const
2635 {
2636 // Fast-path for common case
2637 if (iargs.a.isOrdinary() && iargs.b.isOrdinary())
2638 {
2639 Interval ret;
2640
2641 TCU_SET_INTERVAL_BOUNDS(ret, diff,
2642 diff = iargs.a.lo() - iargs.b.hi(),
2643 diff = iargs.a.hi() - iargs.b.lo());
2644 return ctx.format.convert(ctx.format.roundOut(ret, true));
2645
2646 }
2647 else
2648 {
2649 return this->applyMonotone(ctx, iargs.a, iargs.b);
2650 }
2651 }
2652
2653 protected:
applyExact(double x,double y) const2654 double applyExact (double x, double y) const { return x - y; }
2655 };
2656
2657 template <class T>
2658 class Negate : public FloatFunc1<T>
2659 {
2660 public:
getName(void) const2661 string getName (void) const { return "_negate"; }
doPrint(ostream & os,const BaseArgExprs & args) const2662 void doPrint (ostream& os, const BaseArgExprs& args) const { os << "-" << *args[0]; }
2663
2664 protected:
precision(const EvalContext &,double,double) const2665 double precision (const EvalContext&, double, double) const { return 0.0; }
applyExact(double x) const2666 double applyExact (double x) const { return -x; }
2667 };
2668
2669 template <class T>
2670 class Div : public InfixOperator<T>
2671 {
2672 public:
getName(void) const2673 string getName (void) const { return "div"; }
2674
2675 protected:
getSymbol(void) const2676 string getSymbol (void) const { return "/"; }
2677
innerExtrema(const EvalContext &,const Interval & nom,const Interval & den) const2678 Interval innerExtrema (const EvalContext&,
2679 const Interval& nom,
2680 const Interval& den) const
2681 {
2682 Interval ret;
2683
2684 if (den.contains(0.0))
2685 {
2686 if (nom.contains(0.0))
2687 ret |= TCU_NAN;
2688
2689 if (nom.lo() < 0.0 || nom.hi() > 0.0)
2690 ret |= Interval::unbounded();
2691 }
2692
2693 return ret;
2694 }
2695
applyExact(double x,double y) const2696 double applyExact (double x, double y) const { return x / y; }
2697
applyPoint(const EvalContext & ctx,double x,double y) const2698 Interval applyPoint (const EvalContext& ctx, double x, double y) const
2699 {
2700 Interval ret = FloatFunc2<T>::applyPoint(ctx, x, y);
2701
2702 if (!deIsInf(x) && !deIsInf(y) && y != 0.0)
2703 {
2704 const Interval dst = ctx.format.convert(ret);
2705 if (dst.contains(-TCU_INFINITY)) ret |= -ctx.format.getMaxValue();
2706 if (dst.contains(+TCU_INFINITY)) ret |= +ctx.format.getMaxValue();
2707 }
2708
2709 return ret;
2710 }
2711
precision(const EvalContext & ctx,double ret,double,double den) const2712 double precision (const EvalContext& ctx, double ret, double, double den) const
2713 {
2714 const FloatFormat& fmt = ctx.format;
2715
2716 // \todo [2014-03-05 lauri] Check that the limits in GLSL 3.10 are actually correct.
2717 // For now, we assume that division's precision is 2.5 ULP when the value is within
2718 // [2^MINEXP, 2^MAXEXP-1]
2719
2720 if (den == 0.0)
2721 return 0.0; // Result must be exactly inf
2722 else if (de::inBounds(deAbs(den),
2723 deLdExp(1.0, fmt.getMinExp()),
2724 deLdExp(1.0, fmt.getMaxExp() - 1)))
2725 return fmt.ulp(ret, 2.5);
2726 else
2727 return TCU_INFINITY; // Can be any number, but must be a number.
2728 }
2729 };
2730
2731 template <class T>
2732 class InverseSqrt : public FloatFunc1 <T>
2733 {
2734 public:
getName(void) const2735 string getName (void) const { return "inversesqrt"; }
2736
2737 protected:
applyExact(double x) const2738 double applyExact (double x) const { return 1.0 / deSqrt(x); }
2739
precision(const EvalContext & ctx,double ret,double x) const2740 double precision (const EvalContext& ctx, double ret, double x) const
2741 {
2742 return x <= 0 ? TCU_NAN : ctx.format.ulp(ret, 2.0);
2743 }
2744
getCodomain(const EvalContext &) const2745 Interval getCodomain (const EvalContext&) const
2746 {
2747 return Interval(0.0, TCU_INFINITY);
2748 }
2749 };
2750
2751 template <class T>
2752 class ExpFunc : public CFloatFunc1<T>
2753 {
2754 public:
ExpFunc(const string & name,DoubleFunc1 & func)2755 ExpFunc (const string& name, DoubleFunc1& func)
2756 : CFloatFunc1<T> (name, func)
2757 {}
2758 protected:
2759 double precision (const EvalContext& ctx, double ret, double x) const;
getCodomain(const EvalContext &) const2760 Interval getCodomain (const EvalContext&) const
2761 {
2762 return Interval(0.0, TCU_INFINITY);
2763 }
2764 };
2765
2766 template <>
precision(const EvalContext & ctx,double ret,double x) const2767 double ExpFunc <Signature<float, float> >::precision (const EvalContext& ctx, double ret, double x) const
2768 {
2769 switch (ctx.floatPrecision)
2770 {
2771 case glu::PRECISION_HIGHP:
2772 return ctx.format.ulp(ret, 3.0 + 2.0 * deAbs(x));
2773 case glu::PRECISION_MEDIUMP:
2774 case glu::PRECISION_LAST:
2775 return ctx.format.ulp(ret, 1.0 + 2.0 * deAbs(x));
2776 default:
2777 DE_FATAL("Impossible");
2778 }
2779
2780 return 0.0;
2781 }
2782
2783 template <>
precision(const EvalContext & ctx,double ret,double x) const2784 double ExpFunc <Signature<deFloat16, deFloat16> >::precision(const EvalContext& ctx, double ret, double x) const
2785 {
2786 return ctx.format.ulp(ret, 1.0 + 2.0 * deAbs(x));
2787 }
2788
2789 template <>
precision(const EvalContext & ctx,double ret,double x) const2790 double ExpFunc <Signature<double, double> >::precision(const EvalContext& ctx, double ret, double x) const
2791 {
2792 return ctx.format.ulp(ret, 1.0 + 2.0 * deAbs(x));
2793 }
2794
2795 template <class T>
Exp2(void)2796 class Exp2 : public ExpFunc<T> { public: Exp2 (void) : ExpFunc<T>("exp2", deExp2) {} };
2797 template <class T>
Exp(void)2798 class Exp : public ExpFunc<T> { public: Exp (void) : ExpFunc<T>("exp", deExp) {} };
2799
2800 template <typename T>
exp2(const ExprP<T> & x)2801 ExprP<T> exp2 (const ExprP<T>& x) { return app<Exp2< Signature<T, T> > >(x); }
2802 template <typename T>
exp(const ExprP<T> & x)2803 ExprP<T> exp (const ExprP<T>& x) { return app<Exp< Signature<T, T> > >(x); }
2804
2805 template <class T>
2806 class LogFunc : public CFloatFunc1<T>
2807 {
2808 public:
LogFunc(const string & name,DoubleFunc1 & func)2809 LogFunc (const string& name, DoubleFunc1& func)
2810 : CFloatFunc1<T>(name, func) {}
2811
2812 protected:
2813 double precision (const EvalContext& ctx, double ret, double x) const;
2814 };
2815
2816 template <>
precision(const EvalContext & ctx,double ret,double x) const2817 double LogFunc<Signature<float, float> >::precision(const EvalContext& ctx, double ret, double x) const
2818 {
2819 if (x <= 0)
2820 return TCU_NAN;
2821
2822 switch (ctx.floatPrecision)
2823 {
2824 case glu::PRECISION_HIGHP:
2825 return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -21) : ctx.format.ulp(ret, 3.0);
2826 case glu::PRECISION_MEDIUMP:
2827 case glu::PRECISION_LAST:
2828 return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -7) : ctx.format.ulp(ret, 3.0);
2829 default:
2830 DE_FATAL("Impossible");
2831 }
2832
2833 return 0;
2834 }
2835
2836 template <>
precision(const EvalContext & ctx,double ret,double x) const2837 double LogFunc<Signature<deFloat16, deFloat16> >::precision(const EvalContext& ctx, double ret, double x) const
2838 {
2839 if (x <= 0)
2840 return TCU_NAN;
2841 return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -7) : ctx.format.ulp(ret, 3.0);
2842 }
2843
2844 // Spec: "The precision of double-precision instructions is at least that of single precision."
2845 // Lets pick float high precision as a reference.
2846 template <>
precision(const EvalContext & ctx,double ret,double x) const2847 double LogFunc<Signature<double, double> >::precision(const EvalContext& ctx, double ret, double x) const
2848 {
2849 if (x <= 0)
2850 return TCU_NAN;
2851 return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -21) : ctx.format.ulp(ret, 3.0);
2852 }
2853
2854 template <class T>
Log2(void)2855 class Log2 : public LogFunc<T> { public: Log2 (void) : LogFunc<T>("log2", deLog2) {} };
2856 template <class T>
Log(void)2857 class Log : public LogFunc<T> { public: Log (void) : LogFunc<T>("log", deLog) {} };
2858
log2(const ExprP<float> & x)2859 ExprP<float> log2 (const ExprP<float>& x) { return app<Log2< Signature<float, float> > >(x); }
log(const ExprP<float> & x)2860 ExprP<float> log (const ExprP<float>& x) { return app<Log< Signature<float, float> > >(x); }
2861
log2(const ExprP<deFloat16> & x)2862 ExprP<deFloat16> log2 (const ExprP<deFloat16>& x) { return app<Log2< Signature<deFloat16, deFloat16> > >(x); }
log(const ExprP<deFloat16> & x)2863 ExprP<deFloat16> log (const ExprP<deFloat16>& x) { return app<Log< Signature<deFloat16, deFloat16> > >(x); }
2864
log2(const ExprP<double> & x)2865 ExprP<double> log2 (const ExprP<double>& x) { return app<Log2< Signature<double, double> > >(x); }
log(const ExprP<double> & x)2866 ExprP<double> log (const ExprP<double>& x) { return app<Log< Signature<double, double> > >(x); }
2867
2868 #define DEFINE_CONSTRUCTOR1(CLASS, TRET, NAME, T0) \
2869 ExprP<TRET> NAME (const ExprP<T0>& arg0) { return app<CLASS>(arg0); }
2870
2871 #define DEFINE_DERIVED1(CLASS, TRET, NAME, T0, ARG0, EXPANSION) \
2872 class CLASS : public DerivedFunc<Signature<TRET, T0> > /* NOLINT(CLASS) */ \
2873 { \
2874 public: \
2875 string getName (void) const { return #NAME; } \
2876 \
2877 protected: \
2878 ExprP<TRET> doExpand (ExpandContext&, \
2879 const CLASS::ArgExprs& args_) const \
2880 { \
2881 const ExprP<T0>& ARG0 = args_.a; \
2882 return EXPANSION; \
2883 } \
2884 }; \
2885 DEFINE_CONSTRUCTOR1(CLASS, TRET, NAME, T0)
2886
2887 #define DEFINE_DERIVED_DOUBLE1(CLASS, NAME, ARG0, EXPANSION) \
2888 DEFINE_DERIVED1(CLASS, double, NAME, double, ARG0, EXPANSION)
2889
2890 #define DEFINE_DERIVED_FLOAT1(CLASS, NAME, ARG0, EXPANSION) \
2891 DEFINE_DERIVED1(CLASS, float, NAME, float, ARG0, EXPANSION)
2892
2893
2894 #define DEFINE_DERIVED1_INPUTRANGE(CLASS, TRET, NAME, T0, ARG0, EXPANSION, INTERVAL) \
2895 class CLASS : public DerivedFunc<Signature<TRET, T0> > /* NOLINT(CLASS) */ \
2896 { \
2897 public: \
2898 string getName (void) const { return #NAME; } \
2899 \
2900 protected: \
2901 ExprP<TRET> doExpand (ExpandContext&, \
2902 const CLASS::ArgExprs& args_) const \
2903 { \
2904 const ExprP<T0>& ARG0 = args_.a; \
2905 return EXPANSION; \
2906 } \
2907 Interval getInputRange (const bool /*is16bit*/) const \
2908 { \
2909 return INTERVAL; \
2910 } \
2911 }; \
2912 DEFINE_CONSTRUCTOR1(CLASS, TRET, NAME, T0)
2913
2914 #define DEFINE_DERIVED_FLOAT1_INPUTRANGE(CLASS, NAME, ARG0, EXPANSION, INTERVAL) \
2915 DEFINE_DERIVED1_INPUTRANGE(CLASS, float, NAME, float, ARG0, EXPANSION, INTERVAL)
2916
2917 #define DEFINE_DERIVED_DOUBLE1_INPUTRANGE(CLASS, NAME, ARG0, EXPANSION, INTERVAL) \
2918 DEFINE_DERIVED1_INPUTRANGE(CLASS, double, NAME, double, ARG0, EXPANSION, INTERVAL)
2919
2920 #define DEFINE_DERIVED_FLOAT1_16BIT(CLASS, NAME, ARG0, EXPANSION) \
2921 DEFINE_DERIVED1(CLASS, deFloat16, NAME, deFloat16, ARG0, EXPANSION)
2922
2923 #define DEFINE_DERIVED_FLOAT1_INPUTRANGE_16BIT(CLASS, NAME, ARG0, EXPANSION, INTERVAL) \
2924 DEFINE_DERIVED1_INPUTRANGE(CLASS, deFloat16, NAME, deFloat16, ARG0, EXPANSION, INTERVAL)
2925
2926 #define DEFINE_CONSTRUCTOR2(CLASS, TRET, NAME, T0, T1) \
2927 ExprP<TRET> NAME (const ExprP<T0>& arg0, const ExprP<T1>& arg1) \
2928 { \
2929 return app<CLASS>(arg0, arg1); \
2930 }
2931
2932 #define DEFINE_CASED_DERIVED2(CLASS, TRET, NAME, T0, Arg0, T1, Arg1, EXPANSION, SPIRVCASE) \
2933 class CLASS : public DerivedFunc<Signature<TRET, T0, T1> > /* NOLINT(CLASS) */ \
2934 { \
2935 public: \
2936 string getName (void) const { return #NAME; } \
2937 \
2938 SpirVCaseT getSpirvCase(void) const { return SPIRVCASE; } \
2939 \
2940 protected: \
2941 ExprP<TRET> doExpand (ExpandContext&, const ArgExprs& args_) const \
2942 { \
2943 const ExprP<T0>& Arg0 = args_.a; \
2944 const ExprP<T1>& Arg1 = args_.b; \
2945 return EXPANSION; \
2946 } \
2947 }; \
2948 DEFINE_CONSTRUCTOR2(CLASS, TRET, NAME, T0, T1)
2949
2950 #define DEFINE_DERIVED2(CLASS, TRET, NAME, T0, Arg0, T1, Arg1, EXPANSION) \
2951 DEFINE_CASED_DERIVED2(CLASS, TRET, NAME, T0, Arg0, T1, Arg1, EXPANSION, SPIRV_CASETYPE_NONE)
2952
2953 #define DEFINE_DERIVED_DOUBLE2(CLASS, NAME, Arg0, Arg1, EXPANSION) \
2954 DEFINE_DERIVED2(CLASS, double, NAME, double, Arg0, double, Arg1, EXPANSION)
2955
2956 #define DEFINE_DERIVED_FLOAT2(CLASS, NAME, Arg0, Arg1, EXPANSION) \
2957 DEFINE_DERIVED2(CLASS, float, NAME, float, Arg0, float, Arg1, EXPANSION)
2958
2959 #define DEFINE_DERIVED_FLOAT2_16BIT(CLASS, NAME, Arg0, Arg1, EXPANSION) \
2960 DEFINE_DERIVED2(CLASS, deFloat16, NAME, deFloat16, Arg0, deFloat16, Arg1, EXPANSION)
2961
2962 #define DEFINE_CASED_DERIVED_FLOAT2(CLASS, NAME, Arg0, Arg1, EXPANSION, SPIRVCASE) \
2963 DEFINE_CASED_DERIVED2(CLASS, float, NAME, float, Arg0, float, Arg1, EXPANSION, SPIRVCASE)
2964
2965 #define DEFINE_CASED_DERIVED_FLOAT2_16BIT(CLASS, NAME, Arg0, Arg1, EXPANSION, SPIRVCASE) \
2966 DEFINE_CASED_DERIVED2(CLASS, deFloat16, NAME, deFloat16, Arg0, deFloat16, Arg1, EXPANSION, SPIRVCASE)
2967
2968 #define DEFINE_CASED_DERIVED_DOUBLE2(CLASS, NAME, Arg0, Arg1, EXPANSION, SPIRVCASE) \
2969 DEFINE_CASED_DERIVED2(CLASS, double, NAME, double, Arg0, double, Arg1, EXPANSION, SPIRVCASE)
2970
2971 #define DEFINE_CONSTRUCTOR3(CLASS, TRET, NAME, T0, T1, T2) \
2972 ExprP<TRET> NAME (const ExprP<T0>& arg0, const ExprP<T1>& arg1, const ExprP<T2>& arg2) \
2973 { \
2974 return app<CLASS>(arg0, arg1, arg2); \
2975 }
2976
2977 #define DEFINE_DERIVED3(CLASS, TRET, NAME, T0, ARG0, T1, ARG1, T2, ARG2, EXPANSION) \
2978 class CLASS : public DerivedFunc<Signature<TRET, T0, T1, T2> > /* NOLINT(CLASS) */ \
2979 { \
2980 public: \
2981 string getName (void) const { return #NAME; } \
2982 \
2983 protected: \
2984 ExprP<TRET> doExpand (ExpandContext&, const ArgExprs& args_) const \
2985 { \
2986 const ExprP<T0>& ARG0 = args_.a; \
2987 const ExprP<T1>& ARG1 = args_.b; \
2988 const ExprP<T2>& ARG2 = args_.c; \
2989 return EXPANSION; \
2990 } \
2991 }; \
2992 DEFINE_CONSTRUCTOR3(CLASS, TRET, NAME, T0, T1, T2)
2993
2994 #define DEFINE_DERIVED_DOUBLE3(CLASS, NAME, ARG0, ARG1, ARG2, EXPANSION) \
2995 DEFINE_DERIVED3(CLASS, double, NAME, double, ARG0, double, ARG1, double, ARG2, EXPANSION)
2996
2997 #define DEFINE_DERIVED_FLOAT3(CLASS, NAME, ARG0, ARG1, ARG2, EXPANSION) \
2998 DEFINE_DERIVED3(CLASS, float, NAME, float, ARG0, float, ARG1, float, ARG2, EXPANSION)
2999
3000 #define DEFINE_DERIVED_FLOAT3_16BIT(CLASS, NAME, ARG0, ARG1, ARG2, EXPANSION) \
3001 DEFINE_DERIVED3(CLASS, deFloat16, NAME, deFloat16, ARG0, deFloat16, ARG1, deFloat16, ARG2, EXPANSION)
3002
3003 #define DEFINE_CONSTRUCTOR4(CLASS, TRET, NAME, T0, T1, T2, T3) \
3004 ExprP<TRET> NAME (const ExprP<T0>& arg0, const ExprP<T1>& arg1, \
3005 const ExprP<T2>& arg2, const ExprP<T3>& arg3) \
3006 { \
3007 return app<CLASS>(arg0, arg1, arg2, arg3); \
3008 }
3009
3010 typedef InverseSqrt< Signature<deFloat16, deFloat16> > InverseSqrt16Bit;
3011 typedef InverseSqrt< Signature<float, float> > InverseSqrt32Bit;
3012 typedef InverseSqrt< Signature<double, double> > InverseSqrt64Bit;
3013
3014 DEFINE_DERIVED_FLOAT1(Sqrt32Bit, sqrt, x, constant(1.0f) / app<InverseSqrt32Bit>(x));
3015 DEFINE_DERIVED_FLOAT1_16BIT(Sqrt16Bit, sqrt, x, constant((deFloat16)FLOAT16_1_0) / app<InverseSqrt16Bit>(x));
3016 DEFINE_DERIVED_DOUBLE1(Sqrt64Bit, sqrt, x, constant(1.0) / app<InverseSqrt64Bit>(x));
3017 DEFINE_DERIVED_FLOAT2(Pow, pow, x, y, exp2<float>(y * log2(x)));
3018 DEFINE_DERIVED_FLOAT2_16BIT(Pow16, pow, x, y, exp2<deFloat16>(y * log2(x)));
3019 DEFINE_DERIVED_DOUBLE2(Pow64, pow, x, y, exp2<double>(y * log2(x)));
3020 DEFINE_DERIVED_FLOAT1(Radians, radians, d, (constant(DE_PI) / constant(180.0f)) * d);
3021 DEFINE_DERIVED_FLOAT1_16BIT(Radians16, radians, d, (constant((deFloat16)DE_PI_16BIT) / constant((deFloat16)FLOAT16_180_0)) * d);
3022 DEFINE_DERIVED_DOUBLE1(Radians64, radians, d, (constant((double)(DE_PI)) / constant(180.0)) * d);
3023 DEFINE_DERIVED_FLOAT1(Degrees, degrees, r, (constant(180.0f) / constant(DE_PI)) * r);
3024 DEFINE_DERIVED_FLOAT1_16BIT(Degrees16, degrees, r, (constant((deFloat16)FLOAT16_180_0) / constant((deFloat16)DE_PI_16BIT)) * r);
3025 DEFINE_DERIVED_DOUBLE1(Degrees64, degrees, r, (constant(180.0) / constant((double)(DE_PI))) * r);
3026
3027 /*Proper parameters for template T
3028 Signature<float, float> 32bit tests
3029 Signature<float, deFloat16> 16bit tests*/
3030 template<class T>
3031 class TrigFunc : public CFloatFunc1<T>
3032 {
3033 public:
TrigFunc(const string & name,DoubleFunc1 & func,const Interval & loEx,const Interval & hiEx)3034 TrigFunc (const string& name,
3035 DoubleFunc1& func,
3036 const Interval& loEx,
3037 const Interval& hiEx)
3038 : CFloatFunc1<T> (name, func)
3039 , m_loExtremum (loEx)
3040 , m_hiExtremum (hiEx) {}
3041
3042 protected:
innerExtrema(const EvalContext &,const Interval & angle) const3043 Interval innerExtrema (const EvalContext&, const Interval& angle) const
3044 {
3045 const double lo = angle.lo();
3046 const double hi = angle.hi();
3047 const int loSlope = doGetSlope(lo);
3048 const int hiSlope = doGetSlope(hi);
3049
3050 // Detect the high and low values the function can take between the
3051 // interval endpoints.
3052 if (angle.length() >= 2.0 * DE_PI_DOUBLE)
3053 {
3054 // The interval is longer than a full cycle, so it must get all possible values.
3055 return m_hiExtremum | m_loExtremum;
3056 }
3057 else if (loSlope == 1 && hiSlope == -1)
3058 {
3059 // The slope can change from positive to negative only at the maximum value.
3060 return m_hiExtremum;
3061 }
3062 else if (loSlope == -1 && hiSlope == 1)
3063 {
3064 // The slope can change from negative to positive only at the maximum value.
3065 return m_loExtremum;
3066 }
3067 else if (loSlope == hiSlope &&
3068 deIntSign(CFloatFunc1<T>::applyExact(hi) - CFloatFunc1<T>::applyExact(lo)) * loSlope == -1)
3069 {
3070 // The slope has changed twice between the endpoints, so both extrema are included.
3071 return m_hiExtremum | m_loExtremum;
3072 }
3073
3074 return Interval();
3075 }
3076
getCodomain(const EvalContext &) const3077 Interval getCodomain (const EvalContext&) const
3078 {
3079 // Ensure that result is always within [-1, 1], or NaN (for +-inf)
3080 return Interval(-1.0, 1.0) | TCU_NAN;
3081 }
3082
3083 double precision (const EvalContext& ctx, double ret, double arg) const;
3084
3085 Interval getInputRange (const bool is16bit) const;
3086 virtual int doGetSlope (double angle) const = 0;
3087
3088 Interval m_loExtremum;
3089 Interval m_hiExtremum;
3090 };
3091
3092 //Only -DE_PI_DOUBLE, DE_PI_DOUBLE input range
3093 template<>
getInputRange(const bool is16bit) const3094 Interval TrigFunc<Signature<float, float> >::getInputRange(const bool is16bit) const
3095 {
3096 DE_UNREF(is16bit);
3097 return Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE);
3098 }
3099
3100 //Only -DE_PI_DOUBLE, DE_PI_DOUBLE input range
3101 template<>
getInputRange(const bool is16bit) const3102 Interval TrigFunc<Signature<deFloat16, deFloat16> >::getInputRange(const bool is16bit) const
3103 {
3104 DE_UNREF(is16bit);
3105 return Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE);
3106 }
3107
3108 //Only -DE_PI_DOUBLE, DE_PI_DOUBLE input range
3109 template<>
getInputRange(const bool is16bit) const3110 Interval TrigFunc<Signature<double, double> >::getInputRange(const bool is16bit) const
3111 {
3112 DE_UNREF(is16bit);
3113 return Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE);
3114 }
3115
3116 template<>
precision(const EvalContext & ctx,double ret,double arg) const3117 double TrigFunc<Signature<float, float> >::precision(const EvalContext& ctx, double ret, double arg) const
3118 {
3119 DE_UNREF(ret);
3120 if (ctx.floatPrecision == glu::PRECISION_HIGHP)
3121 {
3122 if (-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE)
3123 return deLdExp(1.0, -11);
3124 else
3125 {
3126 // "larger otherwise", let's pick |x| * 2^-12 , which is slightly over
3127 // 2^-11 at x == pi.
3128 return deLdExp(deAbs(arg), -12);
3129 }
3130 }
3131 else
3132 {
3133 DE_ASSERT(ctx.floatPrecision == glu::PRECISION_MEDIUMP || ctx.floatPrecision == glu::PRECISION_LAST);
3134
3135 if (-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE)
3136 return deLdExp(1.0, -7);
3137 else
3138 {
3139 // |x| * 2^-8, slightly larger than 2^-7 at x == pi
3140 return deLdExp(deAbs(arg), -8);
3141 }
3142 }
3143 }
3144 //
3145 /*
3146 * Half tests
3147 * From Spec:
3148 * Absolute error 2^{-7} inside the range [-pi, pi].
3149 */
3150 template<>
precision(const EvalContext & ctx,double ret,double arg) const3151 double TrigFunc<Signature<deFloat16, deFloat16> >::precision(const EvalContext& ctx, double ret, double arg) const
3152 {
3153 DE_UNREF(ctx);
3154 DE_UNREF(ret);
3155 DE_UNREF(arg);
3156 DE_ASSERT(-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE && ctx.floatPrecision == glu::PRECISION_LAST);
3157 return deLdExp(1.0, -7);
3158 }
3159
3160 // Spec: "The precision of double-precision instructions is at least that of single precision."
3161 // Lets pick float high precision as a reference.
3162 template<>
precision(const EvalContext & ctx,double ret,double arg) const3163 double TrigFunc<Signature<double, double> >::precision(const EvalContext& ctx, double ret, double arg) const
3164 {
3165 DE_UNREF(ctx);
3166 DE_UNREF(ret);
3167 if (-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE)
3168 return deLdExp(1.0, -11);
3169 else
3170 {
3171 // "larger otherwise", let's pick |x| * 2^-12 , which is slightly over
3172 // 2^-11 at x == pi.
3173 return deLdExp(deAbs(arg), -12);
3174 }
3175 }
3176
3177 /*Proper parameters for template T
3178 Signature<float, float> 32bit tests
3179 Signature<float, deFloat16> 16bit tests*/
3180 template <class T>
3181 class Sin : public TrigFunc<T>
3182 {
3183 public:
Sin(void)3184 Sin (void) : TrigFunc<T>("sin", deSin, -1.0, 1.0) {}
3185
3186 protected:
doGetSlope(double angle) const3187 int doGetSlope (double angle) const { return deIntSign(deCos(angle)); }
3188 };
3189
sin(const ExprP<float> & x)3190 ExprP<float> sin (const ExprP<float>& x) { return app<Sin<Signature<float, float> > >(x); }
sin(const ExprP<deFloat16> & x)3191 ExprP<deFloat16> sin (const ExprP<deFloat16>& x) { return app<Sin<Signature<deFloat16, deFloat16> > >(x); }
sin(const ExprP<double> & x)3192 ExprP<double> sin (const ExprP<double>& x) { return app<Sin<Signature<double, double> > >(x); }
3193
3194 template <class T>
3195 class Cos : public TrigFunc<T>
3196 {
3197 public:
Cos(void)3198 Cos (void) : TrigFunc<T> ("cos", deCos, -1.0, 1.0) {}
3199
3200 protected:
doGetSlope(double angle) const3201 int doGetSlope (double angle) const { return -deIntSign(deSin(angle)); }
3202 };
3203
cos(const ExprP<float> & x)3204 ExprP<float> cos (const ExprP<float>& x) { return app<Cos<Signature<float, float> > >(x); }
cos(const ExprP<deFloat16> & x)3205 ExprP<deFloat16> cos (const ExprP<deFloat16>& x) { return app<Cos<Signature<deFloat16, deFloat16> > >(x); }
cos(const ExprP<double> & x)3206 ExprP<double> cos (const ExprP<double>& x) { return app<Cos<Signature<double, double> > >(x); }
3207
3208 DEFINE_DERIVED_FLOAT1_INPUTRANGE(Tan, tan, x, sin(x) * (constant(1.0f) / cos(x)), Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE));
3209 DEFINE_DERIVED_FLOAT1_INPUTRANGE_16BIT(Tan16Bit, tan, x, sin(x) * (constant((deFloat16)FLOAT16_1_0) / cos(x)), Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE));
3210 DEFINE_DERIVED_DOUBLE1_INPUTRANGE(Tan64Bit, tan, x, sin(x) * (constant(1.0) / cos(x)), Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE));
3211
3212 template <class T>
3213 class ATan : public CFloatFunc1<T>
3214 {
3215 public:
ATan(void)3216 ATan (void) : CFloatFunc1<T> ("atan", deAtanOver) {}
3217
3218 protected:
precision(const EvalContext & ctx,double ret,double) const3219 double precision (const EvalContext& ctx, double ret, double) const
3220 {
3221 if (ctx.floatPrecision == glu::PRECISION_HIGHP)
3222 return ctx.format.ulp(ret, 4096.0);
3223 else
3224 return ctx.format.ulp(ret, 5.0);
3225 }
3226
getCodomain(const EvalContext & ctx) const3227 Interval getCodomain(const EvalContext& ctx) const
3228 {
3229 return ctx.format.roundOut(Interval(-0.5 * DE_PI_DOUBLE, 0.5 * DE_PI_DOUBLE), true);
3230 }
3231 };
3232
3233 template <class T>
3234 class ATan2 : public CFloatFunc2<T>
3235 {
3236 public:
ATan2(void)3237 ATan2 (void) : CFloatFunc2<T> ("atan", deAtan2) {}
3238
3239 protected:
innerExtrema(const EvalContext & ctx,const Interval & yi,const Interval & xi) const3240 Interval innerExtrema (const EvalContext& ctx,
3241 const Interval& yi,
3242 const Interval& xi) const
3243 {
3244 Interval ret;
3245
3246 if (yi.contains(0.0))
3247 {
3248 if (xi.contains(0.0))
3249 ret |= TCU_NAN;
3250 if (xi.intersects(Interval(-TCU_INFINITY, 0.0)))
3251 ret |= ctx.format.roundOut(Interval(-DE_PI_DOUBLE, DE_PI_DOUBLE), true);
3252 }
3253
3254 if (ctx.format.hasInf() != YES && (!yi.isFinite() || !xi.isFinite()))
3255 {
3256 // Infinities may not be supported, allow anything, including NaN
3257 ret |= TCU_NAN;
3258 }
3259
3260 return ret;
3261 }
3262
precision(const EvalContext & ctx,double ret,double,double) const3263 double precision (const EvalContext& ctx, double ret, double, double) const
3264 {
3265 if (ctx.floatPrecision == glu::PRECISION_HIGHP)
3266 return ctx.format.ulp(ret, 4096.0);
3267 else
3268 return ctx.format.ulp(ret, 5.0);
3269 }
3270
getCodomain(const EvalContext & ctx) const3271 Interval getCodomain(const EvalContext& ctx) const
3272 {
3273 return ctx.format.roundOut(Interval(-DE_PI_DOUBLE, DE_PI_DOUBLE), true);
3274 }
3275 };
3276
atan2(const ExprP<float> & x,const ExprP<float> & y)3277 ExprP<float> atan2 (const ExprP<float>& x, const ExprP<float>& y) { return app<ATan2<Signature<float, float, float> > >(x, y); }
3278
atan2(const ExprP<deFloat16> & x,const ExprP<deFloat16> & y)3279 ExprP<deFloat16> atan2 (const ExprP<deFloat16>& x, const ExprP<deFloat16>& y) { return app<ATan2<Signature<deFloat16, deFloat16, deFloat16> > >(x, y); }
3280
atan2(const ExprP<double> & x,const ExprP<double> & y)3281 ExprP<double> atan2 (const ExprP<double>& x, const ExprP<double>& y) { return app<ATan2<Signature<double, double, double> > >(x, y); }
3282
3283
3284 DEFINE_DERIVED_FLOAT1(Sinh, sinh, x, (exp<float>(x) - exp<float>(-x)) / constant(2.0f));
3285 DEFINE_DERIVED_FLOAT1(Cosh, cosh, x, (exp<float>(x) + exp<float>(-x)) / constant(2.0f));
3286 DEFINE_DERIVED_FLOAT1(Tanh, tanh, x, sinh(x) / cosh(x));
3287
3288 DEFINE_DERIVED_FLOAT1_16BIT(Sinh16Bit, sinh, x, (exp(x) - exp(-x)) / constant((deFloat16)FLOAT16_2_0));
3289 DEFINE_DERIVED_FLOAT1_16BIT(Cosh16Bit, cosh, x, (exp(x) + exp(-x)) / constant((deFloat16)FLOAT16_2_0));
3290 DEFINE_DERIVED_FLOAT1_16BIT(Tanh16Bit, tanh, x, sinh(x) / cosh(x));
3291
3292 DEFINE_DERIVED_DOUBLE1(Sinh64Bit, sinh, x, (exp<double>(x) - exp<double>(-x)) / constant(2.0));
3293 DEFINE_DERIVED_DOUBLE1(Cosh64Bit, cosh, x, (exp<double>(x) + exp<double>(-x)) / constant(2.0));
3294 DEFINE_DERIVED_DOUBLE1(Tanh64Bit, tanh, x, sinh(x) / cosh(x));
3295
3296 DEFINE_DERIVED_FLOAT1(ASin, asin, x, atan2(x, sqrt(constant(1.0f) - x * x)));
3297 DEFINE_DERIVED_FLOAT1(ACos, acos, x, atan2(sqrt(constant(1.0f) - x * x), x));
3298 DEFINE_DERIVED_FLOAT1(ASinh, asinh, x, log(x + sqrt(x * x + constant(1.0f))));
3299 DEFINE_DERIVED_FLOAT1(ACosh, acosh, x, log(x + sqrt(alternatives((x + constant(1.0f)) * (x - constant(1.0f)),
3300 (x * x - constant(1.0f))))));
3301 DEFINE_DERIVED_FLOAT1(ATanh, atanh, x, constant(0.5f) * log((constant(1.0f) + x) /
3302 (constant(1.0f) - x)));
3303
3304 DEFINE_DERIVED_FLOAT1_16BIT(ASin16Bit, asin, x, atan2(x, sqrt(constant((deFloat16)FLOAT16_1_0) - x * x)));
3305 DEFINE_DERIVED_FLOAT1_16BIT(ACos16Bit, acos, x, atan2(sqrt(constant((deFloat16)FLOAT16_1_0) - x * x), x));
3306 DEFINE_DERIVED_FLOAT1_16BIT(ASinh16Bit, asinh, x, log(x + sqrt(x * x + constant((deFloat16)FLOAT16_1_0))));
3307 DEFINE_DERIVED_FLOAT1_16BIT(ACosh16Bit, acosh, x, log(x + sqrt(alternatives((x + constant((deFloat16)FLOAT16_1_0)) * (x - constant((deFloat16)FLOAT16_1_0)),
3308 (x * x - constant((deFloat16)FLOAT16_1_0))))));
3309 DEFINE_DERIVED_FLOAT1_16BIT(ATanh16Bit, atanh, x, constant((deFloat16)FLOAT16_0_5) * log((constant((deFloat16)FLOAT16_1_0) + x) /
3310 (constant((deFloat16)FLOAT16_1_0) - x)));
3311
3312 DEFINE_DERIVED_DOUBLE1(ASin64Bit, asin, x, atan2(x, sqrt(constant(1.0) - pow(x, constant(2.0)))));
3313 DEFINE_DERIVED_DOUBLE1(ACos64Bit, acos, x, atan2(sqrt(constant(1.0) - pow(x, constant(2.0))), x));
3314 DEFINE_DERIVED_DOUBLE1(ASinh64Bit, asinh, x, log(x + sqrt(x * x + constant(1.0))));
3315 DEFINE_DERIVED_DOUBLE1(ACosh64Bit, acosh, x, log(x + sqrt(alternatives((x + constant(1.0)) * (x - constant(1.0)),
3316 (x * x - constant(1.0))))));
3317 DEFINE_DERIVED_DOUBLE1(ATanh64Bit, atanh, x, constant(0.5) * log((constant(1.0) + x) /
3318 (constant(1.0) - x)));
3319
3320 template <typename T>
3321 class GetComponent : public PrimitiveFunc<Signature<typename T::Element, T, int> >
3322 {
3323 public:
3324 typedef typename GetComponent::IRet IRet;
3325
getName(void) const3326 string getName (void) const { return "_getComponent"; }
3327
print(ostream & os,const BaseArgExprs & args) const3328 void print (ostream& os,
3329 const BaseArgExprs& args) const
3330 {
3331 os << *args[0] << "[" << *args[1] << "]";
3332 }
3333
3334 protected:
doApply(const EvalContext &,const typename GetComponent::IArgs & iargs) const3335 IRet doApply (const EvalContext&,
3336 const typename GetComponent::IArgs& iargs) const
3337 {
3338 IRet ret;
3339
3340 for (int compNdx = 0; compNdx < T::SIZE; ++compNdx)
3341 {
3342 if (iargs.b.contains(compNdx))
3343 ret = unionIVal<typename T::Element>(ret, iargs.a[compNdx]);
3344 }
3345
3346 return ret;
3347 }
3348
3349 };
3350
3351 template <typename T>
getComponent(const ExprP<T> & container,int ndx)3352 ExprP<typename T::Element> getComponent (const ExprP<T>& container, int ndx)
3353 {
3354 DE_ASSERT(0 <= ndx && ndx < T::SIZE);
3355 return app<GetComponent<T> >(container, constant(ndx));
3356 }
3357
3358 template <typename T> string vecNamePrefix (void);
vecNamePrefix(void)3359 template <> string vecNamePrefix<float> (void) { return ""; }
vecNamePrefix(void)3360 template <> string vecNamePrefix<deFloat16>(void) { return ""; }
vecNamePrefix(void)3361 template <> string vecNamePrefix<double> (void) { return "d"; }
vecNamePrefix(void)3362 template <> string vecNamePrefix<int> (void) { return "i"; }
vecNamePrefix(void)3363 template <> string vecNamePrefix<bool> (void) { return "b"; }
3364
3365 template <typename T, int Size>
vecName(void)3366 string vecName (void) { return vecNamePrefix<T>() + "vec" + de::toString(Size); }
3367
3368 template <typename T, int Size> class GenVec;
3369
3370 template <typename T>
3371 class GenVec<T, 1> : public DerivedFunc<Signature<T, T> >
3372 {
3373 public:
3374 typedef typename GenVec<T, 1>::ArgExprs ArgExprs;
3375
getName(void) const3376 string getName (void) const
3377 {
3378 return "_" + vecName<T, 1>();
3379 }
3380
3381 protected:
3382
doExpand(ExpandContext &,const ArgExprs & args) const3383 ExprP<T> doExpand (ExpandContext&, const ArgExprs& args) const { return args.a; }
3384 };
3385
3386 template <typename T>
3387 class GenVec<T, 2> : public PrimitiveFunc<Signature<Vector<T, 2>, T, T> >
3388 {
3389 public:
3390 typedef typename GenVec::IRet IRet;
3391 typedef typename GenVec::IArgs IArgs;
3392
getName(void) const3393 string getName (void) const
3394 {
3395 return vecName<T, 2>();
3396 }
3397
3398 protected:
doApply(const EvalContext &,const IArgs & iargs) const3399 IRet doApply (const EvalContext&, const IArgs& iargs) const
3400 {
3401 return IRet(iargs.a, iargs.b);
3402 }
3403 };
3404
3405 template <typename T>
3406 class GenVec<T, 3> : public PrimitiveFunc<Signature<Vector<T, 3>, T, T, T> >
3407 {
3408 public:
3409 typedef typename GenVec::IRet IRet;
3410 typedef typename GenVec::IArgs IArgs;
3411
getName(void) const3412 string getName (void) const
3413 {
3414 return vecName<T, 3>();
3415 }
3416
3417 protected:
doApply(const EvalContext &,const IArgs & iargs) const3418 IRet doApply (const EvalContext&, const IArgs& iargs) const
3419 {
3420 return IRet(iargs.a, iargs.b, iargs.c);
3421 }
3422 };
3423
3424 template <typename T>
3425 class GenVec<T, 4> : public PrimitiveFunc<Signature<Vector<T, 4>, T, T, T, T> >
3426 {
3427 public:
3428 typedef typename GenVec::IRet IRet;
3429 typedef typename GenVec::IArgs IArgs;
3430
getName(void) const3431 string getName (void) const { return vecName<T, 4>(); }
3432
3433 protected:
doApply(const EvalContext &,const IArgs & iargs) const3434 IRet doApply (const EvalContext&, const IArgs& iargs) const
3435 {
3436 return IRet(iargs.a, iargs.b, iargs.c, iargs.d);
3437 }
3438 };
3439
3440 template <typename T, int Rows, int Columns>
3441 class GenMat;
3442
3443 template <typename T, int Rows>
3444 class GenMat<T, Rows, 2> : public PrimitiveFunc<
3445 Signature<Matrix<T, Rows, 2>, Vector<T, Rows>, Vector<T, Rows> > >
3446 {
3447 public:
3448 typedef typename GenMat::Ret Ret;
3449 typedef typename GenMat::IRet IRet;
3450 typedef typename GenMat::IArgs IArgs;
3451
getName(void) const3452 string getName (void) const
3453 {
3454 return dataTypeNameOf<Ret>();
3455 }
3456
3457 protected:
3458
doApply(const EvalContext &,const IArgs & iargs) const3459 IRet doApply (const EvalContext&, const IArgs& iargs) const
3460 {
3461 IRet ret;
3462 ret[0] = iargs.a;
3463 ret[1] = iargs.b;
3464 return ret;
3465 }
3466 };
3467
3468 template <typename T, int Rows>
3469 class GenMat<T, Rows, 3> : public PrimitiveFunc<
3470 Signature<Matrix<T, Rows, 3>, Vector<T, Rows>, Vector<T, Rows>, Vector<T, Rows> > >
3471 {
3472 public:
3473 typedef typename GenMat::Ret Ret;
3474 typedef typename GenMat::IRet IRet;
3475 typedef typename GenMat::IArgs IArgs;
3476
getName(void) const3477 string getName (void) const
3478 {
3479 return dataTypeNameOf<Ret>();
3480 }
3481
3482 protected:
3483
doApply(const EvalContext &,const IArgs & iargs) const3484 IRet doApply (const EvalContext&, const IArgs& iargs) const
3485 {
3486 IRet ret;
3487 ret[0] = iargs.a;
3488 ret[1] = iargs.b;
3489 ret[2] = iargs.c;
3490 return ret;
3491 }
3492 };
3493
3494 template <typename T, int Rows>
3495 class GenMat<T, Rows, 4> : public PrimitiveFunc<
3496 Signature<Matrix<T, Rows, 4>,
3497 Vector<T, Rows>, Vector<T, Rows>, Vector<T, Rows>, Vector<T, Rows> > >
3498 {
3499 public:
3500 typedef typename GenMat::Ret Ret;
3501 typedef typename GenMat::IRet IRet;
3502 typedef typename GenMat::IArgs IArgs;
3503
getName(void) const3504 string getName (void) const
3505 {
3506 return dataTypeNameOf<Ret>();
3507 }
3508
3509 protected:
doApply(const EvalContext &,const IArgs & iargs) const3510 IRet doApply (const EvalContext&, const IArgs& iargs) const
3511 {
3512 IRet ret;
3513 ret[0] = iargs.a;
3514 ret[1] = iargs.b;
3515 ret[2] = iargs.c;
3516 ret[3] = iargs.d;
3517 return ret;
3518 }
3519 };
3520
3521 template <typename T, int Rows>
mat2(const ExprP<Vector<T,Rows>> & arg0,const ExprP<Vector<T,Rows>> & arg1)3522 ExprP<Matrix<T, Rows, 2> > mat2 (const ExprP<Vector<T, Rows> >& arg0,
3523 const ExprP<Vector<T, Rows> >& arg1)
3524 {
3525 return app<GenMat<T, Rows, 2> >(arg0, arg1);
3526 }
3527
3528 template <typename T, int Rows>
mat3(const ExprP<Vector<T,Rows>> & arg0,const ExprP<Vector<T,Rows>> & arg1,const ExprP<Vector<T,Rows>> & arg2)3529 ExprP<Matrix<T, Rows, 3> > mat3 (const ExprP<Vector<T, Rows> >& arg0,
3530 const ExprP<Vector<T, Rows> >& arg1,
3531 const ExprP<Vector<T, Rows> >& arg2)
3532 {
3533 return app<GenMat<T, Rows, 3> >(arg0, arg1, arg2);
3534 }
3535
3536 template <typename T, int Rows>
mat4(const ExprP<Vector<T,Rows>> & arg0,const ExprP<Vector<T,Rows>> & arg1,const ExprP<Vector<T,Rows>> & arg2,const ExprP<Vector<T,Rows>> & arg3)3537 ExprP<Matrix<T, Rows, 4> > mat4 (const ExprP<Vector<T, Rows> >& arg0,
3538 const ExprP<Vector<T, Rows> >& arg1,
3539 const ExprP<Vector<T, Rows> >& arg2,
3540 const ExprP<Vector<T, Rows> >& arg3)
3541 {
3542 return app<GenMat<T, Rows, 4> >(arg0, arg1, arg2, arg3);
3543 }
3544
3545 template <typename T, int Rows, int Cols>
3546 class MatNeg : public PrimitiveFunc<Signature<Matrix<T, Rows, Cols>,
3547 Matrix<T, Rows, Cols> > >
3548 {
3549 public:
3550 typedef typename MatNeg::IRet IRet;
3551 typedef typename MatNeg::IArgs IArgs;
3552
getName(void) const3553 string getName (void) const
3554 {
3555 return "_matNeg";
3556 }
3557
3558 protected:
doPrint(ostream & os,const BaseArgExprs & args) const3559 void doPrint (ostream& os, const BaseArgExprs& args) const
3560 {
3561 os << "-(" << *args[0] << ")";
3562 }
3563
doApply(const EvalContext &,const IArgs & iargs) const3564 IRet doApply (const EvalContext&, const IArgs& iargs) const
3565 {
3566 IRet ret;
3567
3568 for (int col = 0; col < Cols; ++col)
3569 {
3570 for (int row = 0; row < Rows; ++row)
3571 ret[col][row] = -iargs.a[col][row];
3572 }
3573
3574 return ret;
3575 }
3576 };
3577
3578 template <typename T, typename Sig>
3579 class CompWiseFunc : public PrimitiveFunc<Sig>
3580 {
3581 public:
3582 typedef Func<Signature<T, T, T> > ScalarFunc;
3583
getName(void) const3584 string getName (void) const
3585 {
3586 return doGetScalarFunc().getName();
3587 }
3588 protected:
doPrint(ostream & os,const BaseArgExprs & args) const3589 void doPrint (ostream& os,
3590 const BaseArgExprs& args) const
3591 {
3592 doGetScalarFunc().print(os, args);
3593 }
3594
3595 virtual
3596 const ScalarFunc& doGetScalarFunc (void) const = 0;
3597 };
3598
3599 template <typename T, int Rows, int Cols>
3600 class CompMatFuncBase : public CompWiseFunc<T, Signature<Matrix<T, Rows, Cols>,
3601 Matrix<T, Rows, Cols>,
3602 Matrix<T, Rows, Cols> > >
3603 {
3604 public:
3605 typedef typename CompMatFuncBase::IRet IRet;
3606 typedef typename CompMatFuncBase::IArgs IArgs;
3607
3608 protected:
3609
doApply(const EvalContext & ctx,const IArgs & iargs) const3610 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
3611 {
3612 IRet ret;
3613
3614 for (int col = 0; col < Cols; ++col)
3615 {
3616 for (int row = 0; row < Rows; ++row)
3617 ret[col][row] = this->doGetScalarFunc().apply(ctx,
3618 iargs.a[col][row],
3619 iargs.b[col][row]);
3620 }
3621
3622 return ret;
3623 }
3624 };
3625
3626 template <typename F, typename T, int Rows, int Cols>
3627 class CompMatFunc : public CompMatFuncBase<T, Rows, Cols>
3628 {
3629 protected:
doGetScalarFunc(void) const3630 const typename CompMatFunc::ScalarFunc& doGetScalarFunc (void) const
3631 {
3632 return instance<F>();
3633 }
3634 };
3635
3636 template <class T>
3637 class ScalarMatrixCompMult : public Mul< Signature<T, T, T> >
3638 {
3639 public:
3640
getName(void) const3641 string getName (void) const
3642 {
3643 return "matrixCompMult";
3644 }
3645
doPrint(ostream & os,const BaseArgExprs & args) const3646 void doPrint (ostream& os, const BaseArgExprs& args) const
3647 {
3648 Func<Signature<T, T, T> >::doPrint(os, args);
3649 }
3650 };
3651
3652 template <int Rows, int Cols, class T>
3653 class MatrixCompMult : public CompMatFunc<ScalarMatrixCompMult<T>, T, Rows, Cols>
3654 {
3655 };
3656
3657 template <int Rows, int Cols>
3658 class ScalarMatFuncBase : public CompWiseFunc<float, Signature<Matrix<float, Rows, Cols>,
3659 Matrix<float, Rows, Cols>,
3660 float> >
3661 {
3662 public:
3663 typedef typename ScalarMatFuncBase::IRet IRet;
3664 typedef typename ScalarMatFuncBase::IArgs IArgs;
3665
3666 protected:
3667
doApply(const EvalContext & ctx,const IArgs & iargs) const3668 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
3669 {
3670 IRet ret;
3671
3672 for (int col = 0; col < Cols; ++col)
3673 {
3674 for (int row = 0; row < Rows; ++row)
3675 ret[col][row] = this->doGetScalarFunc().apply(ctx, iargs.a[col][row], iargs.b);
3676 }
3677
3678 return ret;
3679 }
3680 };
3681
3682 template <typename F, int Rows, int Cols>
3683 class ScalarMatFunc : public ScalarMatFuncBase<Rows, Cols>
3684 {
3685 protected:
doGetScalarFunc(void) const3686 const typename ScalarMatFunc::ScalarFunc& doGetScalarFunc (void) const
3687 {
3688 return instance<F>();
3689 }
3690 };
3691
3692 template<typename T, int Size> struct GenXType;
3693
3694 template<typename T>
3695 struct GenXType<T, 1>
3696 {
genXTypevkt::shaderexecutor::Functions::GenXType3697 static ExprP<T> genXType (const ExprP<T>& x) { return x; }
3698 };
3699
3700 template<typename T>
3701 struct GenXType<T, 2>
3702 {
genXTypevkt::shaderexecutor::Functions::GenXType3703 static ExprP<Vector<T, 2> > genXType (const ExprP<T>& x)
3704 {
3705 return app<GenVec<T, 2> >(x, x);
3706 }
3707 };
3708
3709 template<typename T>
3710 struct GenXType<T, 3>
3711 {
genXTypevkt::shaderexecutor::Functions::GenXType3712 static ExprP<Vector<T, 3> > genXType (const ExprP<T>& x)
3713 {
3714 return app<GenVec<T, 3> >(x, x, x);
3715 }
3716 };
3717
3718 template<typename T>
3719 struct GenXType<T, 4>
3720 {
genXTypevkt::shaderexecutor::Functions::GenXType3721 static ExprP<Vector<T, 4> > genXType (const ExprP<T>& x)
3722 {
3723 return app<GenVec<T, 4> >(x, x, x, x);
3724 }
3725 };
3726
3727 //! Returns an expression of vector of size `Size` (or scalar if Size == 1),
3728 //! with each element initialized with the expression `x`.
3729 template<typename T, int Size>
genXType(const ExprP<T> & x)3730 ExprP<typename ContainerOf<T, Size>::Container> genXType (const ExprP<T>& x)
3731 {
3732 return GenXType<T, Size>::genXType(x);
3733 }
3734
3735 typedef GenVec<float, 2> FloatVec2;
3736 DEFINE_CONSTRUCTOR2(FloatVec2, Vec2, vec2, float, float)
3737
3738 typedef GenVec<deFloat16, 2> FloatVec2_16bit;
3739 DEFINE_CONSTRUCTOR2(FloatVec2_16bit, Vec2_16Bit, vec2, deFloat16, deFloat16)
3740
3741 typedef GenVec<double, 2> DoubleVec2;
3742 DEFINE_CONSTRUCTOR2(DoubleVec2, Vec2_64Bit, vec2, double, double)
3743
3744 typedef GenVec<float, 3> FloatVec3;
3745 DEFINE_CONSTRUCTOR3(FloatVec3, Vec3, vec3, float, float, float)
3746
3747 typedef GenVec<deFloat16, 3> FloatVec3_16bit;
3748 DEFINE_CONSTRUCTOR3(FloatVec3_16bit, Vec3_16Bit, vec3, deFloat16, deFloat16, deFloat16)
3749
3750 typedef GenVec<double, 3> DoubleVec3;
3751 DEFINE_CONSTRUCTOR3(DoubleVec3, Vec3_64Bit, vec3, double, double, double)
3752
3753 typedef GenVec<float, 4> FloatVec4;
3754 DEFINE_CONSTRUCTOR4(FloatVec4, Vec4, vec4, float, float, float, float)
3755
3756 typedef GenVec<deFloat16, 4> FloatVec4_16bit;
3757 DEFINE_CONSTRUCTOR4(FloatVec4_16bit, Vec4_16Bit, vec4, deFloat16, deFloat16, deFloat16, deFloat16)
3758
3759 typedef GenVec<double, 4> DoubleVec4;
3760 DEFINE_CONSTRUCTOR4(DoubleVec4, Vec4_64Bit, vec4, double, double, double, double)
3761
3762 template <class T>
3763 const ExprP<T> getConstZero(void);
3764 template <class T>
3765 const ExprP<T> getConstOne(void);
3766 template <class T>
3767 const ExprP<T> getConstTwo(void);
3768
3769 template <>
getConstZero(void)3770 const ExprP<float> getConstZero<float>(void)
3771 {
3772 return constant(0.0f);
3773 }
3774
3775 template <>
getConstZero(void)3776 const ExprP<deFloat16> getConstZero<deFloat16>(void)
3777 {
3778 return constant((deFloat16)FLOAT16_0_0);
3779 }
3780
3781 template <>
getConstZero(void)3782 const ExprP<double> getConstZero<double>(void)
3783 {
3784 return constant(0.0);
3785 }
3786
3787 template <>
getConstOne(void)3788 const ExprP<float> getConstOne<float>(void)
3789 {
3790 return constant(1.0f);
3791 }
3792
3793 template <>
getConstOne(void)3794 const ExprP<deFloat16> getConstOne<deFloat16>(void)
3795 {
3796 return constant((deFloat16)FLOAT16_1_0);
3797 }
3798
3799 template <>
getConstOne(void)3800 const ExprP<double> getConstOne<double>(void)
3801 {
3802 return constant(1.0);
3803 }
3804
3805 template <>
getConstTwo(void)3806 const ExprP<float> getConstTwo<float>(void)
3807 {
3808 return constant(2.0f);
3809 }
3810
3811 template <>
getConstTwo(void)3812 const ExprP<deFloat16> getConstTwo<deFloat16>(void)
3813 {
3814 return constant((deFloat16)FLOAT16_2_0);
3815 }
3816
3817 template <>
getConstTwo(void)3818 const ExprP<double> getConstTwo<double>(void)
3819 {
3820 return constant(2.0);
3821 }
3822
3823 template <int Size, class T>
3824 class Dot : public DerivedFunc<Signature<T, Vector<T, Size>, Vector<T, Size> > >
3825 {
3826 public:
3827 typedef typename Dot::ArgExprs ArgExprs;
3828
getName(void) const3829 string getName (void) const
3830 {
3831 return "dot";
3832 }
3833
3834 protected:
doExpand(ExpandContext &,const ArgExprs & args) const3835 ExprP<T> doExpand (ExpandContext&, const ArgExprs& args) const
3836 {
3837 ExprP<T> op[Size];
3838 // Precompute all products.
3839 for (int ndx = 0; ndx < Size; ++ndx)
3840 op[ndx] = args.a[ndx] * args.b[ndx];
3841
3842 int idx[Size];
3843 //Prepare an array of indices.
3844 for (int ndx = 0; ndx < Size; ++ndx)
3845 idx[ndx] = ndx;
3846
3847 ExprP<T> res = op[0];
3848 // Compute the first dot alternative: SUM(a[i]*b[i]), i = 0 .. Size-1
3849 for (int ndx = 1; ndx < Size; ++ndx)
3850 res = res + op[ndx];
3851
3852 // Generate all permutations of indices and
3853 // using a permutation compute a dot alternative.
3854 // Generates all possible variants fo summation of products in the dot product expansion expression.
3855 do {
3856 ExprP<T> alt = getConstZero<T>();
3857 for (int ndx = 0; ndx < Size; ++ndx)
3858 alt = alt + op[idx[ndx]];
3859 res = alternatives(res, alt);
3860 } while (std::next_permutation(idx, idx + Size));
3861
3862 return res;
3863 }
3864 };
3865
3866 template <class T>
3867 class Dot<1, T> : public DerivedFunc<Signature<T, T, T> >
3868 {
3869 public:
3870 typedef typename DerivedFunc<Signature<T, T, T> >::ArgExprs TArgExprs;
3871
getName(void) const3872 string getName (void) const
3873 {
3874 return "dot";
3875 }
3876
doExpand(ExpandContext &,const TArgExprs & args) const3877 ExprP<T> doExpand (ExpandContext&, const TArgExprs& args) const
3878 {
3879 return args.a * args.b;
3880 }
3881 };
3882
3883 template <int Size>
dot(const ExprP<Vector<deFloat16,Size>> & x,const ExprP<Vector<deFloat16,Size>> & y)3884 ExprP<deFloat16> dot (const ExprP<Vector<deFloat16, Size> >& x, const ExprP<Vector<deFloat16, Size> >& y)
3885 {
3886 return app<Dot<Size, deFloat16> >(x, y);
3887 }
3888
dot(const ExprP<deFloat16> & x,const ExprP<deFloat16> & y)3889 ExprP<deFloat16> dot (const ExprP<deFloat16>& x, const ExprP<deFloat16>& y)
3890 {
3891 return app<Dot<1, deFloat16> >(x, y);
3892 }
3893
3894 template <int Size>
dot(const ExprP<Vector<float,Size>> & x,const ExprP<Vector<float,Size>> & y)3895 ExprP<float> dot (const ExprP<Vector<float, Size> >& x, const ExprP<Vector<float, Size> >& y)
3896 {
3897 return app<Dot<Size, float> >(x, y);
3898 }
3899
dot(const ExprP<float> & x,const ExprP<float> & y)3900 ExprP<float> dot (const ExprP<float>& x, const ExprP<float>& y)
3901 {
3902 return app<Dot<1, float> >(x, y);
3903 }
3904
3905 template <int Size>
dot(const ExprP<Vector<double,Size>> & x,const ExprP<Vector<double,Size>> & y)3906 ExprP<double> dot (const ExprP<Vector<double, Size> >& x, const ExprP<Vector<double, Size> >& y)
3907 {
3908 return app<Dot<Size, double> >(x, y);
3909 }
3910
dot(const ExprP<double> & x,const ExprP<double> & y)3911 ExprP<double> dot (const ExprP<double>& x, const ExprP<double>& y)
3912 {
3913 return app<Dot<1, double> >(x, y);
3914 }
3915
3916 template <int Size, class T>
3917 class Length : public DerivedFunc<
3918 Signature<T, typename ContainerOf<T, Size>::Container> >
3919 {
3920 public:
3921 typedef typename Length::ArgExprs ArgExprs;
3922
getName(void) const3923 string getName (void) const
3924 {
3925 return "length";
3926 }
3927
3928 protected:
doExpand(ExpandContext &,const ArgExprs & args) const3929 ExprP<T> doExpand (ExpandContext&, const ArgExprs& args) const
3930 {
3931 return sqrt(dot(args.a, args.a));
3932 }
3933 };
3934
3935
3936 template <class T, class TRet>
length(const ExprP<T> & x)3937 ExprP<TRet> length (const ExprP<T>& x)
3938 {
3939 return app<Length<1, T> >(x);
3940 }
3941
3942 template <int Size, class T, class TRet>
length(const ExprP<typename ContainerOf<T,Size>::Container> & x)3943 ExprP<TRet> length (const ExprP<typename ContainerOf<T, Size>::Container>& x)
3944 {
3945 return app<Length<Size, T> >(x);
3946 }
3947
3948 template <int Size, class T>
3949 class Distance : public DerivedFunc<
3950 Signature<T,
3951 typename ContainerOf<T, Size>::Container,
3952 typename ContainerOf<T, Size>::Container> >
3953 {
3954 public:
3955 typedef typename Distance::Ret Ret;
3956 typedef typename Distance::ArgExprs ArgExprs;
3957
getName(void) const3958 string getName (void) const
3959 {
3960 return "distance";
3961 }
3962
3963 protected:
doExpand(ExpandContext &,const ArgExprs & args) const3964 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
3965 {
3966 return length<Size, T, Ret>(args.a - args.b);
3967 }
3968 };
3969
3970 // cross
3971
3972 class Cross : public DerivedFunc<Signature<Vec3, Vec3, Vec3> >
3973 {
3974 public:
getName(void) const3975 string getName (void) const
3976 {
3977 return "cross";
3978 }
3979
3980 protected:
doExpand(ExpandContext &,const ArgExprs & x) const3981 ExprP<Vec3> doExpand (ExpandContext&, const ArgExprs& x) const
3982 {
3983 return vec3(x.a[1] * x.b[2] - x.b[1] * x.a[2],
3984 x.a[2] * x.b[0] - x.b[2] * x.a[0],
3985 x.a[0] * x.b[1] - x.b[0] * x.a[1]);
3986 }
3987 };
3988
3989 class Cross16Bit : public DerivedFunc<Signature<Vec3_16Bit, Vec3_16Bit, Vec3_16Bit> >
3990 {
3991 public:
getName(void) const3992 string getName (void) const
3993 {
3994 return "cross";
3995 }
3996
3997 protected:
doExpand(ExpandContext &,const ArgExprs & x) const3998 ExprP<Vec3_16Bit> doExpand (ExpandContext&, const ArgExprs& x) const
3999 {
4000 return vec3(x.a[1] * x.b[2] - x.b[1] * x.a[2],
4001 x.a[2] * x.b[0] - x.b[2] * x.a[0],
4002 x.a[0] * x.b[1] - x.b[0] * x.a[1]);
4003 }
4004 };
4005
4006 class Cross64Bit : public DerivedFunc<Signature<Vec3_64Bit, Vec3_64Bit, Vec3_64Bit> >
4007 {
4008 public:
getName(void) const4009 string getName (void) const
4010 {
4011 return "cross";
4012 }
4013
4014 protected:
doExpand(ExpandContext &,const ArgExprs & x) const4015 ExprP<Vec3_64Bit> doExpand (ExpandContext&, const ArgExprs& x) const
4016 {
4017 return vec3(x.a[1] * x.b[2] - x.b[1] * x.a[2],
4018 x.a[2] * x.b[0] - x.b[2] * x.a[0],
4019 x.a[0] * x.b[1] - x.b[0] * x.a[1]);
4020 }
4021 };
4022
4023 DEFINE_CONSTRUCTOR2(Cross, Vec3, cross, Vec3, Vec3)
4024 DEFINE_CONSTRUCTOR2(Cross16Bit, Vec3_16Bit, cross, Vec3_16Bit, Vec3_16Bit)
4025 DEFINE_CONSTRUCTOR2(Cross64Bit, Vec3_64Bit, cross, Vec3_64Bit, Vec3_64Bit)
4026
4027 template<int Size, class T>
4028 class Normalize : public DerivedFunc<
4029 Signature<typename ContainerOf<T, Size>::Container,
4030 typename ContainerOf<T, Size>::Container> >
4031 {
4032 public:
4033 typedef typename Normalize::Ret Ret;
4034 typedef typename Normalize::ArgExprs ArgExprs;
4035
getName(void) const4036 string getName (void) const
4037 {
4038 return "normalize";
4039 }
4040
4041 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4042 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
4043 {
4044 return args.a / length<Size, T, T>(args.a);
4045 }
4046 };
4047
4048 template <int Size, class T>
4049 class FaceForward : public DerivedFunc<
4050 Signature<typename ContainerOf<T, Size>::Container,
4051 typename ContainerOf<T, Size>::Container,
4052 typename ContainerOf<T, Size>::Container,
4053 typename ContainerOf<T, Size>::Container> >
4054 {
4055 public:
4056 typedef typename FaceForward::Ret Ret;
4057 typedef typename FaceForward::ArgExprs ArgExprs;
4058
getName(void) const4059 string getName (void) const
4060 {
4061 return "faceforward";
4062 }
4063
4064 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4065 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
4066 {
4067 return cond(dot(args.c, args.b) < getConstZero<T>(), args.a, -args.a);
4068 }
4069 };
4070
4071 template <int Size, class T>
4072 class Reflect : public DerivedFunc<
4073 Signature<typename ContainerOf<T, Size>::Container,
4074 typename ContainerOf<T, Size>::Container,
4075 typename ContainerOf<T, Size>::Container> >
4076 {
4077 public:
4078 typedef typename Reflect::Ret Ret;
4079 typedef typename Reflect::Arg0 Arg0;
4080 typedef typename Reflect::Arg1 Arg1;
4081 typedef typename Reflect::ArgExprs ArgExprs;
4082
getName(void) const4083 string getName (void) const
4084 {
4085 return "reflect";
4086 }
4087
4088 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4089 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
4090 {
4091 const ExprP<Arg0>& i = args.a;
4092 const ExprP<Arg1>& n = args.b;
4093 const ExprP<T> dotNI = bindExpression("dotNI", ctx, dot(n, i));
4094
4095 return i - alternatives((n * dotNI) * getConstTwo<T>(),
4096 alternatives( n * (dotNI * getConstTwo<T>()),
4097 alternatives(n * dot(i * getConstTwo<T>(), n),
4098 n * dot(i, n * getConstTwo<T>())
4099 )
4100 )
4101 );
4102 }
4103 };
4104
4105 template <int Size, class T>
4106 class Refract : public DerivedFunc<
4107 Signature<typename ContainerOf<T, Size>::Container,
4108 typename ContainerOf<T, Size>::Container,
4109 typename ContainerOf<T, Size>::Container,
4110 T> >
4111 {
4112 public:
4113 typedef typename Refract::Ret Ret;
4114 typedef typename Refract::Arg0 Arg0;
4115 typedef typename Refract::Arg1 Arg1;
4116 typedef typename Refract::Arg2 Arg2;
4117 typedef typename Refract::ArgExprs ArgExprs;
4118
getName(void) const4119 string getName (void) const
4120 {
4121 return "refract";
4122 }
4123
4124 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4125 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
4126 {
4127 const ExprP<Arg0>& i = args.a;
4128 const ExprP<Arg1>& n = args.b;
4129 const ExprP<Arg2>& eta = args.c;
4130 const ExprP<T> dotNI = bindExpression("dotNI", ctx, dot(n, i));
4131 const ExprP<T> k = bindExpression("k", ctx, getConstOne<T>() - eta * eta *
4132 (getConstOne<T>() - dotNI * dotNI));
4133 return cond(k < getConstZero<T>(),
4134 genXType<T, Size>(getConstZero<T>()),
4135 i * eta - n * (eta * dotNI + sqrt(k)));
4136 }
4137 };
4138
4139 template <class T>
4140 class PreciseFunc1 : public CFloatFunc1<T>
4141 {
4142 public:
PreciseFunc1(const string & name,DoubleFunc1 & func)4143 PreciseFunc1 (const string& name, DoubleFunc1& func) : CFloatFunc1<T> (name, func) {}
4144 protected:
precision(const EvalContext &,double,double) const4145 double precision (const EvalContext&, double, double) const { return 0.0; }
4146 };
4147
4148 template <class T>
4149 class Abs : public PreciseFunc1<T>
4150 {
4151 public:
Abs(void)4152 Abs (void) : PreciseFunc1<T> ("abs", deAbs) {}
4153 };
4154
4155 template <class T>
4156 class Sign : public PreciseFunc1<T>
4157 {
4158 public:
Sign(void)4159 Sign (void) : PreciseFunc1<T> ("sign", deSign) {}
4160 };
4161
4162 template <class T>
4163 class Floor : public PreciseFunc1<T>
4164 {
4165 public:
Floor(void)4166 Floor (void) : PreciseFunc1<T> ("floor", deFloor) {}
4167 };
4168
4169 template <class T>
4170 class Trunc : public PreciseFunc1<T>
4171 {
4172 public:
Trunc(void)4173 Trunc (void) : PreciseFunc1<T> ("trunc", deTrunc) {}
4174 };
4175
4176 template <class T>
4177 class Round : public FloatFunc1<T>
4178 {
4179 public:
getName(void) const4180 string getName (void) const { return "round"; }
4181
4182 protected:
applyPoint(const EvalContext &,double x) const4183 Interval applyPoint (const EvalContext&, double x) const
4184 {
4185 double truncated = 0.0;
4186 const double fract = deModf(x, &truncated);
4187 Interval ret;
4188
4189 if (fabs(fract) <= 0.5)
4190 ret |= truncated;
4191 if (fabs(fract) >= 0.5)
4192 ret |= truncated + deSign(fract);
4193
4194 return ret;
4195 }
4196
precision(const EvalContext &,double,double) const4197 double precision (const EvalContext&, double, double) const { return 0.0; }
4198 };
4199
4200 template <class T>
4201 class RoundEven : public PreciseFunc1<T>
4202 {
4203 public:
RoundEven(void)4204 RoundEven (void) : PreciseFunc1<T> ("roundEven", deRoundEven) {}
4205 };
4206
4207 template <class T>
4208 class Ceil : public PreciseFunc1<T>
4209 {
4210 public:
Ceil(void)4211 Ceil (void) : PreciseFunc1<T> ("ceil", deCeil) {}
4212 };
4213
4214 typedef Floor< Signature<float, float> > Floor32Bit;
4215 typedef Floor< Signature<deFloat16, deFloat16> > Floor16Bit;
4216 typedef Floor< Signature<double, double> > Floor64Bit;
4217
4218 typedef Trunc< Signature<float, float> > Trunc32Bit;
4219 typedef Trunc< Signature<deFloat16, deFloat16> > Trunc16Bit;
4220 typedef Trunc< Signature<double, double> > Trunc64Bit;
4221
4222 typedef Trunc< Signature<float, float> > Trunc32Bit;
4223 typedef Trunc< Signature<deFloat16, deFloat16> > Trunc16Bit;
4224
4225 DEFINE_DERIVED_FLOAT1(Fract, fract, x, x - app<Floor32Bit>(x));
4226 DEFINE_DERIVED_FLOAT1_16BIT(Fract16Bit, fract, x, x - app<Floor16Bit>(x));
4227 DEFINE_DERIVED_DOUBLE1(Fract64Bit, fract, x, x - app<Floor64Bit>(x));
4228
4229 template <class T>
4230 class PreciseFunc2 : public CFloatFunc2<T>
4231 {
4232 public:
PreciseFunc2(const string & name,DoubleFunc2 & func)4233 PreciseFunc2 (const string& name, DoubleFunc2& func) : CFloatFunc2<T> (name, func) {}
4234 protected:
precision(const EvalContext &,double,double,double) const4235 double precision (const EvalContext&, double, double, double) const { return 0.0; }
4236 };
4237
4238 DEFINE_DERIVED_FLOAT2(Mod32Bit, mod, x, y, x - y * app<Floor32Bit>(x / y));
4239 DEFINE_DERIVED_FLOAT2_16BIT(Mod16Bit, mod, x, y, x - y * app<Floor16Bit>(x / y));
4240 DEFINE_DERIVED_DOUBLE2(Mod64Bit, mod, x, y, x - y * app<Floor64Bit>(x / y));
4241
4242 DEFINE_CASED_DERIVED_FLOAT2(FRem32Bit, frem, x, y, x - y * app<Trunc32Bit>(x / y), SPIRV_CASETYPE_FREM);
4243 DEFINE_CASED_DERIVED_FLOAT2_16BIT(FRem16Bit, frem, x, y, x - y * app<Trunc16Bit>(x / y), SPIRV_CASETYPE_FREM);
4244 DEFINE_CASED_DERIVED_DOUBLE2(FRem64Bit, frem, x, y, x - y * app<Trunc64Bit>(x / y), SPIRV_CASETYPE_FREM);
4245
4246 template <class T>
4247 class Modf : public PrimitiveFunc<T>
4248 {
4249 public:
4250 typedef typename Modf<T>::IArgs TIArgs;
4251 typedef typename Modf<T>::IRet TIRet;
getName(void) const4252 string getName (void) const
4253 {
4254 return "modf";
4255 }
4256
4257 protected:
doApply(const EvalContext &,const TIArgs & iargs) const4258 TIRet doApply (const EvalContext&, const TIArgs& iargs) const
4259 {
4260 Interval fracIV;
4261 Interval& wholeIV = const_cast<Interval&>(iargs.b);
4262 double intPart = 0;
4263
4264 TCU_INTERVAL_APPLY_MONOTONE1(fracIV, x, iargs.a, frac, frac = deModf(x, &intPart));
4265 TCU_INTERVAL_APPLY_MONOTONE1(wholeIV, x, iargs.a, whole,
4266 deModf(x, &intPart); whole = intPart);
4267
4268 if (!iargs.a.isFinite())
4269 {
4270 // Behavior on modf(Inf) not well-defined, allow anything as a fractional part
4271 // See Khronos bug 13907
4272 fracIV |= TCU_NAN;
4273 }
4274
4275 return fracIV;
4276 }
4277
getOutParamIndex(void) const4278 int getOutParamIndex (void) const
4279 {
4280 return 1;
4281 }
4282 };
4283 typedef Modf< Signature<float, float, float> > Modf32Bit;
4284 typedef Modf< Signature<deFloat16, deFloat16, deFloat16> > Modf16Bit;
4285 typedef Modf< Signature<double, double, double> > Modf64Bit;
4286
4287 template <class T>
4288 class ModfStruct : public Modf<T>
4289 {
4290 public:
getName(void) const4291 virtual string getName (void) const { return "modfstruct"; }
getSpirvCase(void) const4292 virtual SpirVCaseT getSpirvCase (void) const { return SPIRV_CASETYPE_MODFSTRUCT; }
4293 };
4294 typedef ModfStruct< Signature<float, float, float> > ModfStruct32Bit;
4295 typedef ModfStruct< Signature<deFloat16, deFloat16, deFloat16> > ModfStruct16Bit;
4296 typedef ModfStruct< Signature<double, double, double> > ModfStruct64Bit;
4297
4298 template <class T>
Min(void)4299 class Min : public PreciseFunc2<T> { public: Min (void) : PreciseFunc2<T> ("min", deMin) {} };
4300 template <class T>
Max(void)4301 class Max : public PreciseFunc2<T> { public: Max (void) : PreciseFunc2<T> ("max", deMax) {} };
4302
4303 template <class T>
4304 class Clamp : public FloatFunc3<T>
4305 {
4306 public:
getName(void) const4307 string getName (void) const { return "clamp"; }
4308
applyExact(double x,double minVal,double maxVal) const4309 double applyExact (double x, double minVal, double maxVal) const
4310 {
4311 return de::min(de::max(x, minVal), maxVal);
4312 }
4313
precision(const EvalContext &,double,double,double minVal,double maxVal) const4314 double precision (const EvalContext&, double, double, double minVal, double maxVal) const
4315 {
4316 return minVal > maxVal ? TCU_NAN : 0.0;
4317 }
4318 };
4319
clamp(const ExprP<deFloat16> & x,const ExprP<deFloat16> & minVal,const ExprP<deFloat16> & maxVal)4320 ExprP<deFloat16> clamp(const ExprP<deFloat16>& x, const ExprP<deFloat16>& minVal, const ExprP<deFloat16>& maxVal)
4321 {
4322 return app<Clamp< Signature<deFloat16, deFloat16, deFloat16, deFloat16> > >(x, minVal, maxVal);
4323 }
4324
clamp(const ExprP<float> & x,const ExprP<float> & minVal,const ExprP<float> & maxVal)4325 ExprP<float> clamp(const ExprP<float>& x, const ExprP<float>& minVal, const ExprP<float>& maxVal)
4326 {
4327 return app<Clamp< Signature<float, float, float, float> > >(x, minVal, maxVal);
4328 }
4329
clamp(const ExprP<double> & x,const ExprP<double> & minVal,const ExprP<double> & maxVal)4330 ExprP<double> clamp(const ExprP<double>& x, const ExprP<double>& minVal, const ExprP<double>& maxVal)
4331 {
4332 return app<Clamp< Signature<double, double, double, double> > >(x, minVal, maxVal);
4333 }
4334
4335 template <class T>
4336 class NanIfGreaterOrEqual : public FloatFunc2<T>
4337 {
4338 public:
getName(void) const4339 string getName (void) const { return "nanIfGreaterOrEqual"; }
4340
applyExact(double edge0,double edge1) const4341 double applyExact (double edge0, double edge1) const
4342 {
4343 return (edge0 >= edge1) ? TCU_NAN : 0.0;
4344 }
4345
precision(const EvalContext &,double,double edge0,double edge1) const4346 double precision (const EvalContext&, double, double edge0, double edge1) const
4347 {
4348 return (edge0 >= edge1) ? TCU_NAN : 0.0;
4349 }
4350 };
4351
nanIfGreaterOrEqual(const ExprP<deFloat16> & edge0,const ExprP<deFloat16> & edge1)4352 ExprP<deFloat16> nanIfGreaterOrEqual(const ExprP<deFloat16>& edge0, const ExprP<deFloat16>& edge1)
4353 {
4354 return app<NanIfGreaterOrEqual< Signature<deFloat16, deFloat16, deFloat16> > >(edge0, edge1);
4355 }
4356
nanIfGreaterOrEqual(const ExprP<float> & edge0,const ExprP<float> & edge1)4357 ExprP<float> nanIfGreaterOrEqual(const ExprP<float>& edge0, const ExprP<float>& edge1)
4358 {
4359 return app<NanIfGreaterOrEqual< Signature<float, float, float> > >(edge0, edge1);
4360 }
4361
nanIfGreaterOrEqual(const ExprP<double> & edge0,const ExprP<double> & edge1)4362 ExprP<double> nanIfGreaterOrEqual(const ExprP<double>& edge0, const ExprP<double>& edge1)
4363 {
4364 return app<NanIfGreaterOrEqual< Signature<double, double, double> > >(edge0, edge1);
4365 }
4366
4367 DEFINE_DERIVED_FLOAT3(Mix, mix, x, y, a, alternatives((x * (constant(1.0f) - a)) + y * a,
4368 x + (y - x) * a));
4369
4370 DEFINE_DERIVED_FLOAT3_16BIT(Mix16Bit, mix, x, y, a, alternatives((x * (constant((deFloat16)FLOAT16_1_0) - a)) + y * a,
4371 x + (y - x) * a));
4372
4373 DEFINE_DERIVED_DOUBLE3(Mix64Bit, mix, x, y, a, alternatives((x * (constant(1.0) - a)) + y * a,
4374 x + (y - x) * a));
4375
step(double edge,double x)4376 static double step (double edge, double x)
4377 {
4378 return x < edge ? 0.0 : 1.0;
4379 }
4380
4381 template <class T>
Step(void)4382 class Step : public PreciseFunc2<T> { public: Step (void) : PreciseFunc2<T> ("step", step) {} };
4383
4384 template <class T>
4385 class SmoothStep : public DerivedFunc<T>
4386 {
4387 public:
4388 typedef typename SmoothStep<T>::ArgExprs TArgExprs;
4389 typedef typename SmoothStep<T>::Ret TRet;
getName(void) const4390 string getName (void) const
4391 {
4392 return "smoothstep";
4393 }
4394
4395 protected:
4396
4397 ExprP<TRet> doExpand (ExpandContext& ctx, const TArgExprs& args) const;
4398 };
4399
4400 template<>
doExpand(ExpandContext & ctx,const SmoothStep<Signature<float,float,float,float>>::ArgExprs & args) const4401 ExprP<SmoothStep< Signature<float, float, float, float> >::Ret> SmoothStep< Signature<float, float, float, float> >::doExpand (ExpandContext& ctx, const SmoothStep< Signature<float, float, float, float> >::ArgExprs& args) const
4402 {
4403 const ExprP<float>& edge0 = args.a;
4404 const ExprP<float>& edge1 = args.b;
4405 const ExprP<float>& x = args.c;
4406 const ExprP<float> tExpr = clamp((x - edge0) / (edge1 - edge0), constant(0.0f), constant(1.0f))
4407 + nanIfGreaterOrEqual(edge0, edge1); // force NaN (and non-analyzable result) for cases edge0 >= edge1
4408 const ExprP<float> t = bindExpression("t", ctx, tExpr);
4409
4410 return (t * t * (constant(3.0f) - constant(2.0f) * t));
4411 }
4412
4413 template<>
doExpand(ExpandContext & ctx,const TArgExprs & args) const4414 ExprP<SmoothStep< Signature<deFloat16, deFloat16, deFloat16, deFloat16> >::TRet> SmoothStep< Signature<deFloat16, deFloat16, deFloat16, deFloat16> >::doExpand (ExpandContext& ctx, const TArgExprs& args) const
4415 {
4416 const ExprP<deFloat16>& edge0 = args.a;
4417 const ExprP<deFloat16>& edge1 = args.b;
4418 const ExprP<deFloat16>& x = args.c;
4419 const ExprP<deFloat16> tExpr = clamp(( x - edge0 ) / ( edge1 - edge0 ),
4420 constant((deFloat16)FLOAT16_0_0), constant((deFloat16)FLOAT16_1_0))
4421 + nanIfGreaterOrEqual(edge0, edge1); // force NaN (and non-analyzable result) for cases edge0 >= edge1
4422 const ExprP<deFloat16> t = bindExpression("t", ctx, tExpr);
4423
4424 return (t * t * (constant((deFloat16)FLOAT16_3_0) - constant((deFloat16)FLOAT16_2_0) * t));
4425 }
4426
4427 template<>
doExpand(ExpandContext & ctx,const SmoothStep<Signature<double,double,double,double>>::ArgExprs & args) const4428 ExprP<SmoothStep< Signature<double, double, double, double> >::Ret> SmoothStep< Signature<double, double, double, double> >::doExpand (ExpandContext& ctx, const SmoothStep< Signature<double, double, double, double> >::ArgExprs& args) const
4429 {
4430 const ExprP<double>& edge0 = args.a;
4431 const ExprP<double>& edge1 = args.b;
4432 const ExprP<double>& x = args.c;
4433 const ExprP<double> tExpr = clamp((x - edge0) / (edge1 - edge0), constant(0.0), constant(1.0))
4434 + nanIfGreaterOrEqual(edge0, edge1); // force NaN (and non-analyzable result) for cases edge0 >= edge1
4435 const ExprP<double> t = bindExpression("t", ctx, tExpr);
4436
4437 return (t * t * (constant(3.0) - constant(2.0) * t));
4438 }
4439
4440 //Signature<float, float, int>
4441 //Signature<float, deFloat16, int>
4442 //Signature<double, double, int>
4443 template <class T>
4444 class FrExp : public PrimitiveFunc<T>
4445 {
4446 public:
getName(void) const4447 string getName (void) const
4448 {
4449 return "frexp";
4450 }
4451
4452 typedef typename FrExp::IRet IRet;
4453 typedef typename FrExp::IArgs IArgs;
4454 typedef typename FrExp::IArg0 IArg0;
4455 typedef typename FrExp::IArg1 IArg1;
4456
4457 protected:
doApply(const EvalContext &,const IArgs & iargs) const4458 IRet doApply (const EvalContext&, const IArgs& iargs) const
4459 {
4460 IRet ret;
4461 const IArg0& x = iargs.a;
4462 IArg1& exponent = const_cast<IArg1&>(iargs.b);
4463
4464 if (x.hasNaN() || x.contains(TCU_INFINITY) || x.contains(-TCU_INFINITY))
4465 {
4466 // GLSL (in contrast to IEEE) says that result of applying frexp
4467 // to infinity is undefined
4468 ret = Interval::unbounded() | TCU_NAN;
4469 exponent = Interval(-deLdExp(1.0, 31), deLdExp(1.0, 31)-1);
4470 }
4471 else if (!x.empty())
4472 {
4473 int loExp = 0;
4474 const double loFrac = deFrExp(x.lo(), &loExp);
4475 int hiExp = 0;
4476 const double hiFrac = deFrExp(x.hi(), &hiExp);
4477
4478 if (deSign(loFrac) != deSign(hiFrac))
4479 {
4480 exponent = Interval(-TCU_INFINITY, de::max(loExp, hiExp));
4481 ret = Interval();
4482 if (deSign(loFrac) < 0)
4483 ret |= Interval(-1.0 + DBL_EPSILON*0.5, 0.0);
4484 if (deSign(hiFrac) > 0)
4485 ret |= Interval(0.0, 1.0 - DBL_EPSILON*0.5);
4486 }
4487 else
4488 {
4489 exponent = Interval(loExp, hiExp);
4490 if (loExp == hiExp)
4491 ret = Interval(loFrac, hiFrac);
4492 else
4493 ret = deSign(loFrac) * Interval(0.5, 1.0 - DBL_EPSILON*0.5);
4494 }
4495 }
4496
4497 return ret;
4498 }
4499
getOutParamIndex(void) const4500 int getOutParamIndex (void) const
4501 {
4502 return 1;
4503 }
4504 };
4505 typedef FrExp< Signature<float, float, int> > Frexp32Bit;
4506 typedef FrExp< Signature<deFloat16, deFloat16, int> > Frexp16Bit;
4507 typedef FrExp< Signature<double, double, int> > Frexp64Bit;
4508
4509 template <class T>
4510 class FrexpStruct : public FrExp<T>
4511 {
4512 public:
getName(void) const4513 virtual string getName (void) const { return "frexpstruct"; }
getSpirvCase(void) const4514 virtual SpirVCaseT getSpirvCase (void) const { return SPIRV_CASETYPE_FREXPSTRUCT; }
4515 };
4516 typedef FrexpStruct< Signature<float, float, int> > FrexpStruct32Bit;
4517 typedef FrexpStruct< Signature<deFloat16, deFloat16, int> > FrexpStruct16Bit;
4518 typedef FrexpStruct< Signature<double, double, int> > FrexpStruct64Bit;
4519
4520 //Signature<float, float, int>
4521 //Signature<deFloat16, deFloat16, int>
4522 //Signature<double, double, int>
4523 template <class T>
4524 class LdExp : public PrimitiveFunc<T >
4525 {
4526 public:
4527 typedef typename LdExp::IRet IRet;
4528 typedef typename LdExp::IArgs IArgs;
4529
getName(void) const4530 string getName (void) const
4531 {
4532 return "ldexp";
4533 }
4534
4535 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const4536 Interval doApply (const EvalContext& ctx, const IArgs& iargs) const
4537 {
4538 const int minExp = ctx.format.getMinExp();
4539 const int maxExp = ctx.format.getMaxExp();
4540 // Restrictions from the GLSL.std.450 instruction set.
4541 // See Khronos bugzilla 11180 for rationale.
4542 bool any = iargs.a.hasNaN() || iargs.b.hi() > (maxExp + 1);
4543 Interval ret(any, ldexp(iargs.a.lo(), (int)iargs.b.lo()), ldexp(iargs.a.hi(), (int)iargs.b.hi()));
4544 if (iargs.b.lo() < minExp) ret |= 0.0;
4545 if (!ret.isFinite()) ret |= TCU_NAN;
4546 return ctx.format.convert(ret);
4547 }
4548 };
4549
4550 template <>
doApply(const EvalContext & ctx,const IArgs & iargs) const4551 Interval LdExp <Signature<double, double, int>>::doApply(const EvalContext& ctx, const IArgs& iargs) const
4552 {
4553 const int minExp = ctx.format.getMinExp();
4554 const int maxExp = ctx.format.getMaxExp();
4555 // Restrictions from the GLSL.std.450 instruction set.
4556 // See Khronos bugzilla 11180 for rationale.
4557 bool any = iargs.a.hasNaN() || iargs.b.hi() > (maxExp + 1);
4558 Interval ret(any, ldexp(iargs.a.lo(), (int)iargs.b.lo()), ldexp(iargs.a.hi(), (int)iargs.b.hi()));
4559 // Add 1ULP precision tolerance to account for differing rounding modes between the GPU and deLdExp.
4560 ret += Interval(-ctx.format.ulp(ret.lo()), ctx.format.ulp(ret.hi()));
4561 if (iargs.b.lo() < minExp) ret |= 0.0;
4562 if (!ret.isFinite()) ret |= TCU_NAN;
4563 return ctx.format.convert(ret);
4564 }
4565
4566 template<int Rows, int Columns, class T>
4567 class Transpose : public PrimitiveFunc<Signature<Matrix<T, Rows, Columns>,
4568 Matrix<T, Columns, Rows> > >
4569 {
4570 public:
4571 typedef typename Transpose::IRet IRet;
4572 typedef typename Transpose::IArgs IArgs;
4573
getName(void) const4574 string getName (void) const
4575 {
4576 return "transpose";
4577 }
4578
4579 protected:
doApply(const EvalContext &,const IArgs & iargs) const4580 IRet doApply (const EvalContext&, const IArgs& iargs) const
4581 {
4582 IRet ret;
4583
4584 for (int rowNdx = 0; rowNdx < Rows; ++rowNdx)
4585 {
4586 for (int colNdx = 0; colNdx < Columns; ++colNdx)
4587 ret(rowNdx, colNdx) = iargs.a(colNdx, rowNdx);
4588 }
4589
4590 return ret;
4591 }
4592 };
4593
4594 template<typename Ret, typename Arg0, typename Arg1>
4595 class MulFunc : public PrimitiveFunc<Signature<Ret, Arg0, Arg1> >
4596 {
4597 public:
getName(void) const4598 string getName (void) const { return "mul"; }
4599
4600 protected:
doPrint(ostream & os,const BaseArgExprs & args) const4601 void doPrint (ostream& os, const BaseArgExprs& args) const
4602 {
4603 os << "(" << *args[0] << " * " << *args[1] << ")";
4604 }
4605 };
4606
4607 template<typename T, int LeftRows, int Middle, int RightCols>
4608 class MatMul : public MulFunc<Matrix<T, LeftRows, RightCols>,
4609 Matrix<T, LeftRows, Middle>,
4610 Matrix<T, Middle, RightCols> >
4611 {
4612 protected:
4613 typedef typename MatMul::IRet IRet;
4614 typedef typename MatMul::IArgs IArgs;
4615 typedef typename MatMul::IArg0 IArg0;
4616 typedef typename MatMul::IArg1 IArg1;
4617
doApply(const EvalContext & ctx,const IArgs & iargs) const4618 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
4619 {
4620 const IArg0& left = iargs.a;
4621 const IArg1& right = iargs.b;
4622 IRet ret;
4623
4624 for (int row = 0; row < LeftRows; ++row)
4625 {
4626 for (int col = 0; col < RightCols; ++col)
4627 {
4628 Interval element (0.0);
4629
4630 for (int ndx = 0; ndx < Middle; ++ndx)
4631 element = call<Add< Signature<T, T, T> > >(ctx, element,
4632 call<Mul< Signature<T, T, T> > >(ctx, left[ndx][row], right[col][ndx]));
4633
4634 ret[col][row] = element;
4635 }
4636 }
4637
4638 return ret;
4639 }
4640 };
4641
4642 template<typename T, int Rows, int Cols>
4643 class VecMatMul : public MulFunc<Vector<T, Cols>,
4644 Vector<T, Rows>,
4645 Matrix<T, Rows, Cols> >
4646 {
4647 public:
4648 typedef typename VecMatMul::IRet IRet;
4649 typedef typename VecMatMul::IArgs IArgs;
4650 typedef typename VecMatMul::IArg0 IArg0;
4651 typedef typename VecMatMul::IArg1 IArg1;
4652
4653 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const4654 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
4655 {
4656 const IArg0& left = iargs.a;
4657 const IArg1& right = iargs.b;
4658 IRet ret;
4659
4660 for (int col = 0; col < Cols; ++col)
4661 {
4662 Interval element (0.0);
4663
4664 for (int row = 0; row < Rows; ++row)
4665 element = call<Add< Signature<T, T, T> > >(ctx, element, call<Mul< Signature<T, T, T> > >(ctx, left[row], right[col][row]));
4666
4667 ret[col] = element;
4668 }
4669
4670 return ret;
4671 }
4672 };
4673
4674 template<int Rows, int Cols, class T>
4675 class MatVecMul : public MulFunc<Vector<T, Rows>,
4676 Matrix<T, Rows, Cols>,
4677 Vector<T, Cols> >
4678 {
4679 public:
4680 typedef typename MatVecMul::IRet IRet;
4681 typedef typename MatVecMul::IArgs IArgs;
4682 typedef typename MatVecMul::IArg0 IArg0;
4683 typedef typename MatVecMul::IArg1 IArg1;
4684
4685 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const4686 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
4687 {
4688 const IArg0& left = iargs.a;
4689 const IArg1& right = iargs.b;
4690
4691 return call<VecMatMul<T, Cols, Rows> >(ctx, right,
4692 call<Transpose<Rows, Cols, T> >(ctx, left));
4693 }
4694 };
4695
4696 template<int Rows, int Cols, class T>
4697 class OuterProduct : public PrimitiveFunc<Signature<Matrix<T, Rows, Cols>,
4698 Vector<T, Rows>,
4699 Vector<T, Cols> > >
4700 {
4701 public:
4702 typedef typename OuterProduct::IRet IRet;
4703 typedef typename OuterProduct::IArgs IArgs;
4704
getName(void) const4705 string getName (void) const
4706 {
4707 return "outerProduct";
4708 }
4709
4710 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const4711 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
4712 {
4713 IRet ret;
4714
4715 for (int row = 0; row < Rows; ++row)
4716 {
4717 for (int col = 0; col < Cols; ++col)
4718 ret[col][row] = call<Mul< Signature<T, T, T> > >(ctx, iargs.a[row], iargs.b[col]);
4719 }
4720
4721 return ret;
4722 }
4723 };
4724
4725 template<int Rows, int Cols, class T>
outerProduct(const ExprP<Vector<T,Rows>> & left,const ExprP<Vector<T,Cols>> & right)4726 ExprP<Matrix<T, Rows, Cols> > outerProduct (const ExprP<Vector<T, Rows> >& left,
4727 const ExprP<Vector<T, Cols> >& right)
4728 {
4729 return app<OuterProduct<Rows, Cols, T> >(left, right);
4730 }
4731
4732 template<class T>
4733 class DeterminantBase : public DerivedFunc<T>
4734 {
4735 public:
getName(void) const4736 string getName (void) const { return "determinant"; }
4737 };
4738
4739 template<int Size> class Determinant;
4740 template<int Size> class Determinant16bit;
4741 template<int Size> class Determinant64bit;
4742
4743 template<int Size>
determinant(ExprP<Matrix<float,Size,Size>> mat)4744 ExprP<float> determinant (ExprP<Matrix<float, Size, Size> > mat)
4745 {
4746 return app<Determinant<Size> >(mat);
4747 }
4748
4749 template<int Size>
determinant(ExprP<Matrix<deFloat16,Size,Size>> mat)4750 ExprP<deFloat16> determinant (ExprP<Matrix<deFloat16, Size, Size> > mat)
4751 {
4752 return app<Determinant16bit<Size> >(mat);
4753 }
4754
4755 template<int Size>
determinant(ExprP<Matrix<double,Size,Size>> mat)4756 ExprP<double> determinant (ExprP<Matrix<double, Size, Size> > mat)
4757 {
4758 return app<Determinant64bit<Size> >(mat);
4759 }
4760
4761 template<>
4762 class Determinant<2> : public DeterminantBase<Signature<float, Matrix<float, 2, 2> > >
4763 {
4764 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4765 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
4766 {
4767 ExprP<Mat2> mat = args.a;
4768
4769 return mat[0][0] * mat[1][1] - mat[1][0] * mat[0][1];
4770 }
4771 };
4772
4773 template<>
4774 class Determinant<3> : public DeterminantBase<Signature<float, Matrix<float, 3, 3> > >
4775 {
4776 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4777 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
4778 {
4779 ExprP<Mat3> mat = args.a;
4780
4781 return (mat[0][0] * (mat[1][1] * mat[2][2] - mat[1][2] * mat[2][1]) +
4782 mat[0][1] * (mat[1][2] * mat[2][0] - mat[1][0] * mat[2][2]) +
4783 mat[0][2] * (mat[1][0] * mat[2][1] - mat[1][1] * mat[2][0]));
4784 }
4785 };
4786
4787 template<>
4788 class Determinant<4> : public DeterminantBase<Signature<float, Matrix<float, 4, 4> > >
4789 {
4790 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4791 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
4792 {
4793 ExprP<Mat4> mat = args.a;
4794 ExprP<Mat3> minors[4];
4795
4796 for (int ndx = 0; ndx < 4; ++ndx)
4797 {
4798 ExprP<Vec4> minorColumns[3];
4799 ExprP<Vec3> columns[3];
4800
4801 for (int col = 0; col < 3; ++col)
4802 minorColumns[col] = mat[col < ndx ? col : col + 1];
4803
4804 for (int col = 0; col < 3; ++col)
4805 columns[col] = vec3(minorColumns[0][col+1],
4806 minorColumns[1][col+1],
4807 minorColumns[2][col+1]);
4808
4809 minors[ndx] = bindExpression("minor", ctx,
4810 mat3(columns[0], columns[1], columns[2]));
4811 }
4812
4813 return (mat[0][0] * determinant(minors[0]) -
4814 mat[1][0] * determinant(minors[1]) +
4815 mat[2][0] * determinant(minors[2]) -
4816 mat[3][0] * determinant(minors[3]));
4817 }
4818 };
4819
4820 template<>
4821 class Determinant16bit<2> : public DeterminantBase<Signature<deFloat16, Matrix<deFloat16, 2, 2> > >
4822 {
4823 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4824 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
4825 {
4826 ExprP<Mat2_16b> mat = args.a;
4827
4828 return mat[0][0] * mat[1][1] - mat[1][0] * mat[0][1];
4829 }
4830 };
4831
4832 template<>
4833 class Determinant16bit<3> : public DeterminantBase<Signature<deFloat16, Matrix<deFloat16, 3, 3> > >
4834 {
4835 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4836 ExprP<Ret> doExpand(ExpandContext&, const ArgExprs& args) const
4837 {
4838 ExprP<Mat3_16b> mat = args.a;
4839
4840 return (mat[0][0] * (mat[1][1] * mat[2][2] - mat[1][2] * mat[2][1]) +
4841 mat[0][1] * (mat[1][2] * mat[2][0] - mat[1][0] * mat[2][2]) +
4842 mat[0][2] * (mat[1][0] * mat[2][1] - mat[1][1] * mat[2][0]));
4843 }
4844 };
4845
4846 template<>
4847 class Determinant16bit<4> : public DeterminantBase<Signature<deFloat16, Matrix<deFloat16, 4, 4> > >
4848 {
4849 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4850 ExprP<Ret> doExpand(ExpandContext& ctx, const ArgExprs& args) const
4851 {
4852 ExprP<Mat4_16b> mat = args.a;
4853 ExprP<Mat3_16b> minors[4];
4854
4855 for (int ndx = 0; ndx < 4; ++ndx)
4856 {
4857 ExprP<Vec4_16Bit> minorColumns[3];
4858 ExprP<Vec3_16Bit> columns[3];
4859
4860 for (int col = 0; col < 3; ++col)
4861 minorColumns[col] = mat[col < ndx ? col : col + 1];
4862
4863 for (int col = 0; col < 3; ++col)
4864 columns[col] = vec3(minorColumns[0][col + 1],
4865 minorColumns[1][col + 1],
4866 minorColumns[2][col + 1]);
4867
4868 minors[ndx] = bindExpression("minor", ctx,
4869 mat3(columns[0], columns[1], columns[2]));
4870 }
4871
4872 return (mat[0][0] * determinant(minors[0]) -
4873 mat[1][0] * determinant(minors[1]) +
4874 mat[2][0] * determinant(minors[2]) -
4875 mat[3][0] * determinant(minors[3]));
4876 }
4877 };
4878
4879 template<>
4880 class Determinant64bit<2> : public DeterminantBase<Signature<double, Matrix<double, 2, 2> > >
4881 {
4882 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4883 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
4884 {
4885 ExprP<Matrix2d> mat = args.a;
4886
4887 return mat[0][0] * mat[1][1] - mat[1][0] * mat[0][1];
4888 }
4889 };
4890
4891 template<>
4892 class Determinant64bit<3> : public DeterminantBase<Signature<double, Matrix<double, 3, 3> > >
4893 {
4894 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4895 ExprP<Ret> doExpand(ExpandContext&, const ArgExprs& args) const
4896 {
4897 ExprP<Matrix3d> mat = args.a;
4898
4899 return (mat[0][0] * (mat[1][1] * mat[2][2] - mat[1][2] * mat[2][1]) +
4900 mat[0][1] * (mat[1][2] * mat[2][0] - mat[1][0] * mat[2][2]) +
4901 mat[0][2] * (mat[1][0] * mat[2][1] - mat[1][1] * mat[2][0]));
4902 }
4903 };
4904
4905 template<>
4906 class Determinant64bit<4> : public DeterminantBase<Signature<double, Matrix<double, 4, 4> > >
4907 {
4908 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4909 ExprP<Ret> doExpand(ExpandContext& ctx, const ArgExprs& args) const
4910 {
4911 ExprP<Matrix4d> mat = args.a;
4912 ExprP<Matrix3d> minors[4];
4913
4914 for (int ndx = 0; ndx < 4; ++ndx)
4915 {
4916 ExprP<Vec4_64Bit> minorColumns[3];
4917 ExprP<Vec3_64Bit> columns[3];
4918
4919 for (int col = 0; col < 3; ++col)
4920 minorColumns[col] = mat[col < ndx ? col : col + 1];
4921
4922 for (int col = 0; col < 3; ++col)
4923 columns[col] = vec3(minorColumns[0][col + 1],
4924 minorColumns[1][col + 1],
4925 minorColumns[2][col + 1]);
4926
4927 minors[ndx] = bindExpression("minor", ctx,
4928 mat3(columns[0], columns[1], columns[2]));
4929 }
4930
4931 return (mat[0][0] * determinant(minors[0]) -
4932 mat[1][0] * determinant(minors[1]) +
4933 mat[2][0] * determinant(minors[2]) -
4934 mat[3][0] * determinant(minors[3]));
4935 }
4936 };
4937
4938 template<int Size> class Inverse;
4939
4940 template <int Size>
inverse(ExprP<Matrix<float,Size,Size>> mat)4941 ExprP<Matrix<float, Size, Size> > inverse (ExprP<Matrix<float, Size, Size> > mat)
4942 {
4943 return app<Inverse<Size> >(mat);
4944 }
4945
4946 template<int Size> class Inverse16bit;
4947
4948 template <int Size>
inverse(ExprP<Matrix<deFloat16,Size,Size>> mat)4949 ExprP<Matrix<deFloat16, Size, Size> > inverse (ExprP<Matrix<deFloat16, Size, Size> > mat)
4950 {
4951 return app<Inverse16bit<Size> >(mat);
4952 }
4953
4954 template<int Size> class Inverse64bit;
4955
4956 template <int Size>
inverse(ExprP<Matrix<double,Size,Size>> mat)4957 ExprP<Matrix<double, Size, Size> > inverse (ExprP<Matrix<double, Size, Size> > mat)
4958 {
4959 return app<Inverse64bit<Size> >(mat);
4960 }
4961
4962 template<>
4963 class Inverse<2> : public DerivedFunc<Signature<Mat2, Mat2> >
4964 {
4965 public:
getName(void) const4966 string getName (void) const
4967 {
4968 return "inverse";
4969 }
4970
4971 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4972 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
4973 {
4974 ExprP<Mat2> mat = args.a;
4975 ExprP<float> det = bindExpression("det", ctx, determinant(mat));
4976
4977 return mat2(vec2(mat[1][1] / det, -mat[0][1] / det),
4978 vec2(-mat[1][0] / det, mat[0][0] / det));
4979 }
4980 };
4981
4982 template<>
4983 class Inverse<3> : public DerivedFunc<Signature<Mat3, Mat3> >
4984 {
4985 public:
getName(void) const4986 string getName (void) const
4987 {
4988 return "inverse";
4989 }
4990
4991 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4992 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
4993 {
4994 ExprP<Mat3> mat = args.a;
4995 ExprP<Mat2> invA = bindExpression("invA", ctx,
4996 inverse(mat2(vec2(mat[0][0], mat[0][1]),
4997 vec2(mat[1][0], mat[1][1]))));
4998
4999 ExprP<Vec2> matB = bindExpression("matB", ctx, vec2(mat[2][0], mat[2][1]));
5000 ExprP<Vec2> matC = bindExpression("matC", ctx, vec2(mat[0][2], mat[1][2]));
5001 ExprP<float> matD = bindExpression("matD", ctx, mat[2][2]);
5002
5003 ExprP<float> schur = bindExpression("schur", ctx,
5004 constant(1.0f) /
5005 (matD - dot(matC * invA, matB)));
5006
5007 ExprP<Vec2> t1 = invA * matB;
5008 ExprP<Vec2> t2 = t1 * schur;
5009 ExprP<Mat2> t3 = outerProduct(t2, matC);
5010 ExprP<Mat2> t4 = t3 * invA;
5011 ExprP<Mat2> t5 = invA + t4;
5012 ExprP<Mat2> blockA = bindExpression("blockA", ctx, t5);
5013 ExprP<Vec2> blockB = bindExpression("blockB", ctx,
5014 (invA * matB) * -schur);
5015 ExprP<Vec2> blockC = bindExpression("blockC", ctx,
5016 (matC * invA) * -schur);
5017
5018 return mat3(vec3(blockA[0][0], blockA[0][1], blockC[0]),
5019 vec3(blockA[1][0], blockA[1][1], blockC[1]),
5020 vec3(blockB[0], blockB[1], schur));
5021 }
5022 };
5023
5024 template<>
5025 class Inverse<4> : public DerivedFunc<Signature<Mat4, Mat4> >
5026 {
5027 public:
getName(void) const5028 string getName (void) const { return "inverse"; }
5029
5030 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5031 ExprP<Ret> doExpand (ExpandContext& ctx,
5032 const ArgExprs& args) const
5033 {
5034 ExprP<Mat4> mat = args.a;
5035 ExprP<Mat2> invA = bindExpression("invA", ctx,
5036 inverse(mat2(vec2(mat[0][0], mat[0][1]),
5037 vec2(mat[1][0], mat[1][1]))));
5038 ExprP<Mat2> matB = bindExpression("matB", ctx,
5039 mat2(vec2(mat[2][0], mat[2][1]),
5040 vec2(mat[3][0], mat[3][1])));
5041 ExprP<Mat2> matC = bindExpression("matC", ctx,
5042 mat2(vec2(mat[0][2], mat[0][3]),
5043 vec2(mat[1][2], mat[1][3])));
5044 ExprP<Mat2> matD = bindExpression("matD", ctx,
5045 mat2(vec2(mat[2][2], mat[2][3]),
5046 vec2(mat[3][2], mat[3][3])));
5047 ExprP<Mat2> schur = bindExpression("schur", ctx,
5048 inverse(matD + -(matC * invA * matB)));
5049 ExprP<Mat2> blockA = bindExpression("blockA", ctx,
5050 invA + (invA * matB * schur * matC * invA));
5051 ExprP<Mat2> blockB = bindExpression("blockB", ctx,
5052 (-invA) * matB * schur);
5053 ExprP<Mat2> blockC = bindExpression("blockC", ctx,
5054 (-schur) * matC * invA);
5055
5056 return mat4(vec4(blockA[0][0], blockA[0][1], blockC[0][0], blockC[0][1]),
5057 vec4(blockA[1][0], blockA[1][1], blockC[1][0], blockC[1][1]),
5058 vec4(blockB[0][0], blockB[0][1], schur[0][0], schur[0][1]),
5059 vec4(blockB[1][0], blockB[1][1], schur[1][0], schur[1][1]));
5060 }
5061 };
5062
5063 template<>
5064 class Inverse16bit<2> : public DerivedFunc<Signature<Mat2_16b, Mat2_16b> >
5065 {
5066 public:
getName(void) const5067 string getName (void) const
5068 {
5069 return "inverse";
5070 }
5071
5072 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5073 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
5074 {
5075 ExprP<Mat2_16b> mat = args.a;
5076 ExprP<deFloat16> det = bindExpression("det", ctx, determinant(mat));
5077
5078 return mat2(vec2((mat[1][1] / det), (-mat[0][1] / det)),
5079 vec2((-mat[1][0] / det), (mat[0][0] / det)));
5080 }
5081 };
5082
5083 template<>
5084 class Inverse16bit<3> : public DerivedFunc<Signature<Mat3_16b, Mat3_16b> >
5085 {
5086 public:
getName(void) const5087 string getName(void) const
5088 {
5089 return "inverse";
5090 }
5091
5092 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5093 ExprP<Ret> doExpand(ExpandContext& ctx, const ArgExprs& args) const
5094 {
5095 ExprP<Mat3_16b> mat = args.a;
5096 ExprP<Mat2_16b> invA = bindExpression("invA", ctx,
5097 inverse(mat2(vec2(mat[0][0], mat[0][1]),
5098 vec2(mat[1][0], mat[1][1]))));
5099
5100 ExprP<Vec2_16Bit> matB = bindExpression("matB", ctx, vec2(mat[2][0], mat[2][1]));
5101 ExprP<Vec2_16Bit> matC = bindExpression("matC", ctx, vec2(mat[0][2], mat[1][2]));
5102 ExprP<Mat3_16b::Scalar> matD = bindExpression("matD", ctx, mat[2][2]);
5103
5104 ExprP<Mat3_16b::Scalar> schur = bindExpression("schur", ctx,
5105 constant((deFloat16)FLOAT16_1_0) /
5106 (matD - dot(matC * invA, matB)));
5107
5108 ExprP<Vec2_16Bit> t1 = invA * matB;
5109 ExprP<Vec2_16Bit> t2 = t1 * schur;
5110 ExprP<Mat2_16b> t3 = outerProduct(t2, matC);
5111 ExprP<Mat2_16b> t4 = t3 * invA;
5112 ExprP<Mat2_16b> t5 = invA + t4;
5113 ExprP<Mat2_16b> blockA = bindExpression("blockA", ctx, t5);
5114 ExprP<Vec2_16Bit> blockB = bindExpression("blockB", ctx,
5115 (invA * matB) * -schur);
5116 ExprP<Vec2_16Bit> blockC = bindExpression("blockC", ctx,
5117 (matC * invA) * -schur);
5118
5119 return mat3(vec3(blockA[0][0], blockA[0][1], blockC[0]),
5120 vec3(blockA[1][0], blockA[1][1], blockC[1]),
5121 vec3(blockB[0], blockB[1], schur));
5122 }
5123 };
5124
5125 template<>
5126 class Inverse16bit<4> : public DerivedFunc<Signature<Mat4_16b, Mat4_16b> >
5127 {
5128 public:
getName(void) const5129 string getName(void) const { return "inverse"; }
5130
5131 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5132 ExprP<Ret> doExpand(ExpandContext& ctx,
5133 const ArgExprs& args) const
5134 {
5135 ExprP<Mat4_16b> mat = args.a;
5136 ExprP<Mat2_16b> invA = bindExpression("invA", ctx,
5137 inverse(mat2(vec2(mat[0][0], mat[0][1]),
5138 vec2(mat[1][0], mat[1][1]))));
5139 ExprP<Mat2_16b> matB = bindExpression("matB", ctx,
5140 mat2(vec2(mat[2][0], mat[2][1]),
5141 vec2(mat[3][0], mat[3][1])));
5142 ExprP<Mat2_16b> matC = bindExpression("matC", ctx,
5143 mat2(vec2(mat[0][2], mat[0][3]),
5144 vec2(mat[1][2], mat[1][3])));
5145 ExprP<Mat2_16b> matD = bindExpression("matD", ctx,
5146 mat2(vec2(mat[2][2], mat[2][3]),
5147 vec2(mat[3][2], mat[3][3])));
5148 ExprP<Mat2_16b> schur = bindExpression("schur", ctx,
5149 inverse(matD + -(matC * invA * matB)));
5150 ExprP<Mat2_16b> blockA = bindExpression("blockA", ctx,
5151 invA + (invA * matB * schur * matC * invA));
5152 ExprP<Mat2_16b> blockB = bindExpression("blockB", ctx,
5153 (-invA) * matB * schur);
5154 ExprP<Mat2_16b> blockC = bindExpression("blockC", ctx,
5155 (-schur) * matC * invA);
5156
5157 return mat4(vec4(blockA[0][0], blockA[0][1], blockC[0][0], blockC[0][1]),
5158 vec4(blockA[1][0], blockA[1][1], blockC[1][0], blockC[1][1]),
5159 vec4(blockB[0][0], blockB[0][1], schur[0][0], schur[0][1]),
5160 vec4(blockB[1][0], blockB[1][1], schur[1][0], schur[1][1]));
5161 }
5162 };
5163
5164 template<>
5165 class Inverse64bit<2> : public DerivedFunc<Signature<Matrix2d, Matrix2d> >
5166 {
5167 public:
getName(void) const5168 string getName (void) const
5169 {
5170 return "inverse";
5171 }
5172
5173 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5174 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
5175 {
5176 ExprP<Matrix2d> mat = args.a;
5177 ExprP<double> det = bindExpression("det", ctx, determinant(mat));
5178
5179 return mat2(vec2((mat[1][1] / det), (-mat[0][1] / det)),
5180 vec2((-mat[1][0] / det), (mat[0][0] / det)));
5181 }
5182 };
5183
5184 template<>
5185 class Inverse64bit<3> : public DerivedFunc<Signature<Matrix3d, Matrix3d> >
5186 {
5187 public:
getName(void) const5188 string getName(void) const
5189 {
5190 return "inverse";
5191 }
5192
5193 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5194 ExprP<Ret> doExpand(ExpandContext& ctx, const ArgExprs& args) const
5195 {
5196 ExprP<Matrix3d> mat = args.a;
5197 ExprP<Matrix2d> invA = bindExpression("invA", ctx,
5198 inverse(mat2(vec2(mat[0][0], mat[0][1]),
5199 vec2(mat[1][0], mat[1][1]))));
5200
5201 ExprP<Vec2_64Bit> matB = bindExpression("matB", ctx, vec2(mat[2][0], mat[2][1]));
5202 ExprP<Vec2_64Bit> matC = bindExpression("matC", ctx, vec2(mat[0][2], mat[1][2]));
5203 ExprP<Matrix3d::Scalar> matD = bindExpression("matD", ctx, mat[2][2]);
5204
5205 ExprP<Matrix3d::Scalar> schur = bindExpression("schur", ctx,
5206 constant(1.0) /
5207 (matD - dot(matC * invA, matB)));
5208
5209 ExprP<Vec2_64Bit> t1 = invA * matB;
5210 ExprP<Vec2_64Bit> t2 = t1 * schur;
5211 ExprP<Matrix2d> t3 = outerProduct(t2, matC);
5212 ExprP<Matrix2d> t4 = t3 * invA;
5213 ExprP<Matrix2d> t5 = invA + t4;
5214 ExprP<Matrix2d> blockA = bindExpression("blockA", ctx, t5);
5215 ExprP<Vec2_64Bit> blockB = bindExpression("blockB", ctx,
5216 (invA * matB) * -schur);
5217 ExprP<Vec2_64Bit> blockC = bindExpression("blockC", ctx,
5218 (matC * invA) * -schur);
5219
5220 return mat3(vec3(blockA[0][0], blockA[0][1], blockC[0]),
5221 vec3(blockA[1][0], blockA[1][1], blockC[1]),
5222 vec3(blockB[0], blockB[1], schur));
5223 }
5224 };
5225
5226 template<>
5227 class Inverse64bit<4> : public DerivedFunc<Signature<Matrix4d, Matrix4d> >
5228 {
5229 public:
getName(void) const5230 string getName(void) const { return "inverse"; }
5231
5232 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5233 ExprP<Ret> doExpand(ExpandContext& ctx,
5234 const ArgExprs& args) const
5235 {
5236 ExprP<Matrix4d> mat = args.a;
5237 ExprP<Matrix2d> invA = bindExpression("invA", ctx,
5238 inverse(mat2(vec2(mat[0][0], mat[0][1]),
5239 vec2(mat[1][0], mat[1][1]))));
5240 ExprP<Matrix2d> matB = bindExpression("matB", ctx,
5241 mat2(vec2(mat[2][0], mat[2][1]),
5242 vec2(mat[3][0], mat[3][1])));
5243 ExprP<Matrix2d> matC = bindExpression("matC", ctx,
5244 mat2(vec2(mat[0][2], mat[0][3]),
5245 vec2(mat[1][2], mat[1][3])));
5246 ExprP<Matrix2d> matD = bindExpression("matD", ctx,
5247 mat2(vec2(mat[2][2], mat[2][3]),
5248 vec2(mat[3][2], mat[3][3])));
5249 ExprP<Matrix2d> schur = bindExpression("schur", ctx,
5250 inverse(matD + -(matC * invA * matB)));
5251 ExprP<Matrix2d> blockA = bindExpression("blockA", ctx,
5252 invA + (invA * matB * schur * matC * invA));
5253 ExprP<Matrix2d> blockB = bindExpression("blockB", ctx,
5254 (-invA) * matB * schur);
5255 ExprP<Matrix2d> blockC = bindExpression("blockC", ctx,
5256 (-schur) * matC * invA);
5257
5258 return mat4(vec4(blockA[0][0], blockA[0][1], blockC[0][0], blockC[0][1]),
5259 vec4(blockA[1][0], blockA[1][1], blockC[1][0], blockC[1][1]),
5260 vec4(blockB[0][0], blockB[0][1], schur[0][0], schur[0][1]),
5261 vec4(blockB[1][0], blockB[1][1], schur[1][0], schur[1][1]));
5262 }
5263 };
5264
5265 //Signature<float, float, float, float>
5266 //Signature<deFloat16, deFloat16, deFloat16, deFloat16>
5267 //Signature<double, double, double, double>
5268 template <class T>
5269 class Fma : public DerivedFunc<T>
5270 {
5271 public:
5272 typedef typename Fma::ArgExprs ArgExprs;
5273 typedef typename Fma::Ret Ret;
5274
getName(void) const5275 string getName (void) const
5276 {
5277 return "fma";
5278 }
5279
5280 protected:
doExpand(ExpandContext &,const ArgExprs & x) const5281 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& x) const
5282 {
5283 return x.a * x.b + x.c;
5284 }
5285 };
5286
5287 } // Functions
5288
5289 using namespace Functions;
5290
5291 template <typename T>
operator [](int i) const5292 ExprP<typename T::Element> ContainerExprPBase<T>::operator[] (int i) const
5293 {
5294 return Functions::getComponent(exprP<T>(*this), i);
5295 }
5296
operator +(const ExprP<float> & arg0,const ExprP<float> & arg1)5297 ExprP<float> operator+ (const ExprP<float>& arg0, const ExprP<float>& arg1)
5298 {
5299 return app<Add< Signature<float, float, float> > >(arg0, arg1);
5300 }
5301
operator +(const ExprP<deFloat16> & arg0,const ExprP<deFloat16> & arg1)5302 ExprP<deFloat16> operator+ (const ExprP<deFloat16>& arg0, const ExprP<deFloat16>& arg1)
5303 {
5304 return app<Add< Signature<deFloat16, deFloat16, deFloat16> > >(arg0, arg1);
5305 }
5306
operator +(const ExprP<double> & arg0,const ExprP<double> & arg1)5307 ExprP<double> operator+ (const ExprP<double>& arg0, const ExprP<double>& arg1)
5308 {
5309 return app<Add< Signature<double, double, double> > >(arg0, arg1);
5310 }
5311
5312 template <typename T>
operator -(const ExprP<T> & arg0,const ExprP<T> & arg1)5313 ExprP<T> operator- (const ExprP<T>& arg0, const ExprP<T>& arg1)
5314 {
5315 return app<Sub <Signature <T,T,T> > >(arg0, arg1);
5316 }
5317
5318 template <typename T>
operator -(const ExprP<T> & arg0)5319 ExprP<T> operator- (const ExprP<T>& arg0)
5320 {
5321 return app<Negate< Signature<T, T> > >(arg0);
5322 }
5323
operator *(const ExprP<float> & arg0,const ExprP<float> & arg1)5324 ExprP<float> operator* (const ExprP<float>& arg0, const ExprP<float>& arg1)
5325 {
5326 return app<Mul< Signature<float, float, float> > >(arg0, arg1);
5327 }
5328
operator *(const ExprP<deFloat16> & arg0,const ExprP<deFloat16> & arg1)5329 ExprP<deFloat16> operator* (const ExprP<deFloat16>& arg0, const ExprP<deFloat16>& arg1)
5330 {
5331 return app<Mul< Signature<deFloat16, deFloat16, deFloat16> > >(arg0, arg1);
5332 }
5333
operator *(const ExprP<double> & arg0,const ExprP<double> & arg1)5334 ExprP<double> operator* (const ExprP<double>& arg0, const ExprP<double>& arg1)
5335 {
5336 return app<Mul< Signature<double, double, double> > >(arg0, arg1);
5337 }
5338
5339 template <typename T>
operator /(const ExprP<T> & arg0,const ExprP<T> & arg1)5340 ExprP<T> operator/ (const ExprP<T>& arg0, const ExprP<T>& arg1)
5341 {
5342 return app<Div< Signature<T, T, T> > >(arg0, arg1);
5343 }
5344
5345
5346 template <typename Sig_, int Size>
5347 class GenFunc : public PrimitiveFunc<Signature<
5348 typename ContainerOf<typename Sig_::Ret, Size>::Container,
5349 typename ContainerOf<typename Sig_::Arg0, Size>::Container,
5350 typename ContainerOf<typename Sig_::Arg1, Size>::Container,
5351 typename ContainerOf<typename Sig_::Arg2, Size>::Container,
5352 typename ContainerOf<typename Sig_::Arg3, Size>::Container> >
5353 {
5354 public:
5355 typedef typename GenFunc::IArgs IArgs;
5356 typedef typename GenFunc::IRet IRet;
5357
GenFunc(const Func<Sig_> & scalarFunc)5358 GenFunc (const Func<Sig_>& scalarFunc) : m_func (scalarFunc) {}
5359
getSpirvCase(void) const5360 SpirVCaseT getSpirvCase (void) const
5361 {
5362 return m_func.getSpirvCase();
5363 }
5364
getName(void) const5365 string getName (void) const
5366 {
5367 return m_func.getName();
5368 }
5369
getOutParamIndex(void) const5370 int getOutParamIndex (void) const
5371 {
5372 return m_func.getOutParamIndex();
5373 }
5374
getRequiredExtension(void) const5375 string getRequiredExtension (void) const
5376 {
5377 return m_func.getRequiredExtension();
5378 }
5379
getInputRange(const bool is16bit) const5380 Interval getInputRange (const bool is16bit) const
5381 {
5382 return m_func.getInputRange(is16bit);
5383 }
5384
5385 protected:
doPrint(ostream & os,const BaseArgExprs & args) const5386 void doPrint (ostream& os, const BaseArgExprs& args) const
5387 {
5388 m_func.print(os, args);
5389 }
5390
doApply(const EvalContext & ctx,const IArgs & iargs) const5391 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
5392 {
5393 IRet ret;
5394
5395 for (int ndx = 0; ndx < Size; ++ndx)
5396 {
5397 ret[ndx] =
5398 m_func.apply(ctx, iargs.a[ndx], iargs.b[ndx], iargs.c[ndx], iargs.d[ndx]);
5399 }
5400
5401 return ret;
5402 }
5403
doFail(const EvalContext & ctx,const IArgs & iargs) const5404 IRet doFail (const EvalContext& ctx, const IArgs& iargs) const
5405 {
5406 IRet ret;
5407
5408 for (int ndx = 0; ndx < Size; ++ndx)
5409 {
5410 ret[ndx] =
5411 m_func.fail(ctx, iargs.a[ndx], iargs.b[ndx], iargs.c[ndx], iargs.d[ndx]);
5412 }
5413
5414 return ret;
5415 }
5416
doGetUsedFuncs(FuncSet & dst) const5417 void doGetUsedFuncs (FuncSet& dst) const
5418 {
5419 m_func.getUsedFuncs(dst);
5420 }
5421
5422 const Func<Sig_>& m_func;
5423 };
5424
5425 template <typename F, int Size>
5426 class VectorizedFunc : public GenFunc<typename F::Sig, Size>
5427 {
5428 public:
VectorizedFunc(void)5429 VectorizedFunc (void) : GenFunc<typename F::Sig, Size>(instance<F>()) {}
5430 };
5431
5432 template <typename Sig_, int Size>
5433 class FixedGenFunc : public PrimitiveFunc <Signature<
5434 typename ContainerOf<typename Sig_::Ret, Size>::Container,
5435 typename ContainerOf<typename Sig_::Arg0, Size>::Container,
5436 typename Sig_::Arg1,
5437 typename ContainerOf<typename Sig_::Arg2, Size>::Container,
5438 typename ContainerOf<typename Sig_::Arg3, Size>::Container> >
5439 {
5440 public:
5441 typedef typename FixedGenFunc::IArgs IArgs;
5442 typedef typename FixedGenFunc::IRet IRet;
5443
getName(void) const5444 string getName (void) const
5445 {
5446 return this->doGetScalarFunc().getName();
5447 }
5448
getSpirvCase(void) const5449 SpirVCaseT getSpirvCase (void) const
5450 {
5451 return this->doGetScalarFunc().getSpirvCase();
5452 }
5453
5454 protected:
doPrint(ostream & os,const BaseArgExprs & args) const5455 void doPrint (ostream& os, const BaseArgExprs& args) const
5456 {
5457 this->doGetScalarFunc().print(os, args);
5458 }
5459
doApply(const EvalContext & ctx,const IArgs & iargs) const5460 IRet doApply (const EvalContext& ctx,
5461 const IArgs& iargs) const
5462 {
5463 IRet ret;
5464 const Func<Sig_>& func = this->doGetScalarFunc();
5465
5466 for (int ndx = 0; ndx < Size; ++ndx)
5467 ret[ndx] = func.apply(ctx, iargs.a[ndx], iargs.b, iargs.c[ndx], iargs.d[ndx]);
5468
5469 return ret;
5470 }
5471
5472 virtual const Func<Sig_>& doGetScalarFunc (void) const = 0;
5473 };
5474
5475 template <typename F, int Size>
5476 class FixedVecFunc : public FixedGenFunc<typename F::Sig, Size>
5477 {
5478 protected:
doGetScalarFunc(void) const5479 const Func<typename F::Sig>& doGetScalarFunc (void) const { return instance<F>(); }
5480 };
5481
5482 template<typename Sig>
5483 struct GenFuncs
5484 {
GenFuncsvkt::shaderexecutor::GenFuncs5485 GenFuncs (const Func<Sig>& func_,
5486 const GenFunc<Sig, 2>& func2_,
5487 const GenFunc<Sig, 3>& func3_,
5488 const GenFunc<Sig, 4>& func4_)
5489 : func (func_)
5490 , func2 (func2_)
5491 , func3 (func3_)
5492 , func4 (func4_)
5493 {}
5494
5495 const Func<Sig>& func;
5496 const GenFunc<Sig, 2>& func2;
5497 const GenFunc<Sig, 3>& func3;
5498 const GenFunc<Sig, 4>& func4;
5499 };
5500
5501 template<typename F>
makeVectorizedFuncs(void)5502 GenFuncs<typename F::Sig> makeVectorizedFuncs (void)
5503 {
5504 return GenFuncs<typename F::Sig>(instance<F>(),
5505 instance<VectorizedFunc<F, 2> >(),
5506 instance<VectorizedFunc<F, 3> >(),
5507 instance<VectorizedFunc<F, 4> >());
5508 }
5509
5510 template<typename T, int Size>
operator /(const ExprP<Vector<T,Size>> & arg0,const ExprP<T> & arg1)5511 ExprP<Vector<T, Size> > operator/(const ExprP<Vector<T, Size> >& arg0,
5512 const ExprP<T>& arg1)
5513 {
5514 return app<FixedVecFunc<Div< Signature<T, T, T> >, Size> >(arg0, arg1);
5515 }
5516
5517 template<typename T, int Size>
operator -(const ExprP<Vector<T,Size>> & arg0)5518 ExprP<Vector<T, Size> > operator-(const ExprP<Vector<T, Size> >& arg0)
5519 {
5520 return app<VectorizedFunc<Negate< Signature<T, T> >, Size> >(arg0);
5521 }
5522
5523 template<typename T, int Size>
operator -(const ExprP<Vector<T,Size>> & arg0,const ExprP<Vector<T,Size>> & arg1)5524 ExprP<Vector<T, Size> > operator-(const ExprP<Vector<T, Size> >& arg0,
5525 const ExprP<Vector<T, Size> >& arg1)
5526 {
5527 return app<VectorizedFunc<Sub<Signature<T, T, T> >, Size> >(arg0, arg1);
5528 }
5529
5530 template<int Size, typename T>
operator *(const ExprP<Vector<T,Size>> & arg0,const ExprP<T> & arg1)5531 ExprP<Vector<T, Size> > operator*(const ExprP<Vector<T, Size> >& arg0,
5532 const ExprP<T>& arg1)
5533 {
5534 return app<FixedVecFunc<Mul< Signature<T, T, T> >, Size> >(arg0, arg1);
5535 }
5536
5537 template<typename T, int Size>
operator *(const ExprP<Vector<T,Size>> & arg0,const ExprP<Vector<T,Size>> & arg1)5538 ExprP<Vector<T, Size> > operator*(const ExprP<Vector<T, Size> >& arg0,
5539 const ExprP<Vector<T, Size> >& arg1)
5540 {
5541 return app<VectorizedFunc<Mul< Signature<T, T, T> >, Size> >(arg0, arg1);
5542 }
5543
5544 template<int LeftRows, int Middle, int RightCols, typename T>
5545 ExprP<Matrix<T, LeftRows, RightCols> >
operator *(const ExprP<Matrix<T,LeftRows,Middle>> & left,const ExprP<Matrix<T,Middle,RightCols>> & right)5546 operator* (const ExprP<Matrix<T, LeftRows, Middle> >& left,
5547 const ExprP<Matrix<T, Middle, RightCols> >& right)
5548 {
5549 return app<MatMul<T, LeftRows, Middle, RightCols> >(left, right);
5550 }
5551
5552 template<int Rows, int Cols, typename T>
operator *(const ExprP<Vector<T,Cols>> & left,const ExprP<Matrix<T,Rows,Cols>> & right)5553 ExprP<Vector<T, Rows> > operator* (const ExprP<Vector<T, Cols> >& left,
5554 const ExprP<Matrix<T, Rows, Cols> >& right)
5555 {
5556 return app<VecMatMul<T, Rows, Cols> >(left, right);
5557 }
5558
5559 template<int Rows, int Cols, class T>
operator *(const ExprP<Matrix<T,Rows,Cols>> & left,const ExprP<Vector<T,Rows>> & right)5560 ExprP<Vector<T, Cols> > operator* (const ExprP<Matrix<T, Rows, Cols> >& left,
5561 const ExprP<Vector<T, Rows> >& right)
5562 {
5563 return app<MatVecMul<Rows, Cols, T> >(left, right);
5564 }
5565
5566 template<int Rows, int Cols, typename T>
operator *(const ExprP<Matrix<T,Rows,Cols>> & left,const ExprP<T> & right)5567 ExprP<Matrix<T, Rows, Cols> > operator* (const ExprP<Matrix<T, Rows, Cols> >& left,
5568 const ExprP<T>& right)
5569 {
5570 return app<ScalarMatFunc<Mul< Signature<T, T, T> >, Rows, Cols> >(left, right);
5571 }
5572
5573 template<int Rows, int Cols>
operator +(const ExprP<Matrix<float,Rows,Cols>> & left,const ExprP<Matrix<float,Rows,Cols>> & right)5574 ExprP<Matrix<float, Rows, Cols> > operator+ (const ExprP<Matrix<float, Rows, Cols> >& left,
5575 const ExprP<Matrix<float, Rows, Cols> >& right)
5576 {
5577 return app<CompMatFunc<Add< Signature<float, float, float> >,float, Rows, Cols> >(left, right);
5578 }
5579
5580 template<int Rows, int Cols>
operator +(const ExprP<Matrix<deFloat16,Rows,Cols>> & left,const ExprP<Matrix<deFloat16,Rows,Cols>> & right)5581 ExprP<Matrix<deFloat16, Rows, Cols> > operator+ (const ExprP<Matrix<deFloat16, Rows, Cols> >& left,
5582 const ExprP<Matrix<deFloat16, Rows, Cols> >& right)
5583 {
5584 return app<CompMatFunc<Add< Signature<deFloat16, deFloat16, deFloat16> >, deFloat16, Rows, Cols> >(left, right);
5585 }
5586
5587 template<int Rows, int Cols>
operator +(const ExprP<Matrix<double,Rows,Cols>> & left,const ExprP<Matrix<double,Rows,Cols>> & right)5588 ExprP<Matrix<double, Rows, Cols> > operator+ (const ExprP<Matrix<double, Rows, Cols> >& left,
5589 const ExprP<Matrix<double, Rows, Cols> >& right)
5590 {
5591 return app<CompMatFunc<Add< Signature<double, double, double> >, double, Rows, Cols> >(left, right);
5592 }
5593
5594 template<typename T, int Rows, int Cols>
operator -(const ExprP<Matrix<T,Rows,Cols>> & mat)5595 ExprP<Matrix<T, Rows, Cols> > operator- (const ExprP<Matrix<T, Rows, Cols> >& mat)
5596 {
5597 return app<MatNeg<T, Rows, Cols> >(mat);
5598 }
5599
5600 template <typename T>
5601 class Sampling
5602 {
5603 public:
genFixeds(const FloatFormat &,const Precision,vector<T> &,const Interval &) const5604 virtual void genFixeds (const FloatFormat&, const Precision, vector<T>&, const Interval&) const {}
genRandom(const FloatFormat &,const Precision,Random &,const Interval &) const5605 virtual T genRandom (const FloatFormat&,const Precision, Random&, const Interval&) const { return T(); }
removeNotInRange(vector<T> &,const Interval &,const Precision) const5606 virtual void removeNotInRange (vector<T>&, const Interval&, const Precision) const {};
5607 };
5608
5609 template <>
5610 class DefaultSampling<Void> : public Sampling<Void>
5611 {
5612 public:
genFixeds(const FloatFormat &,const Precision,vector<Void> & dst,const Interval &) const5613 void genFixeds (const FloatFormat&, const Precision, vector<Void>& dst, const Interval&) const { dst.push_back(Void()); }
5614 };
5615
5616 template <>
5617 class DefaultSampling<bool> : public Sampling<bool>
5618 {
5619 public:
genFixeds(const FloatFormat &,const Precision,vector<bool> & dst,const Interval &) const5620 void genFixeds (const FloatFormat&, const Precision, vector<bool>& dst, const Interval&) const
5621 {
5622 dst.push_back(true);
5623 dst.push_back(false);
5624 }
5625 };
5626
5627 template <>
5628 class DefaultSampling<int> : public Sampling<int>
5629 {
5630 public:
genRandom(const FloatFormat &,const Precision prec,Random & rnd,const Interval &) const5631 int genRandom (const FloatFormat&, const Precision prec, Random& rnd, const Interval&) const
5632 {
5633 const int exp = rnd.getInt(0, getNumBits(prec)-2);
5634 const int sign = rnd.getBool() ? -1 : 1;
5635
5636 return sign * rnd.getInt(0, (deInt32)1 << exp);
5637 }
5638
genFixeds(const FloatFormat &,const Precision,vector<int> & dst,const Interval &) const5639 void genFixeds (const FloatFormat&, const Precision, vector<int>& dst, const Interval&) const
5640 {
5641 dst.push_back(0);
5642 dst.push_back(-1);
5643 dst.push_back(1);
5644 }
5645
5646 private:
getNumBits(Precision prec)5647 static inline int getNumBits (Precision prec)
5648 {
5649 switch (prec)
5650 {
5651 case glu::PRECISION_LAST:
5652 case glu::PRECISION_MEDIUMP: return 16;
5653 case glu::PRECISION_HIGHP: return 32;
5654 default:
5655 DE_ASSERT(false);
5656 return 0;
5657 }
5658 }
5659 };
5660
5661 template <>
5662 class DefaultSampling<float> : public Sampling<float>
5663 {
5664 public:
5665 float genRandom (const FloatFormat& format, const Precision prec, Random& rnd, const Interval& inputRange) const;
5666 void genFixeds (const FloatFormat& format, const Precision prec, vector<float>& dst, const Interval& inputRange) const;
5667 void removeNotInRange (vector<float>& dst, const Interval& inputRange, const Precision prec) const;
5668 };
5669
5670 template <>
5671 class DefaultSampling<double> : public Sampling<double>
5672 {
5673 public:
5674 double genRandom (const FloatFormat& format, const Precision prec, Random& rnd, const Interval& inputRange) const;
5675 void genFixeds (const FloatFormat& format, const Precision prec, vector<double>& dst, const Interval& inputRange) const;
5676 void removeNotInRange (vector<double>& dst, const Interval& inputRange, const Precision prec) const;
5677 };
5678
isDenorm16(deFloat16 v)5679 static bool isDenorm16(deFloat16 v)
5680 {
5681 const deUint16 mantissa = 0x03FF;
5682 const deUint16 exponent = 0x7C00;
5683 return ((exponent & v) == 0 && (mantissa & v) != 0);
5684 }
5685
5686 //! Generate a random double from a reasonable general-purpose distribution.
randomDouble(const FloatFormat & format,Random & rnd,const Interval & inputRange)5687 double randomDouble(const FloatFormat& format, Random& rnd, const Interval& inputRange)
5688 {
5689 // No testing of subnormals. TODO: Could integrate float controls for some operations.
5690 const int minExp = format.getMinExp();
5691 const int maxExp = format.getMaxExp();
5692 const bool haveSubnormal = false;
5693 const double midpoint = inputRange.midpoint();
5694
5695 // Choose exponent so that the cumulative distribution is cubic.
5696 // This makes the probability distribution quadratic, with the peak centered on zero.
5697 const double minRoot = deCbrt(minExp - 0.5 - (haveSubnormal ? 1.0 : 0.0));
5698 const double maxRoot = deCbrt(maxExp + 0.5);
5699 const int fractionBits = format.getFractionBits();
5700 const int exp = int(deRoundEven(dePow(rnd.getDouble(minRoot, maxRoot), 3.0)));
5701
5702 // Generate some occasional special numbers
5703 switch (rnd.getInt(0, 64))
5704 {
5705 case 0: return inputRange.contains(0) ? 0 : midpoint;
5706 case 1: return inputRange.contains(TCU_INFINITY) ? TCU_INFINITY : midpoint;
5707 case 2: return inputRange.contains(-TCU_INFINITY) ? -TCU_INFINITY : midpoint;
5708 case 3: return inputRange.contains(TCU_NAN) ? TCU_NAN : midpoint;
5709 default: break;
5710 }
5711
5712 DE_ASSERT(fractionBits < std::numeric_limits<double>::digits);
5713
5714 // Normal number
5715 double base = deLdExp(1.0, exp);
5716 double quantum = deLdExp(1.0, exp - fractionBits); // smallest representable difference in the binade
5717 double significand = 0.0;
5718 switch (rnd.getInt(0, 16))
5719 {
5720 case 0: // The highest number in this binade, significand is all bits one.
5721 significand = base - quantum;
5722 break;
5723 case 1: // Significand is one.
5724 significand = quantum;
5725 break;
5726 case 2: // Significand is zero.
5727 significand = 0.0;
5728 break;
5729 default: // Random (evenly distributed) significand.
5730 {
5731 deUint64 intFraction = rnd.getUint64() & ((1 << fractionBits) - 1);
5732 significand = double(intFraction) * quantum;
5733 }
5734 }
5735
5736 // Produce positive numbers more often than negative.
5737 double value = (rnd.getInt(0, 3) == 0 ? -1.0 : 1.0) * (base + significand);
5738 return inputRange.contains(value) ? value : midpoint;
5739 }
5740
5741 //! Generate a random float from a reasonable general-purpose distribution.
genRandom(const FloatFormat & format,Precision prec,Random & rnd,const Interval & inputRange) const5742 float DefaultSampling<float>::genRandom (const FloatFormat& format,
5743 Precision prec,
5744 Random& rnd,
5745 const Interval& inputRange) const
5746 {
5747 DE_UNREF(prec);
5748 return (float)randomDouble(format, rnd, inputRange);
5749 }
5750
5751 //! Generate a standard set of floats that should always be tested.
genFixeds(const FloatFormat & format,const Precision prec,vector<float> & dst,const Interval & inputRange) const5752 void DefaultSampling<float>::genFixeds (const FloatFormat& format, const Precision prec, vector<float>& dst, const Interval& inputRange) const
5753 {
5754 const int minExp = format.getMinExp();
5755 const int maxExp = format.getMaxExp();
5756 const int fractionBits = format.getFractionBits();
5757 const float minQuantum = deFloatLdExp(1.0f, minExp - fractionBits);
5758 const float minNormalized = deFloatLdExp(1.0f, minExp);
5759 const float maxQuantum = deFloatLdExp(1.0f, maxExp - fractionBits);
5760
5761 // NaN
5762 dst.push_back(TCU_NAN);
5763 // Zero
5764 dst.push_back(0.0f);
5765
5766 for (int sign = -1; sign <= 1; sign += 2)
5767 {
5768 // Smallest normalized
5769 dst.push_back((float)sign * minNormalized);
5770
5771 // Next smallest normalized
5772 dst.push_back((float)sign * (minNormalized + minQuantum));
5773
5774 dst.push_back((float)sign * 0.5f);
5775 dst.push_back((float)sign * 1.0f);
5776 dst.push_back((float)sign * 2.0f);
5777
5778 // Largest number
5779 dst.push_back((float)sign * (deFloatLdExp(1.0f, maxExp) +
5780 (deFloatLdExp(1.0f, maxExp) - maxQuantum)));
5781
5782 dst.push_back((float)sign * TCU_INFINITY);
5783 }
5784 removeNotInRange(dst, inputRange, prec);
5785 }
5786
removeNotInRange(vector<float> & dst,const Interval & inputRange,const Precision prec) const5787 void DefaultSampling<float>::removeNotInRange (vector<float>& dst, const Interval& inputRange, const Precision prec) const
5788 {
5789 for (vector<float>::iterator it = dst.begin(); it < dst.end();)
5790 {
5791 // Remove out of range values. PRECISION_LAST means this is an FP16 test so remove any values that
5792 // will be denorms when converted to FP16. (This is used in the precision_fp16_storage32b test group).
5793 if ( !inputRange.contains(static_cast<double>(*it)) || (prec == glu::PRECISION_LAST && isDenorm16(deFloat32To16Round(*it, DE_ROUNDINGMODE_TO_ZERO))))
5794 it = dst.erase(it);
5795 else
5796 ++it;
5797 }
5798 }
5799
5800 //! Generate a random double from a reasonable general-purpose distribution.
genRandom(const FloatFormat & format,Precision prec,Random & rnd,const Interval & inputRange) const5801 double DefaultSampling<double>::genRandom (const FloatFormat& format,
5802 Precision prec,
5803 Random& rnd,
5804 const Interval& inputRange) const
5805 {
5806 DE_UNREF(prec);
5807 return randomDouble(format, rnd, inputRange);
5808 }
5809
5810 //! Generate a standard set of floats that should always be tested.
genFixeds(const FloatFormat & format,const Precision prec,vector<double> & dst,const Interval & inputRange) const5811 void DefaultSampling<double>::genFixeds (const FloatFormat& format, const Precision prec, vector<double>& dst, const Interval& inputRange) const
5812 {
5813 const int minExp = format.getMinExp();
5814 const int maxExp = format.getMaxExp();
5815 const int fractionBits = format.getFractionBits();
5816 const double minQuantum = deLdExp(1.0, minExp - fractionBits);
5817 const double minNormalized = deLdExp(1.0, minExp);
5818 const double maxQuantum = deLdExp(1.0, maxExp - fractionBits);
5819
5820 // NaN
5821 dst.push_back(TCU_NAN);
5822 // Zero
5823 dst.push_back(0.0);
5824
5825 for (int sign = -1; sign <= 1; sign += 2)
5826 {
5827 // Smallest normalized
5828 dst.push_back((double)sign * minNormalized);
5829
5830 // Next smallest normalized
5831 dst.push_back((double)sign * (minNormalized + minQuantum));
5832
5833 dst.push_back((double)sign * 0.5);
5834 dst.push_back((double)sign * 1.0);
5835 dst.push_back((double)sign * 2.0);
5836
5837 // Largest number
5838 dst.push_back((double)sign * (deLdExp(1.0, maxExp) + (deLdExp(1.0, maxExp) - maxQuantum)));
5839
5840 dst.push_back((double)sign * TCU_INFINITY);
5841 }
5842 removeNotInRange(dst, inputRange, prec);
5843 }
5844
removeNotInRange(vector<double> & dst,const Interval & inputRange,const Precision) const5845 void DefaultSampling<double>::removeNotInRange (vector<double>& dst, const Interval& inputRange, const Precision) const
5846 {
5847 for (vector<double>::iterator it = dst.begin(); it < dst.end();)
5848 {
5849 if ( !inputRange.contains(*it) )
5850 it = dst.erase(it);
5851 else
5852 ++it;
5853 }
5854 }
5855
5856 template <>
5857 class DefaultSampling<deFloat16> : public Sampling<deFloat16>
5858 {
5859 public:
5860 deFloat16 genRandom (const FloatFormat& format, const Precision prec, Random& rnd, const Interval& inputRange) const;
5861 void genFixeds (const FloatFormat& format, const Precision prec, vector<deFloat16>& dst, const Interval& inputRange) const;
5862 private:
5863 void removeNotInRange(vector<deFloat16>& dst, const Interval& inputRange, const Precision prec) const;
5864 };
5865
5866 //! Generate a random float from a reasonable general-purpose distribution.
genRandom(const FloatFormat & format,const Precision prec,Random & rnd,const Interval & inputRange) const5867 deFloat16 DefaultSampling<deFloat16>::genRandom (const FloatFormat& format, const Precision prec,
5868 Random& rnd, const Interval& inputRange) const
5869 {
5870 DE_UNREF(prec);
5871 return deFloat64To16Round(randomDouble(format, rnd, inputRange), DE_ROUNDINGMODE_TO_NEAREST_EVEN);
5872 }
5873
5874 //! Generate a standard set of floats that should always be tested.
genFixeds(const FloatFormat & format,const Precision prec,vector<deFloat16> & dst,const Interval & inputRange) const5875 void DefaultSampling<deFloat16>::genFixeds (const FloatFormat& format, const Precision prec, vector<deFloat16>& dst, const Interval& inputRange) const
5876 {
5877 dst.push_back(deUint16(0x3E00)); //1.5
5878 dst.push_back(deUint16(0x3D00)); //1.25
5879 dst.push_back(deUint16(0x3F00)); //1.75
5880 // Zero
5881 dst.push_back(deUint16(0x0000));
5882 dst.push_back(deUint16(0x8000));
5883 // Infinity
5884 dst.push_back(deUint16(0x7c00));
5885 dst.push_back(deUint16(0xfc00));
5886 // SNaN
5887 dst.push_back(deUint16(0x7c0f));
5888 dst.push_back(deUint16(0xfc0f));
5889 // QNaN
5890 dst.push_back(deUint16(0x7cf0));
5891 dst.push_back(deUint16(0xfcf0));
5892 // Normalized
5893 dst.push_back(deUint16(0x0401));
5894 dst.push_back(deUint16(0x8401));
5895 // Some normal number
5896 dst.push_back(deUint16(0x14cb));
5897 dst.push_back(deUint16(0x94cb));
5898
5899 const int minExp = format.getMinExp();
5900 const int maxExp = format.getMaxExp();
5901 const int fractionBits = format.getFractionBits();
5902 const float minQuantum = deFloatLdExp(1.0f, minExp - fractionBits);
5903 const float minNormalized = deFloatLdExp(1.0f, minExp);
5904 const float maxQuantum = deFloatLdExp(1.0f, maxExp - fractionBits);
5905
5906 for (float sign = -1.0; sign <= 1.0f; sign += 2.0f)
5907 {
5908 // Smallest normalized
5909 dst.push_back(deFloat32To16Round(sign * minNormalized, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5910
5911 // Next smallest normalized
5912 dst.push_back(deFloat32To16Round(sign * (minNormalized + minQuantum), DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5913
5914 dst.push_back(deFloat32To16Round(sign * 0.5f, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5915 dst.push_back(deFloat32To16Round(sign * 1.0f, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5916 dst.push_back(deFloat32To16Round(sign * 2.0f, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5917
5918 // Largest number
5919 dst.push_back(deFloat32To16Round(sign * (deFloatLdExp(1.0f, maxExp) +
5920 (deFloatLdExp(1.0f, maxExp) - maxQuantum)), DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5921
5922 dst.push_back(deFloat32To16Round(sign * TCU_INFINITY, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5923 }
5924 removeNotInRange(dst, inputRange, prec);
5925 }
5926
removeNotInRange(vector<deFloat16> & dst,const Interval & inputRange,const Precision) const5927 void DefaultSampling<deFloat16>::removeNotInRange(vector<deFloat16>& dst, const Interval& inputRange, const Precision) const
5928 {
5929 for (vector<deFloat16>::iterator it = dst.begin(); it < dst.end();)
5930 {
5931 if (inputRange.contains(static_cast<double>(*it)))
5932 ++it;
5933 else
5934 it = dst.erase(it);
5935 }
5936 }
5937
5938 template <typename T, int Size>
5939 class DefaultSampling<Vector<T, Size> > : public Sampling<Vector<T, Size> >
5940 {
5941 public:
5942 typedef Vector<T, Size> Value;
5943
genRandom(const FloatFormat & fmt,const Precision prec,Random & rnd,const Interval & inputRange) const5944 Value genRandom (const FloatFormat& fmt, const Precision prec, Random& rnd, const Interval& inputRange) const
5945 {
5946 Value ret;
5947
5948 for (int ndx = 0; ndx < Size; ++ndx)
5949 ret[ndx] = instance<DefaultSampling<T> >().genRandom(fmt, prec, rnd, inputRange);
5950
5951 return ret;
5952 }
5953
genFixeds(const FloatFormat & fmt,const Precision prec,vector<Value> & dst,const Interval & inputRange) const5954 void genFixeds (const FloatFormat& fmt, const Precision prec, vector<Value>& dst, const Interval& inputRange) const
5955 {
5956 vector<T> scalars;
5957
5958 instance<DefaultSampling<T> >().genFixeds(fmt, prec, scalars, inputRange);
5959
5960 for (size_t scalarNdx = 0; scalarNdx < scalars.size(); ++scalarNdx)
5961 dst.push_back(Value(scalars[scalarNdx]));
5962 }
5963 };
5964
5965 template <typename T, int Rows, int Columns>
5966 class DefaultSampling<Matrix<T, Rows, Columns> > : public Sampling<Matrix<T, Rows, Columns> >
5967 {
5968 public:
5969 typedef Matrix<T, Rows, Columns> Value;
5970
genRandom(const FloatFormat & fmt,const Precision prec,Random & rnd,const Interval & inputRange) const5971 Value genRandom (const FloatFormat& fmt, const Precision prec, Random& rnd, const Interval& inputRange) const
5972 {
5973 Value ret;
5974
5975 for (int rowNdx = 0; rowNdx < Rows; ++rowNdx)
5976 for (int colNdx = 0; colNdx < Columns; ++colNdx)
5977 ret(rowNdx, colNdx) = instance<DefaultSampling<T> >().genRandom(fmt, prec, rnd, inputRange);
5978
5979 return ret;
5980 }
5981
genFixeds(const FloatFormat & fmt,const Precision prec,vector<Value> & dst,const Interval & inputRange) const5982 void genFixeds (const FloatFormat& fmt, const Precision prec, vector<Value>& dst, const Interval& inputRange) const
5983 {
5984 vector<T> scalars;
5985
5986 instance<DefaultSampling<T> >().genFixeds(fmt, prec, scalars, inputRange);
5987
5988 for (size_t scalarNdx = 0; scalarNdx < scalars.size(); ++scalarNdx)
5989 dst.push_back(Value(scalars[scalarNdx]));
5990
5991 if (Columns == Rows)
5992 {
5993 Value mat (T(0.0));
5994 T x = T(1.0f);
5995 mat[0][0] = x;
5996 for (int ndx = 0; ndx < Columns; ++ndx)
5997 {
5998 mat[Columns-1-ndx][ndx] = x;
5999 x = static_cast<T>(x * static_cast<T>(2.0f));
6000 }
6001 dst.push_back(mat);
6002 }
6003 }
6004 };
6005
6006 struct CaseContext
6007 {
CaseContextvkt::shaderexecutor::CaseContext6008 CaseContext (const string& name_,
6009 TestContext& testContext_,
6010 const FloatFormat& floatFormat_,
6011 const FloatFormat& highpFormat_,
6012 const Precision precision_,
6013 const ShaderType shaderType_,
6014 const size_t numRandoms_,
6015 const PrecisionTestFeatures precisionTestFeatures_ = PRECISION_TEST_FEATURES_NONE,
6016 const bool isPackFloat16b_ = false,
6017 const bool isFloat64b_ = false)
6018 : name (name_)
6019 , testContext (testContext_)
6020 , floatFormat (floatFormat_)
6021 , highpFormat (highpFormat_)
6022 , precision (precision_)
6023 , shaderType (shaderType_)
6024 , numRandoms (numRandoms_)
6025 , inputRange (-TCU_INFINITY, TCU_INFINITY)
6026 , precisionTestFeatures (precisionTestFeatures_)
6027 , isPackFloat16b (isPackFloat16b_)
6028 , isFloat64b (isFloat64b_)
6029 {}
6030
6031 string name;
6032 TestContext& testContext;
6033 FloatFormat floatFormat;
6034 FloatFormat highpFormat;
6035 Precision precision;
6036 ShaderType shaderType;
6037 size_t numRandoms;
6038 Interval inputRange;
6039 PrecisionTestFeatures precisionTestFeatures;
6040 bool isPackFloat16b;
6041 bool isFloat64b;
6042 };
6043
6044 template<typename In0_ = Void, typename In1_ = Void, typename In2_ = Void, typename In3_ = Void>
6045 struct InTypes
6046 {
6047 typedef In0_ In0;
6048 typedef In1_ In1;
6049 typedef In2_ In2;
6050 typedef In3_ In3;
6051 };
6052
6053 template <typename In>
numInputs(void)6054 int numInputs (void)
6055 {
6056 return (!isTypeValid<typename In::In0>() ? 0 :
6057 !isTypeValid<typename In::In1>() ? 1 :
6058 !isTypeValid<typename In::In2>() ? 2 :
6059 !isTypeValid<typename In::In3>() ? 3 :
6060 4);
6061 }
6062
6063 template<typename Out0_, typename Out1_ = Void>
6064 struct OutTypes
6065 {
6066 typedef Out0_ Out0;
6067 typedef Out1_ Out1;
6068 };
6069
6070 template <typename Out>
numOutputs(void)6071 int numOutputs (void)
6072 {
6073 return (!isTypeValid<typename Out::Out0>() ? 0 :
6074 !isTypeValid<typename Out::Out1>() ? 1 :
6075 2);
6076 }
6077
6078 template<typename In>
6079 struct Inputs
6080 {
6081 vector<typename In::In0> in0;
6082 vector<typename In::In1> in1;
6083 vector<typename In::In2> in2;
6084 vector<typename In::In3> in3;
6085 };
6086
6087 template<typename Out>
6088 struct Outputs
6089 {
Outputsvkt::shaderexecutor::Outputs6090 Outputs (size_t size) : out0(size), out1(size) {}
6091
6092 vector<typename Out::Out0> out0;
6093 vector<typename Out::Out1> out1;
6094 };
6095
6096 template<typename In, typename Out>
6097 struct Variables
6098 {
6099 VariableP<typename In::In0> in0;
6100 VariableP<typename In::In1> in1;
6101 VariableP<typename In::In2> in2;
6102 VariableP<typename In::In3> in3;
6103 VariableP<typename Out::Out0> out0;
6104 VariableP<typename Out::Out1> out1;
6105 };
6106
6107 template<typename In>
6108 struct Samplings
6109 {
Samplingsvkt::shaderexecutor::Samplings6110 Samplings (const Sampling<typename In::In0>& in0_,
6111 const Sampling<typename In::In1>& in1_,
6112 const Sampling<typename In::In2>& in2_,
6113 const Sampling<typename In::In3>& in3_)
6114 : in0 (in0_), in1 (in1_), in2 (in2_), in3 (in3_) {}
6115
6116 const Sampling<typename In::In0>& in0;
6117 const Sampling<typename In::In1>& in1;
6118 const Sampling<typename In::In2>& in2;
6119 const Sampling<typename In::In3>& in3;
6120 };
6121
6122 template<typename In>
6123 struct DefaultSamplings : Samplings<In>
6124 {
DefaultSamplingsvkt::shaderexecutor::DefaultSamplings6125 DefaultSamplings (void)
6126 : Samplings<In>(instance<DefaultSampling<typename In::In0> >(),
6127 instance<DefaultSampling<typename In::In1> >(),
6128 instance<DefaultSampling<typename In::In2> >(),
6129 instance<DefaultSampling<typename In::In3> >()) {}
6130 };
6131
6132 template <typename In, typename Out>
6133 class BuiltinPrecisionCaseTestInstance : public TestInstance
6134 {
6135 public:
BuiltinPrecisionCaseTestInstance(Context & context,const CaseContext caseCtx,const ShaderSpec & shaderSpec,const Variables<In,Out> variables,const Samplings<In> & samplings,const StatementP stmt,bool modularOp=false)6136 BuiltinPrecisionCaseTestInstance (Context& context,
6137 const CaseContext caseCtx,
6138 const ShaderSpec& shaderSpec,
6139 const Variables<In, Out> variables,
6140 const Samplings<In>& samplings,
6141 const StatementP stmt,
6142 bool modularOp = false)
6143 : TestInstance (context)
6144 , m_caseCtx (caseCtx)
6145 , m_variables (variables)
6146 , m_samplings (samplings)
6147 , m_stmt (stmt)
6148 , m_executor (createExecutor(context, caseCtx.shaderType, shaderSpec))
6149 , m_modularOp (modularOp)
6150 {
6151 }
6152 virtual tcu::TestStatus iterate (void);
6153
6154 protected:
6155 CaseContext m_caseCtx;
6156 Variables<In, Out> m_variables;
6157 const Samplings<In>& m_samplings;
6158 StatementP m_stmt;
6159 de::UniquePtr<ShaderExecutor> m_executor;
6160 bool m_modularOp;
6161 };
6162
6163 template<class In, class Out>
iterate(void)6164 tcu::TestStatus BuiltinPrecisionCaseTestInstance<In, Out>::iterate (void)
6165 {
6166 typedef typename In::In0 In0;
6167 typedef typename In::In1 In1;
6168 typedef typename In::In2 In2;
6169 typedef typename In::In3 In3;
6170 typedef typename Out::Out0 Out0;
6171 typedef typename Out::Out1 Out1;
6172
6173 areFeaturesSupported(m_context, m_caseCtx.precisionTestFeatures);
6174 Inputs<In> inputs = generateInputs(m_samplings, m_caseCtx.floatFormat, m_caseCtx.precision, m_caseCtx.numRandoms, 0xdeadbeefu + m_caseCtx.testContext.getCommandLine().getBaseSeed(), m_caseCtx.inputRange);
6175 const FloatFormat& fmt = m_caseCtx.floatFormat;
6176 const int inCount = numInputs<In>();
6177 const int outCount = numOutputs<Out>();
6178 const size_t numValues = (inCount > 0) ? inputs.in0.size() : 1;
6179 Outputs<Out> outputs (numValues);
6180 const FloatFormat highpFmt = m_caseCtx.highpFormat;
6181 const int maxMsgs = 100;
6182 int numErrors = 0;
6183 Environment env; // Hoisted out of the inner loop for optimization.
6184 ResultCollector status;
6185 TestLog& testLog = m_context.getTestContext().getLog();
6186
6187 // Module operations need exactly two inputs and have exactly one output.
6188 if (m_modularOp)
6189 {
6190 DE_ASSERT(inCount == 2);
6191 DE_ASSERT(outCount == 1);
6192 }
6193
6194 const void* inputArr[] =
6195 {
6196 inputs.in0.data(), inputs.in1.data(), inputs.in2.data(), inputs.in3.data(),
6197 };
6198 void* outputArr[] =
6199 {
6200 outputs.out0.data(), outputs.out1.data(),
6201 };
6202
6203 // Print out the statement and its definitions
6204 testLog << TestLog::Message << "Statement: " << m_stmt << TestLog::EndMessage;
6205 {
6206 ostringstream oss;
6207 FuncSet funcs;
6208
6209 m_stmt->getUsedFuncs(funcs);
6210 for (FuncSet::const_iterator it = funcs.begin(); it != funcs.end(); ++it)
6211 {
6212 (*it)->printDefinition(oss);
6213 }
6214 if (!funcs.empty())
6215 testLog << TestLog::Message << "Reference definitions:\n" << oss.str()
6216 << TestLog::EndMessage;
6217 }
6218 switch (inCount)
6219 {
6220 case 4:
6221 DE_ASSERT(inputs.in3.size() == numValues);
6222 // Fallthrough
6223 case 3:
6224 DE_ASSERT(inputs.in2.size() == numValues);
6225 // Fallthrough
6226 case 2:
6227 DE_ASSERT(inputs.in1.size() == numValues);
6228 // Fallthrough
6229 case 1:
6230 DE_ASSERT(inputs.in0.size() == numValues);
6231 // Fallthrough
6232 default:
6233 break;
6234 }
6235
6236 m_executor->execute(int(numValues), inputArr, outputArr);
6237
6238 // Initialize environment with dummy values so we don't need to bind in inner loop.
6239 {
6240 const typename Traits<In0>::IVal in0;
6241 const typename Traits<In1>::IVal in1;
6242 const typename Traits<In2>::IVal in2;
6243 const typename Traits<In3>::IVal in3;
6244 const typename Traits<Out0>::IVal reference0;
6245 const typename Traits<Out1>::IVal reference1;
6246
6247 env.bind(*m_variables.in0, in0);
6248 env.bind(*m_variables.in1, in1);
6249 env.bind(*m_variables.in2, in2);
6250 env.bind(*m_variables.in3, in3);
6251 env.bind(*m_variables.out0, reference0);
6252 env.bind(*m_variables.out1, reference1);
6253 }
6254
6255 // For each input tuple, compute output reference interval and compare
6256 // shader output to the reference.
6257 for (size_t valueNdx = 0; valueNdx < numValues; valueNdx++)
6258 {
6259 bool result = true;
6260 const bool isInput16Bit = m_executor->areInputs16Bit();
6261 const bool isInput64Bit = m_executor->areInputs64Bit();
6262
6263 DE_ASSERT(!(isInput16Bit && isInput64Bit));
6264
6265 typename Traits<Out0>::IVal reference0;
6266 typename Traits<Out1>::IVal reference1;
6267
6268 if (valueNdx % (size_t)TOUCH_WATCHDOG_VALUE_FREQUENCY == 0)
6269 m_context.getTestContext().touchWatchdog();
6270
6271 env.lookup(*m_variables.in0) = convert<In0>(fmt, round(fmt, inputs.in0[valueNdx]));
6272 env.lookup(*m_variables.in1) = convert<In1>(fmt, round(fmt, inputs.in1[valueNdx]));
6273 env.lookup(*m_variables.in2) = convert<In2>(fmt, round(fmt, inputs.in2[valueNdx]));
6274 env.lookup(*m_variables.in3) = convert<In3>(fmt, round(fmt, inputs.in3[valueNdx]));
6275
6276 {
6277 EvalContext ctx (fmt, m_caseCtx.precision, env, 0);
6278 m_stmt->execute(ctx);
6279
6280 switch (outCount)
6281 {
6282 case 2:
6283 reference1 = convert<Out1>(highpFmt, env.lookup(*m_variables.out1));
6284 if (!status.check(contains(reference1, outputs.out1[valueNdx], m_caseCtx.isPackFloat16b), "Shader output 1 is outside acceptable range"))
6285 result = false;
6286 // Fallthrough
6287 case 1:
6288 {
6289 // Pass b from mod(a, b) if we are in the modulo operation.
6290 const tcu::Maybe<In1> modularDivisor = (m_modularOp ? tcu::just(inputs.in1[valueNdx]) : tcu::nothing<In1>());
6291
6292 reference0 = convert<Out0>(highpFmt, env.lookup(*m_variables.out0));
6293 if (!status.check(contains(reference0, outputs.out0[valueNdx], m_caseCtx.isPackFloat16b, modularDivisor), "Shader output 0 is outside acceptable range"))
6294 {
6295 m_stmt->failed(ctx);
6296 reference0 = convert<Out0>(highpFmt, env.lookup(*m_variables.out0));
6297 if (!status.check(contains(reference0, outputs.out0[valueNdx], m_caseCtx.isPackFloat16b, modularDivisor), "Shader output 0 is outside acceptable range"))
6298 result = false;
6299 }
6300 }
6301 // Fallthrough
6302 default: break;
6303 }
6304
6305 }
6306 if (!result)
6307 ++numErrors;
6308
6309 if ((!result && numErrors <= maxMsgs) || GLS_LOG_ALL_RESULTS)
6310 {
6311 MessageBuilder builder = testLog.message();
6312
6313 builder << (result ? "Passed" : "Failed") << " sample:\n";
6314
6315 if (inCount > 0)
6316 {
6317 builder << "\t" << m_variables.in0->getName() << " = "
6318 << (isInput64Bit ? value64ToString(highpFmt, inputs.in0[valueNdx]) : (isInput16Bit ? value16ToString(highpFmt, inputs.in0[valueNdx]) : value32ToString(highpFmt, inputs.in0[valueNdx]))) << "\n";
6319 }
6320
6321 if (inCount > 1)
6322 {
6323 builder << "\t" << m_variables.in1->getName() << " = "
6324 << (isInput64Bit ? value64ToString(highpFmt, inputs.in1[valueNdx]) : (isInput16Bit ? value16ToString(highpFmt, inputs.in1[valueNdx]) : value32ToString(highpFmt, inputs.in1[valueNdx]))) << "\n";
6325 }
6326
6327 if (inCount > 2)
6328 {
6329 builder << "\t" << m_variables.in2->getName() << " = "
6330 << (isInput64Bit ? value64ToString(highpFmt, inputs.in2[valueNdx]) : (isInput16Bit ? value16ToString(highpFmt, inputs.in2[valueNdx]) : value32ToString(highpFmt, inputs.in2[valueNdx]))) << "\n";
6331 }
6332
6333 if (inCount > 3)
6334 {
6335 builder << "\t" << m_variables.in3->getName() << " = "
6336 << (isInput64Bit ? value64ToString(highpFmt, inputs.in3[valueNdx]) : (isInput16Bit ? value16ToString(highpFmt, inputs.in3[valueNdx]) : value32ToString(highpFmt, inputs.in3[valueNdx]))) << "\n";
6337 }
6338
6339 if (outCount > 0)
6340 {
6341 if (m_executor->spirvCase() == SPIRV_CASETYPE_COMPARE)
6342 {
6343 builder << "Output:\n"
6344 << comparisonMessage(outputs.out0[valueNdx])
6345 << "Expected result:\n"
6346 << comparisonMessageInterval<typename Out::Out0>(reference0) << "\n";
6347 }
6348 else
6349 {
6350 builder << "\t" << m_variables.out0->getName() << " = "
6351 << (m_executor->isOutput64Bit(0u) ? value64ToString(highpFmt, outputs.out0[valueNdx]) : (m_executor->isOutput16Bit(0u) || m_caseCtx.isPackFloat16b ? value16ToString(highpFmt, outputs.out0[valueNdx]) : value32ToString(highpFmt, outputs.out0[valueNdx]))) << "\n"
6352 << "\tExpected range: "
6353 << intervalToString<typename Out::Out0>(highpFmt, reference0) << "\n";
6354 }
6355 }
6356
6357 if (outCount > 1)
6358 {
6359 builder << "\t" << m_variables.out1->getName() << " = "
6360 << (m_executor->isOutput64Bit(1u) ? value64ToString(highpFmt, outputs.out1[valueNdx]) : (m_executor->isOutput16Bit(1u) || m_caseCtx.isPackFloat16b ? value16ToString(highpFmt, outputs.out1[valueNdx]) : value32ToString(highpFmt, outputs.out1[valueNdx]))) << "\n"
6361 << "\tExpected range: "
6362 << intervalToString<typename Out::Out1>(highpFmt, reference1) << "\n";
6363 }
6364
6365 builder << TestLog::EndMessage;
6366 }
6367 }
6368
6369 if (numErrors > maxMsgs)
6370 {
6371 testLog << TestLog::Message << "(Skipped " << (numErrors - maxMsgs) << " messages.)"
6372 << TestLog::EndMessage;
6373 }
6374
6375 if (numErrors == 0)
6376 {
6377 testLog << TestLog::Message << "All " << numValues << " inputs passed."
6378 << TestLog::EndMessage;
6379 }
6380 else
6381 {
6382 testLog << TestLog::Message << numErrors << "/" << numValues << " inputs failed."
6383 << TestLog::EndMessage;
6384 }
6385
6386 if (numErrors)
6387 return tcu::TestStatus::fail(de::toString(numErrors) + string(" test failed. Check log for the details"));
6388 else
6389 return tcu::TestStatus::pass("Pass");
6390
6391 }
6392
6393 class PrecisionCase : public TestCase
6394 {
6395 protected:
PrecisionCase(const CaseContext & context,const string & name,const Interval & inputRange,const string & extension="")6396 PrecisionCase (const CaseContext& context, const string& name, const Interval& inputRange, const string& extension = "")
6397 : TestCase (context.testContext, name.c_str(), name.c_str())
6398 , m_ctx (context)
6399 , m_extension (extension)
6400 {
6401 m_ctx.inputRange = inputRange;
6402 m_spec.packFloat16Bit = context.isPackFloat16b;
6403 }
6404
initPrograms(vk::SourceCollections & programCollection) const6405 virtual void initPrograms (vk::SourceCollections& programCollection) const
6406 {
6407 generateSources(m_ctx.shaderType, m_spec, programCollection);
6408 }
6409
getFormat(void) const6410 const FloatFormat& getFormat (void) const { return m_ctx.floatFormat; }
6411
6412 template <typename In, typename Out>
6413 void testStatement (const Variables<In, Out>& variables, const Statement& stmt, SpirVCaseT spirvCase);
6414
6415 template<typename T>
makeSymbol(const Variable<T> & variable)6416 Symbol makeSymbol (const Variable<T>& variable)
6417 {
6418 return Symbol(variable.getName(), getVarTypeOf<T>(m_ctx.precision));
6419 }
6420
6421 CaseContext m_ctx;
6422 const string m_extension;
6423 ShaderSpec m_spec;
6424 };
6425
6426 template <typename In, typename Out>
testStatement(const Variables<In,Out> & variables,const Statement & stmt,SpirVCaseT spirvCase)6427 void PrecisionCase::testStatement (const Variables<In, Out>& variables, const Statement& stmt, SpirVCaseT spirvCase)
6428 {
6429 const int inCount = numInputs<In>();
6430 const int outCount = numOutputs<Out>();
6431 Environment env; // Hoisted out of the inner loop for optimization.
6432
6433 // Initialize ShaderSpec from precision, variables and statement.
6434 if (m_ctx.precision != glu::PRECISION_LAST)
6435 {
6436 ostringstream os;
6437 os << "precision " << glu::getPrecisionName(m_ctx.precision) << " float;\n";
6438 m_spec.globalDeclarations = os.str();
6439 }
6440
6441 if (!m_extension.empty())
6442 m_spec.globalDeclarations = "#extension " + m_extension + " : require\n";
6443
6444 m_spec.inputs.resize(inCount);
6445
6446 switch (inCount)
6447 {
6448 case 4:
6449 m_spec.inputs[3] = makeSymbol(*variables.in3);
6450 // Fallthrough
6451 case 3:
6452 m_spec.inputs[2] = makeSymbol(*variables.in2);
6453 // Fallthrough
6454 case 2:
6455 m_spec.inputs[1] = makeSymbol(*variables.in1);
6456 // Fallthrough
6457 case 1:
6458 m_spec.inputs[0] = makeSymbol(*variables.in0);
6459 // Fallthrough
6460 default:
6461 break;
6462 }
6463
6464 bool inputs16Bit = false;
6465 for (vector<Symbol>::const_iterator symIter = m_spec.inputs.begin(); symIter != m_spec.inputs.end(); ++symIter)
6466 inputs16Bit = inputs16Bit || glu::isDataTypeFloat16OrVec(symIter->varType.getBasicType());
6467
6468 if (inputs16Bit || m_spec.packFloat16Bit)
6469 m_spec.globalDeclarations += "#extension GL_EXT_shader_explicit_arithmetic_types: require\n";
6470
6471 m_spec.outputs.resize(outCount);
6472
6473 switch (outCount)
6474 {
6475 case 2:
6476 m_spec.outputs[1] = makeSymbol(*variables.out1);
6477 // Fallthrough
6478 case 1:
6479 m_spec.outputs[0] = makeSymbol(*variables.out0);
6480 // Fallthrough
6481 default:
6482 break;
6483 }
6484
6485 m_spec.source = de::toString(stmt);
6486 m_spec.spirvCase = spirvCase;
6487 }
6488
6489 template <typename T>
6490 struct InputLess
6491 {
operator ()vkt::shaderexecutor::InputLess6492 bool operator() (const T& val1, const T& val2) const
6493 {
6494 return val1 < val2;
6495 }
6496 };
6497
6498 template <typename T>
inputLess(const T & val1,const T & val2)6499 bool inputLess (const T& val1, const T& val2)
6500 {
6501 return InputLess<T>()(val1, val2);
6502 }
6503
6504 template <>
6505 struct InputLess<float>
6506 {
operator ()vkt::shaderexecutor::InputLess6507 bool operator() (const float& val1, const float& val2) const
6508 {
6509 if (deIsNaN(val1))
6510 return false;
6511 if (deIsNaN(val2))
6512 return true;
6513 return val1 < val2;
6514 }
6515 };
6516
6517 template <typename T, int Size>
6518 struct InputLess<Vector<T, Size> >
6519 {
operator ()vkt::shaderexecutor::InputLess6520 bool operator() (const Vector<T, Size>& vec1, const Vector<T, Size>& vec2) const
6521 {
6522 for (int ndx = 0; ndx < Size; ++ndx)
6523 {
6524 if (inputLess(vec1[ndx], vec2[ndx]))
6525 return true;
6526 if (inputLess(vec2[ndx], vec1[ndx]))
6527 return false;
6528 }
6529
6530 return false;
6531 }
6532 };
6533
6534 template <typename T, int Rows, int Cols>
6535 struct InputLess<Matrix<T, Rows, Cols> >
6536 {
operator ()vkt::shaderexecutor::InputLess6537 bool operator() (const Matrix<T, Rows, Cols>& mat1,
6538 const Matrix<T, Rows, Cols>& mat2) const
6539 {
6540 for (int col = 0; col < Cols; ++col)
6541 {
6542 if (inputLess(mat1[col], mat2[col]))
6543 return true;
6544 if (inputLess(mat2[col], mat1[col]))
6545 return false;
6546 }
6547
6548 return false;
6549 }
6550 };
6551
6552 template <typename In>
6553 struct InTuple :
6554 public Tuple4<typename In::In0, typename In::In1, typename In::In2, typename In::In3>
6555 {
InTuplevkt::shaderexecutor::InTuple6556 InTuple (const typename In::In0& in0,
6557 const typename In::In1& in1,
6558 const typename In::In2& in2,
6559 const typename In::In3& in3)
6560 : Tuple4<typename In::In0, typename In::In1, typename In::In2, typename In::In3>
6561 (in0, in1, in2, in3) {}
6562 };
6563
6564 template <typename In>
6565 struct InputLess<InTuple<In> >
6566 {
operator ()vkt::shaderexecutor::InputLess6567 bool operator() (const InTuple<In>& in1, const InTuple<In>& in2) const
6568 {
6569 if (inputLess(in1.a, in2.a))
6570 return true;
6571 if (inputLess(in2.a, in1.a))
6572 return false;
6573 if (inputLess(in1.b, in2.b))
6574 return true;
6575 if (inputLess(in2.b, in1.b))
6576 return false;
6577 if (inputLess(in1.c, in2.c))
6578 return true;
6579 if (inputLess(in2.c, in1.c))
6580 return false;
6581 if (inputLess(in1.d, in2.d))
6582 return true;
6583 return false;
6584 };
6585 };
6586
6587 template<typename In>
generateInputs(const Samplings<In> & samplings,const FloatFormat & floatFormat,Precision intPrecision,size_t numSamples,deUint32 seed,const Interval & inputRange)6588 Inputs<In> generateInputs (const Samplings<In>& samplings,
6589 const FloatFormat& floatFormat,
6590 Precision intPrecision,
6591 size_t numSamples,
6592 deUint32 seed,
6593 const Interval& inputRange)
6594 {
6595 Random rnd(seed);
6596 Inputs<In> ret;
6597 Inputs<In> fixedInputs;
6598 set<InTuple<In>, InputLess<InTuple<In> > > seenInputs;
6599
6600 samplings.in0.genFixeds(floatFormat, intPrecision, fixedInputs.in0, inputRange);
6601 samplings.in1.genFixeds(floatFormat, intPrecision, fixedInputs.in1, inputRange);
6602 samplings.in2.genFixeds(floatFormat, intPrecision, fixedInputs.in2, inputRange);
6603 samplings.in3.genFixeds(floatFormat, intPrecision, fixedInputs.in3, inputRange);
6604
6605 for (size_t ndx0 = 0; ndx0 < fixedInputs.in0.size(); ++ndx0)
6606 {
6607 for (size_t ndx1 = 0; ndx1 < fixedInputs.in1.size(); ++ndx1)
6608 {
6609 for (size_t ndx2 = 0; ndx2 < fixedInputs.in2.size(); ++ndx2)
6610 {
6611 for (size_t ndx3 = 0; ndx3 < fixedInputs.in3.size(); ++ndx3)
6612 {
6613 const InTuple<In> tuple (fixedInputs.in0[ndx0],
6614 fixedInputs.in1[ndx1],
6615 fixedInputs.in2[ndx2],
6616 fixedInputs.in3[ndx3]);
6617
6618 seenInputs.insert(tuple);
6619 ret.in0.push_back(tuple.a);
6620 ret.in1.push_back(tuple.b);
6621 ret.in2.push_back(tuple.c);
6622 ret.in3.push_back(tuple.d);
6623 }
6624 }
6625 }
6626 }
6627
6628 for (size_t ndx = 0; ndx < numSamples; ++ndx)
6629 {
6630 const typename In::In0 in0 = samplings.in0.genRandom(floatFormat, intPrecision, rnd, inputRange);
6631 const typename In::In1 in1 = samplings.in1.genRandom(floatFormat, intPrecision, rnd, inputRange);
6632 const typename In::In2 in2 = samplings.in2.genRandom(floatFormat, intPrecision, rnd, inputRange);
6633 const typename In::In3 in3 = samplings.in3.genRandom(floatFormat, intPrecision, rnd, inputRange);
6634 const InTuple<In> tuple (in0, in1, in2, in3);
6635
6636 if (de::contains(seenInputs, tuple))
6637 continue;
6638
6639 seenInputs.insert(tuple);
6640 ret.in0.push_back(in0);
6641 ret.in1.push_back(in1);
6642 ret.in2.push_back(in2);
6643 ret.in3.push_back(in3);
6644 }
6645
6646 return ret;
6647 }
6648
6649 class FuncCaseBase : public PrecisionCase
6650 {
6651 protected:
FuncCaseBase(const CaseContext & context,const string & name,const FuncBase & func)6652 FuncCaseBase (const CaseContext& context, const string& name, const FuncBase& func)
6653 : PrecisionCase (context, name, func.getInputRange(!context.isFloat64b && (context.precision == glu::PRECISION_LAST || context.isPackFloat16b)), func.getRequiredExtension())
6654 {
6655 }
6656
6657 StatementP m_stmt;
6658 };
6659
6660 template <typename Sig>
6661 class FuncCase : public FuncCaseBase
6662 {
6663 public:
6664 typedef Func<Sig> CaseFunc;
6665 typedef typename Sig::Ret Ret;
6666 typedef typename Sig::Arg0 Arg0;
6667 typedef typename Sig::Arg1 Arg1;
6668 typedef typename Sig::Arg2 Arg2;
6669 typedef typename Sig::Arg3 Arg3;
6670 typedef InTypes<Arg0, Arg1, Arg2, Arg3> In;
6671 typedef OutTypes<Ret> Out;
6672
FuncCase(const CaseContext & context,const string & name,const CaseFunc & func,bool modularOp=false)6673 FuncCase (const CaseContext& context, const string& name, const CaseFunc& func, bool modularOp = false)
6674 : FuncCaseBase (context, name, func)
6675 , m_func (func)
6676 , m_modularOp (modularOp)
6677 {
6678 buildTest();
6679 }
6680
createInstance(Context & context) const6681 virtual TestInstance* createInstance (Context& context) const
6682 {
6683 return new BuiltinPrecisionCaseTestInstance<In, Out>(context, m_ctx, m_spec, m_variables, getSamplings(), m_stmt, m_modularOp);
6684 }
6685
6686 protected:
6687 void buildTest (void);
getSamplings(void) const6688 virtual const Samplings<In>& getSamplings (void) const
6689 {
6690 return instance<DefaultSamplings<In> >();
6691 }
6692
6693 private:
6694 const CaseFunc& m_func;
6695 Variables<In, Out> m_variables;
6696 bool m_modularOp;
6697 };
6698
6699 template <typename Sig>
buildTest(void)6700 void FuncCase<Sig>::buildTest (void)
6701 {
6702 m_variables.out0 = variable<Ret>("out0");
6703 m_variables.out1 = variable<Void>("out1");
6704 m_variables.in0 = variable<Arg0>("in0");
6705 m_variables.in1 = variable<Arg1>("in1");
6706 m_variables.in2 = variable<Arg2>("in2");
6707 m_variables.in3 = variable<Arg3>("in3");
6708
6709 {
6710 ExprP<Ret> expr = applyVar(m_func, m_variables.in0, m_variables.in1, m_variables.in2, m_variables.in3);
6711 m_stmt = variableAssignment(m_variables.out0, expr);
6712
6713 this->testStatement(m_variables, *m_stmt, m_func.getSpirvCase());
6714 }
6715 }
6716
6717 template <typename Sig>
6718 class InOutFuncCase : public FuncCaseBase
6719 {
6720 public:
6721 typedef Func<Sig> CaseFunc;
6722 typedef typename Sig::Ret Ret;
6723 typedef typename Sig::Arg0 Arg0;
6724 typedef typename Sig::Arg1 Arg1;
6725 typedef typename Sig::Arg2 Arg2;
6726 typedef typename Sig::Arg3 Arg3;
6727 typedef InTypes<Arg0, Arg2, Arg3> In;
6728 typedef OutTypes<Ret, Arg1> Out;
6729
InOutFuncCase(const CaseContext & context,const string & name,const CaseFunc & func,bool modularOp=false)6730 InOutFuncCase (const CaseContext& context, const string& name, const CaseFunc& func, bool modularOp = false)
6731 : FuncCaseBase (context, name, func)
6732 , m_func (func)
6733 , m_modularOp (modularOp)
6734 {
6735 buildTest();
6736 }
createInstance(Context & context) const6737 virtual TestInstance* createInstance (Context& context) const
6738 {
6739 return new BuiltinPrecisionCaseTestInstance<In, Out>(context, m_ctx, m_spec, m_variables, getSamplings(), m_stmt, m_modularOp);
6740 }
6741
6742 protected:
6743 void buildTest (void);
getSamplings(void) const6744 virtual const Samplings<In>& getSamplings (void) const
6745 {
6746 return instance<DefaultSamplings<In> >();
6747 }
6748
6749 private:
6750 const CaseFunc& m_func;
6751 Variables<In, Out> m_variables;
6752 bool m_modularOp;
6753 };
6754
6755 template <typename Sig>
buildTest(void)6756 void InOutFuncCase<Sig>::buildTest (void)
6757 {
6758 m_variables.out0 = variable<Ret>("out0");
6759 m_variables.out1 = variable<Arg1>("out1");
6760 m_variables.in0 = variable<Arg0>("in0");
6761 m_variables.in1 = variable<Arg2>("in1");
6762 m_variables.in2 = variable<Arg3>("in2");
6763 m_variables.in3 = variable<Void>("in3");
6764
6765 {
6766 ExprP<Ret> expr = applyVar(m_func, m_variables.in0, m_variables.out1, m_variables.in1, m_variables.in2);
6767 m_stmt = variableAssignment(m_variables.out0, expr);
6768
6769 this->testStatement(m_variables, *m_stmt, m_func.getSpirvCase());
6770 }
6771 }
6772
6773 template <typename Sig>
createFuncCase(const CaseContext & context,const string & name,const Func<Sig> & func,bool modularOp=false)6774 PrecisionCase* createFuncCase (const CaseContext& context, const string& name, const Func<Sig>& func, bool modularOp = false)
6775 {
6776 switch (func.getOutParamIndex())
6777 {
6778 case -1:
6779 return new FuncCase<Sig>(context, name, func, modularOp);
6780 case 1:
6781 return new InOutFuncCase<Sig>(context, name, func, modularOp);
6782 default:
6783 DE_FATAL("Impossible");
6784 }
6785 return DE_NULL;
6786 }
6787
6788 class CaseFactory
6789 {
6790 public:
~CaseFactory(void)6791 virtual ~CaseFactory (void) {}
6792 virtual MovePtr<TestNode> createCase (const CaseContext& ctx) const = 0;
6793 virtual string getName (void) const = 0;
6794 virtual string getDesc (void) const = 0;
6795 };
6796
6797 class FuncCaseFactory : public CaseFactory
6798 {
6799 public:
6800 virtual const FuncBase& getFunc (void) const = 0;
getName(void) const6801 string getName (void) const { return de::toLower(getFunc().getName()); }
getDesc(void) const6802 string getDesc (void) const { return "Function '" + getFunc().getName() + "'"; }
6803 };
6804
6805 template <typename Sig>
6806 class GenFuncCaseFactory : public CaseFactory
6807 {
6808 public:
GenFuncCaseFactory(const GenFuncs<Sig> & funcs,const string & name,bool modularOp=false)6809 GenFuncCaseFactory (const GenFuncs<Sig>& funcs, const string& name, bool modularOp = false)
6810 : m_funcs (funcs)
6811 , m_name (de::toLower(name))
6812 , m_modularOp (modularOp)
6813 {
6814 }
6815
createCase(const CaseContext & ctx) const6816 MovePtr<TestNode> createCase (const CaseContext& ctx) const
6817 {
6818 TestCaseGroup* group = new TestCaseGroup(ctx.testContext, ctx.name.c_str(), ctx.name.c_str());
6819
6820 group->addChild(createFuncCase(ctx, "scalar", m_funcs.func, m_modularOp));
6821 group->addChild(createFuncCase(ctx, "vec2", m_funcs.func2, m_modularOp));
6822 group->addChild(createFuncCase(ctx, "vec3", m_funcs.func3, m_modularOp));
6823 group->addChild(createFuncCase(ctx, "vec4", m_funcs.func4, m_modularOp));
6824 return MovePtr<TestNode>(group);
6825 }
6826
getName(void) const6827 string getName (void) const { return m_name; }
getDesc(void) const6828 string getDesc (void) const { return "Function '" + m_funcs.func.getName() + "'"; }
6829
6830 private:
6831 const GenFuncs<Sig> m_funcs;
6832 string m_name;
6833 bool m_modularOp;
6834 };
6835
6836 template <template <int, class> class GenF, typename T>
6837 class TemplateFuncCaseFactory : public FuncCaseFactory
6838 {
6839 public:
createCase(const CaseContext & ctx) const6840 MovePtr<TestNode> createCase (const CaseContext& ctx) const
6841 {
6842 TestCaseGroup* group = new TestCaseGroup(ctx.testContext, ctx.name.c_str(), ctx.name.c_str());
6843
6844 group->addChild(createFuncCase(ctx, "scalar", instance<GenF<1, T> >()));
6845 group->addChild(createFuncCase(ctx, "vec2", instance<GenF<2, T> >()));
6846 group->addChild(createFuncCase(ctx, "vec3", instance<GenF<3, T> >()));
6847 group->addChild(createFuncCase(ctx, "vec4", instance<GenF<4, T> >()));
6848
6849 return MovePtr<TestNode>(group);
6850 }
6851
getFunc(void) const6852 const FuncBase& getFunc (void) const { return instance<GenF<1, T> >(); }
6853 };
6854
6855 template <template <int> class GenF>
6856 class SquareMatrixFuncCaseFactory : public FuncCaseFactory
6857 {
6858 public:
createCase(const CaseContext & ctx) const6859 MovePtr<TestNode> createCase (const CaseContext& ctx) const
6860 {
6861 TestCaseGroup* group = new TestCaseGroup(ctx.testContext, ctx.name.c_str(), ctx.name.c_str());
6862
6863 group->addChild(createFuncCase(ctx, "mat2", instance<GenF<2> >()));
6864 #if 0
6865 // disabled until we get reasonable results
6866 group->addChild(createFuncCase(ctx, "mat3", instance<GenF<3> >()));
6867 group->addChild(createFuncCase(ctx, "mat4", instance<GenF<4> >()));
6868 #endif
6869
6870 return MovePtr<TestNode>(group);
6871 }
6872
getFunc(void) const6873 const FuncBase& getFunc (void) const { return instance<GenF<2> >(); }
6874 };
6875
6876 template <template <int, int, class> class GenF, typename T>
6877 class MatrixFuncCaseFactory : public FuncCaseFactory
6878 {
6879 public:
createCase(const CaseContext & ctx) const6880 MovePtr<TestNode> createCase (const CaseContext& ctx) const
6881 {
6882 TestCaseGroup* const group = new TestCaseGroup(ctx.testContext, ctx.name.c_str(), ctx.name.c_str());
6883
6884 this->addCase<2, 2>(ctx, group);
6885 this->addCase<3, 2>(ctx, group);
6886 this->addCase<4, 2>(ctx, group);
6887 this->addCase<2, 3>(ctx, group);
6888 this->addCase<3, 3>(ctx, group);
6889 this->addCase<4, 3>(ctx, group);
6890 this->addCase<2, 4>(ctx, group);
6891 this->addCase<3, 4>(ctx, group);
6892 this->addCase<4, 4>(ctx, group);
6893
6894 return MovePtr<TestNode>(group);
6895 }
6896
getFunc(void) const6897 const FuncBase& getFunc (void) const { return instance<GenF<2,2, T> >(); }
6898
6899 private:
6900 template <int Rows, int Cols>
addCase(const CaseContext & ctx,TestCaseGroup * group) const6901 void addCase (const CaseContext& ctx, TestCaseGroup* group) const
6902 {
6903 const char* const name = dataTypeNameOf<Matrix<float, Rows, Cols> >();
6904 group->addChild(createFuncCase(ctx, name, instance<GenF<Rows, Cols, T> >()));
6905 }
6906 };
6907
6908 template <typename Sig>
6909 class SimpleFuncCaseFactory : public CaseFactory
6910 {
6911 public:
SimpleFuncCaseFactory(const Func<Sig> & func)6912 SimpleFuncCaseFactory (const Func<Sig>& func) : m_func(func) {}
6913
createCase(const CaseContext & ctx) const6914 MovePtr<TestNode> createCase (const CaseContext& ctx) const { return MovePtr<TestNode>(createFuncCase(ctx, ctx.name.c_str(), m_func)); }
getName(void) const6915 string getName (void) const { return de::toLower(m_func.getName()); }
getDesc(void) const6916 string getDesc (void) const { return "Function '" + getName() + "'"; }
6917
6918 private:
6919 const Func<Sig>& m_func;
6920 };
6921
6922 template <typename F>
createSimpleFuncCaseFactory(void)6923 SharedPtr<SimpleFuncCaseFactory<typename F::Sig> > createSimpleFuncCaseFactory (void)
6924 {
6925 return SharedPtr<SimpleFuncCaseFactory<typename F::Sig> >(new SimpleFuncCaseFactory<typename F::Sig>(instance<F>()));
6926 }
6927
6928 class CaseFactories
6929 {
6930 public:
~CaseFactories(void)6931 virtual ~CaseFactories (void) {}
6932 virtual const std::vector<const CaseFactory*> getFactories (void) const = 0;
6933 };
6934
6935 class BuiltinFuncs : public CaseFactories
6936 {
6937 public:
getFactories(void) const6938 const vector<const CaseFactory*> getFactories (void) const
6939 {
6940 vector<const CaseFactory*> ret;
6941
6942 for (size_t ndx = 0; ndx < m_factories.size(); ++ndx)
6943 ret.push_back(m_factories[ndx].get());
6944
6945 return ret;
6946 }
6947
addFactory(SharedPtr<const CaseFactory> fact)6948 void addFactory (SharedPtr<const CaseFactory> fact) { m_factories.push_back(fact); }
6949
6950 private:
6951 vector<SharedPtr<const CaseFactory> > m_factories;
6952 };
6953
6954 template <typename F>
addScalarFactory(BuiltinFuncs & funcs,string name="",bool modularOp=false)6955 void addScalarFactory (BuiltinFuncs& funcs, string name = "", bool modularOp = false)
6956 {
6957 if (name.empty())
6958 name = instance<F>().getName();
6959
6960 funcs.addFactory(SharedPtr<const CaseFactory>(new GenFuncCaseFactory<typename F::Sig>(makeVectorizedFuncs<F>(), name, modularOp)));
6961 }
6962
createBuiltinCases()6963 MovePtr<const CaseFactories> createBuiltinCases ()
6964 {
6965 MovePtr<BuiltinFuncs> funcs (new BuiltinFuncs());
6966
6967 // Tests for ES3 builtins
6968 addScalarFactory<Comparison< Signature<int, float, float> > >(*funcs);
6969 addScalarFactory<Add< Signature<float, float, float> > >(*funcs);
6970 addScalarFactory<Sub< Signature<float, float, float> > >(*funcs);
6971 addScalarFactory<Mul< Signature<float, float, float> > >(*funcs);
6972 addScalarFactory<Div< Signature<float, float, float> > >(*funcs);
6973
6974 addScalarFactory<Radians>(*funcs);
6975 addScalarFactory<Degrees>(*funcs);
6976 addScalarFactory<Sin<Signature<float, float> > >(*funcs);
6977 addScalarFactory<Cos<Signature<float, float> > >(*funcs);
6978 addScalarFactory<Tan>(*funcs);
6979
6980 addScalarFactory<ASin>(*funcs);
6981 addScalarFactory<ACos>(*funcs);
6982 addScalarFactory<ATan2< Signature<float, float, float> > >(*funcs, "atan2");
6983 addScalarFactory<ATan<Signature<float, float> > >(*funcs);
6984 addScalarFactory<Sinh>(*funcs);
6985 addScalarFactory<Cosh>(*funcs);
6986 addScalarFactory<Tanh>(*funcs);
6987 addScalarFactory<ASinh>(*funcs);
6988 addScalarFactory<ACosh>(*funcs);
6989 addScalarFactory<ATanh>(*funcs);
6990
6991 addScalarFactory<Pow>(*funcs);
6992 addScalarFactory<Exp<Signature<float, float> > >(*funcs);
6993 addScalarFactory<Log< Signature<float, float> > >(*funcs);
6994 addScalarFactory<Exp2<Signature<float, float> > >(*funcs);
6995 addScalarFactory<Log2< Signature<float, float> > >(*funcs);
6996 addScalarFactory<Sqrt32Bit>(*funcs);
6997 addScalarFactory<InverseSqrt< Signature<float, float> > >(*funcs);
6998
6999 addScalarFactory<Abs< Signature<float, float> > >(*funcs);
7000 addScalarFactory<Sign< Signature<float, float> > >(*funcs);
7001 addScalarFactory<Floor32Bit>(*funcs);
7002 addScalarFactory<Trunc32Bit>(*funcs);
7003 addScalarFactory<Round< Signature<float, float> > >(*funcs);
7004 addScalarFactory<RoundEven< Signature<float, float> > >(*funcs);
7005 addScalarFactory<Ceil< Signature<float, float> > >(*funcs);
7006 addScalarFactory<Fract>(*funcs);
7007
7008 addScalarFactory<Mod32Bit>(*funcs, "mod", true);
7009 addScalarFactory<FRem32Bit>(*funcs);
7010
7011 addScalarFactory<Modf32Bit>(*funcs);
7012 addScalarFactory<ModfStruct32Bit>(*funcs);
7013 addScalarFactory<Min< Signature<float, float, float> > >(*funcs);
7014 addScalarFactory<Max< Signature<float, float, float> > >(*funcs);
7015 addScalarFactory<Clamp< Signature<float, float, float, float> > >(*funcs);
7016 addScalarFactory<Mix>(*funcs);
7017 addScalarFactory<Step< Signature<float, float, float> > >(*funcs);
7018 addScalarFactory<SmoothStep< Signature<float, float, float, float> > >(*funcs);
7019
7020 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Length, float>()));
7021 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Distance, float>()));
7022 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Dot, float>()));
7023 funcs->addFactory(createSimpleFuncCaseFactory<Cross>());
7024 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Normalize, float>()));
7025 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<FaceForward, float>()));
7026 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Reflect, float>()));
7027 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Refract, float>()));
7028
7029 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<MatrixCompMult, float>()));
7030 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<OuterProduct, float>()));
7031 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<Transpose, float>()));
7032 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Determinant>()));
7033 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Inverse>()));
7034
7035 addScalarFactory<Frexp32Bit>(*funcs);
7036 addScalarFactory<FrexpStruct32Bit>(*funcs);
7037 addScalarFactory<LdExp <Signature<float, float, int> > >(*funcs);
7038 addScalarFactory<Fma <Signature<float, float, float, float> > >(*funcs);
7039
7040 return MovePtr<const CaseFactories>(funcs.release());
7041 }
7042
createBuiltinDoubleCases()7043 MovePtr<const CaseFactories> createBuiltinDoubleCases ()
7044 {
7045 MovePtr<BuiltinFuncs> funcs (new BuiltinFuncs());
7046
7047 // Tests for ES3 builtins
7048 addScalarFactory<Comparison<Signature<int, double, double>>>(*funcs);
7049 addScalarFactory<Add<Signature<double, double, double>>>(*funcs);
7050 addScalarFactory<Sub<Signature<double, double, double>>>(*funcs);
7051 addScalarFactory<Mul<Signature<double, double, double>>>(*funcs);
7052 addScalarFactory<Div<Signature<double, double, double>>>(*funcs);
7053
7054 // Radians, degrees, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh, atan2, pow, exp, log, exp2 and log2
7055 // only work with 16-bit and 32-bit floating point types according to the spec.
7056 #if 0
7057 addScalarFactory<Radians64>(*funcs);
7058 addScalarFactory<Degrees64>(*funcs);
7059 addScalarFactory<Sin<Signature<double, double>>>(*funcs);
7060 addScalarFactory<Cos<Signature<double, double>>>(*funcs);
7061 addScalarFactory<Tan64Bit>(*funcs);
7062 addScalarFactory<ASin64Bit>(*funcs);
7063 addScalarFactory<ACos64Bit>(*funcs);
7064 addScalarFactory<ATan2<Signature<double, double, double>>>(*funcs, "atan2");
7065 addScalarFactory<ATan<Signature<double, double>>>(*funcs);
7066 addScalarFactory<Sinh64Bit>(*funcs);
7067 addScalarFactory<Cosh64Bit>(*funcs);
7068 addScalarFactory<Tanh64Bit>(*funcs);
7069 addScalarFactory<ASinh64Bit>(*funcs);
7070 addScalarFactory<ACosh64Bit>(*funcs);
7071 addScalarFactory<ATanh64Bit>(*funcs);
7072
7073 addScalarFactory<Pow64>(*funcs);
7074 addScalarFactory<Exp<Signature<double, double>>>(*funcs);
7075 addScalarFactory<Log<Signature<double, double>>>(*funcs);
7076 addScalarFactory<Exp2<Signature<double, double>>>(*funcs);
7077 addScalarFactory<Log2<Signature<double, double>>>(*funcs);
7078 #endif
7079 addScalarFactory<Sqrt64Bit>(*funcs);
7080 addScalarFactory<InverseSqrt<Signature<double, double>>>(*funcs);
7081
7082 addScalarFactory<Abs<Signature<double, double>>>(*funcs);
7083 addScalarFactory<Sign<Signature<double, double>>>(*funcs);
7084 addScalarFactory<Floor64Bit>(*funcs);
7085 addScalarFactory<Trunc64Bit>(*funcs);
7086 addScalarFactory<Round<Signature<double, double>>>(*funcs);
7087 addScalarFactory<RoundEven<Signature<double, double>>>(*funcs);
7088 addScalarFactory<Ceil<Signature<double, double>>>(*funcs);
7089 addScalarFactory<Fract64Bit>(*funcs);
7090
7091 addScalarFactory<Mod64Bit>(*funcs, "mod", true);
7092 addScalarFactory<FRem64Bit>(*funcs);
7093
7094 addScalarFactory<Modf64Bit>(*funcs);
7095 addScalarFactory<ModfStruct64Bit>(*funcs);
7096 addScalarFactory<Min<Signature<double, double, double>>>(*funcs);
7097 addScalarFactory<Max<Signature<double, double, double>>>(*funcs);
7098 addScalarFactory<Clamp<Signature<double, double, double, double>>>(*funcs);
7099 addScalarFactory<Mix64Bit>(*funcs);
7100 addScalarFactory<Step<Signature<double, double, double>>>(*funcs);
7101 addScalarFactory<SmoothStep<Signature<double, double, double, double>>>(*funcs);
7102
7103 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Length, double>()));
7104 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Distance, double>()));
7105 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Dot, double>()));
7106 funcs->addFactory(createSimpleFuncCaseFactory<Cross64Bit>());
7107 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Normalize, double>()));
7108 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<FaceForward, double>()));
7109 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Reflect, double>()));
7110 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Refract, double>()));
7111
7112 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<MatrixCompMult, double>()));
7113 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<OuterProduct, double>()));
7114 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<Transpose, double>()));
7115 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Determinant64bit>()));
7116 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Inverse64bit>()));
7117
7118 addScalarFactory<Frexp64Bit>(*funcs);
7119 addScalarFactory<FrexpStruct64Bit>(*funcs);
7120 addScalarFactory<LdExp<Signature<double, double, int>>>(*funcs);
7121 addScalarFactory<Fma<Signature<double, double, double, double>>>(*funcs);
7122
7123 return MovePtr<const CaseFactories>(funcs.release());
7124 }
7125
createBuiltinCases16Bit(void)7126 MovePtr<const CaseFactories> createBuiltinCases16Bit(void)
7127 {
7128 MovePtr<BuiltinFuncs> funcs(new BuiltinFuncs());
7129
7130 addScalarFactory<Comparison< Signature<int, deFloat16, deFloat16> > >(*funcs);
7131 addScalarFactory<Add< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7132 addScalarFactory<Sub< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7133 addScalarFactory<Mul< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7134 addScalarFactory<Div< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7135
7136 addScalarFactory<Radians16>(*funcs);
7137 addScalarFactory<Degrees16>(*funcs);
7138
7139 addScalarFactory<Sin<Signature<deFloat16, deFloat16> > >(*funcs);
7140 addScalarFactory<Cos<Signature<deFloat16, deFloat16> > >(*funcs);
7141 addScalarFactory<Tan16Bit>(*funcs);
7142 addScalarFactory<ASin16Bit>(*funcs);
7143 addScalarFactory<ACos16Bit>(*funcs);
7144 addScalarFactory<ATan2< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs, "atan2");
7145 addScalarFactory<ATan<Signature<deFloat16, deFloat16> > >(*funcs);
7146
7147 addScalarFactory<Sinh16Bit>(*funcs);
7148 addScalarFactory<Cosh16Bit>(*funcs);
7149 addScalarFactory<Tanh16Bit>(*funcs);
7150 addScalarFactory<ASinh16Bit>(*funcs);
7151 addScalarFactory<ACosh16Bit>(*funcs);
7152 addScalarFactory<ATanh16Bit>(*funcs);
7153
7154 addScalarFactory<Pow16>(*funcs);
7155 addScalarFactory<Exp< Signature<deFloat16, deFloat16> > >(*funcs);
7156 addScalarFactory<Log< Signature<deFloat16, deFloat16> > >(*funcs);
7157 addScalarFactory<Exp2< Signature<deFloat16, deFloat16> > >(*funcs);
7158 addScalarFactory<Log2< Signature<deFloat16, deFloat16> > >(*funcs);
7159 addScalarFactory<Sqrt16Bit>(*funcs);
7160 addScalarFactory<InverseSqrt16Bit>(*funcs);
7161
7162 addScalarFactory<Abs< Signature<deFloat16, deFloat16> > >(*funcs);
7163 addScalarFactory<Sign< Signature<deFloat16, deFloat16> > >(*funcs);
7164 addScalarFactory<Floor16Bit>(*funcs);
7165 addScalarFactory<Trunc16Bit>(*funcs);
7166 addScalarFactory<Round< Signature<deFloat16, deFloat16> > >(*funcs);
7167 addScalarFactory<RoundEven< Signature<deFloat16, deFloat16> > >(*funcs);
7168 addScalarFactory<Ceil< Signature<deFloat16, deFloat16> > >(*funcs);
7169 addScalarFactory<Fract16Bit>(*funcs);
7170
7171 addScalarFactory<Mod16Bit>(*funcs, "mod", true);
7172 addScalarFactory<FRem16Bit>(*funcs);
7173
7174 addScalarFactory<Modf16Bit>(*funcs);
7175 addScalarFactory<ModfStruct16Bit>(*funcs);
7176 addScalarFactory<Min< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7177 addScalarFactory<Max< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7178 addScalarFactory<Clamp< Signature<deFloat16, deFloat16, deFloat16, deFloat16> > >(*funcs);
7179 addScalarFactory<Mix16Bit>(*funcs);
7180 addScalarFactory<Step< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7181 addScalarFactory<SmoothStep< Signature<deFloat16, deFloat16, deFloat16, deFloat16> > >(*funcs);
7182
7183 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Length, deFloat16>()));
7184 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Distance, deFloat16>()));
7185 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Dot, deFloat16>()));
7186 funcs->addFactory(createSimpleFuncCaseFactory<Cross16Bit>());
7187 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Normalize, deFloat16>()));
7188 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<FaceForward, deFloat16>()));
7189 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Reflect, deFloat16>()));
7190 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Refract, deFloat16>()));
7191
7192 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<OuterProduct, deFloat16>()));
7193 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<Transpose, deFloat16>()));
7194 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Determinant16bit>()));
7195 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Inverse16bit>()));
7196
7197 addScalarFactory<Frexp16Bit>(*funcs);
7198 addScalarFactory<FrexpStruct16Bit>(*funcs);
7199 addScalarFactory<LdExp <Signature<deFloat16, deFloat16, int> > >(*funcs);
7200 addScalarFactory<Fma <Signature<deFloat16, deFloat16, deFloat16, deFloat16> > >(*funcs);
7201
7202 return MovePtr<const CaseFactories>(funcs.release());
7203 }
7204
createFuncGroup(TestContext & ctx,const CaseFactory & factory,int numRandoms)7205 TestCaseGroup* createFuncGroup (TestContext& ctx, const CaseFactory& factory, int numRandoms)
7206 {
7207 TestCaseGroup* const group = new TestCaseGroup(ctx, factory.getName().c_str(), factory.getDesc().c_str());
7208 const FloatFormat highp (-126, 127, 23, true,
7209 tcu::MAYBE, // subnormals
7210 tcu::YES, // infinities
7211 tcu::MAYBE); // NaN
7212 const FloatFormat mediump (-14, 13, 10, false, tcu::MAYBE);
7213
7214 for (int precNdx = glu::PRECISION_MEDIUMP; precNdx < glu::PRECISION_LAST; ++precNdx)
7215 {
7216 const Precision precision = Precision(precNdx);
7217 const string precName (glu::getPrecisionName(precision));
7218 const FloatFormat& fmt = precNdx == glu::PRECISION_MEDIUMP ? mediump : highp;
7219
7220 const CaseContext caseCtx (precName, ctx, fmt, highp, precision, glu::SHADERTYPE_COMPUTE, numRandoms);
7221
7222 group->addChild(factory.createCase(caseCtx).release());
7223 }
7224
7225 return group;
7226 }
7227
createFuncGroupDouble(TestContext & ctx,const CaseFactory & factory,int numRandoms)7228 TestCaseGroup* createFuncGroupDouble (TestContext& ctx, const CaseFactory& factory, int numRandoms)
7229 {
7230 TestCaseGroup* const group = new TestCaseGroup(ctx, factory.getName().c_str(), factory.getDesc().c_str());
7231 const Precision precision = Precision(glu::PRECISION_LAST);
7232 const FloatFormat highp (-1022, 1023, 52, true,
7233 tcu::MAYBE, // subnormals
7234 tcu::YES, // infinities
7235 tcu::MAYBE); // NaN
7236
7237 PrecisionTestFeatures precisionTestFeatures = PRECISION_TEST_FEATURES_64BIT_SHADER_FLOAT;
7238
7239 const CaseContext caseCtx("compute", ctx, highp, highp, precision, glu::SHADERTYPE_COMPUTE, numRandoms, precisionTestFeatures, false, true);
7240 group->addChild(factory.createCase(caseCtx).release());
7241
7242 return group;
7243 }
7244
createFuncGroup16Bit(TestContext & ctx,const CaseFactory & factory,int numRandoms,bool storage32)7245 TestCaseGroup* createFuncGroup16Bit(TestContext& ctx, const CaseFactory& factory, int numRandoms, bool storage32)
7246 {
7247 TestCaseGroup* const group = new TestCaseGroup(ctx, factory.getName().c_str(), factory.getDesc().c_str());
7248 const Precision precision = Precision(glu::PRECISION_LAST);
7249 const FloatFormat float16 (-14, 15, 10, true, tcu::MAYBE);
7250
7251 PrecisionTestFeatures precisionTestFeatures = PRECISION_TEST_FEATURES_16BIT_SHADER_FLOAT;
7252 if (!storage32)
7253 precisionTestFeatures |= PRECISION_TEST_FEATURES_16BIT_UNIFORM_AND_STORAGE_BUFFER_ACCESS;
7254
7255 const CaseContext caseCtx("compute", ctx, float16, float16, precision, glu::SHADERTYPE_COMPUTE, numRandoms, precisionTestFeatures, storage32);
7256 group->addChild(factory.createCase(caseCtx).release());
7257
7258 return group;
7259 }
7260
7261 const int defRandoms = 16384;
7262
addBuiltinPrecisionTests(TestContext & ctx,TestCaseGroup & dstGroup,const bool test16Bit=false,const bool storage32Bit=false)7263 void addBuiltinPrecisionTests (TestContext& ctx,
7264 TestCaseGroup& dstGroup,
7265 const bool test16Bit = false,
7266 const bool storage32Bit = false)
7267 {
7268 const int userRandoms = ctx.getCommandLine().getTestIterationCount();
7269 const int numRandoms = userRandoms > 0 ? userRandoms : defRandoms;
7270
7271 MovePtr<const CaseFactories> cases = (test16Bit && !storage32Bit) ? createBuiltinCases16Bit()
7272 : createBuiltinCases();
7273 for (size_t ndx = 0; ndx < cases->getFactories().size(); ++ndx)
7274 {
7275 if (!test16Bit)
7276 dstGroup.addChild(createFuncGroup(ctx, *cases->getFactories()[ndx], numRandoms));
7277 else
7278 dstGroup.addChild(createFuncGroup16Bit(ctx, *cases->getFactories()[ndx], numRandoms, storage32Bit));
7279 }
7280 }
7281
addBuiltinPrecisionDoubleTests(TestContext & ctx,TestCaseGroup & dstGroup)7282 void addBuiltinPrecisionDoubleTests (TestContext& ctx,
7283 TestCaseGroup& dstGroup)
7284 {
7285 const int userRandoms = ctx.getCommandLine().getTestIterationCount();
7286 const int numRandoms = userRandoms > 0 ? userRandoms : defRandoms;
7287
7288 MovePtr<const CaseFactories> cases = createBuiltinDoubleCases();
7289 for (size_t ndx = 0; ndx < cases->getFactories().size(); ++ndx)
7290 {
7291 dstGroup.addChild(createFuncGroupDouble(ctx, *cases->getFactories()[ndx], numRandoms));
7292 }
7293 }
7294
BuiltinPrecisionTests(tcu::TestContext & testCtx)7295 BuiltinPrecisionTests::BuiltinPrecisionTests (tcu::TestContext& testCtx)
7296 : tcu::TestCaseGroup(testCtx, "precision", "Builtin precision tests")
7297 {
7298 }
7299
~BuiltinPrecisionTests(void)7300 BuiltinPrecisionTests::~BuiltinPrecisionTests (void)
7301 {
7302 }
7303
init(void)7304 void BuiltinPrecisionTests::init (void)
7305 {
7306 addBuiltinPrecisionTests(m_testCtx, *this);
7307 }
7308
BuiltinPrecisionDoubleTests(tcu::TestContext & testCtx)7309 BuiltinPrecisionDoubleTests::BuiltinPrecisionDoubleTests (tcu::TestContext& testCtx)
7310 : tcu::TestCaseGroup(testCtx, "precision_double", "Builtin precision tests")
7311 {
7312 }
7313
~BuiltinPrecisionDoubleTests(void)7314 BuiltinPrecisionDoubleTests::~BuiltinPrecisionDoubleTests (void)
7315 {
7316 }
7317
init(void)7318 void BuiltinPrecisionDoubleTests::init (void)
7319 {
7320 addBuiltinPrecisionDoubleTests(m_testCtx, *this);
7321 }
7322
BuiltinPrecision16BitTests(tcu::TestContext & testCtx)7323 BuiltinPrecision16BitTests::BuiltinPrecision16BitTests (tcu::TestContext& testCtx)
7324 : tcu::TestCaseGroup(testCtx, "precision_fp16_storage16b", "Builtin precision tests")
7325 {
7326 }
7327
~BuiltinPrecision16BitTests(void)7328 BuiltinPrecision16BitTests::~BuiltinPrecision16BitTests (void)
7329 {
7330 }
7331
init(void)7332 void BuiltinPrecision16BitTests::init (void)
7333 {
7334 addBuiltinPrecisionTests(m_testCtx, *this, true);
7335 }
7336
BuiltinPrecision16Storage32BitTests(tcu::TestContext & testCtx)7337 BuiltinPrecision16Storage32BitTests::BuiltinPrecision16Storage32BitTests(tcu::TestContext& testCtx)
7338 : tcu::TestCaseGroup(testCtx, "precision_fp16_storage32b", "Builtin precision tests")
7339 {
7340 }
7341
~BuiltinPrecision16Storage32BitTests(void)7342 BuiltinPrecision16Storage32BitTests::~BuiltinPrecision16Storage32BitTests(void)
7343 {
7344 }
7345
init(void)7346 void BuiltinPrecision16Storage32BitTests::init(void)
7347 {
7348 addBuiltinPrecisionTests(m_testCtx, *this, true, true);
7349 }
7350
7351 } // shaderexecutor
7352 } // vkt
7353