1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2018 The Khronos Group Inc.
6 * Copyright (c) 2015 Samsung Electronics Co., Ltd.
7 * Copyright (c) 2016 The Android Open Source Project
8 *
9 * Licensed under the Apache License, Version 2.0 (the "License");
10 * you may not use this file except in compliance with the License.
11 * You may obtain a copy of the License at
12 *
13 * http://www.apache.org/licenses/LICENSE-2.0
14 *
15 * Unless required by applicable law or agreed to in writing, software
16 * distributed under the License is distributed on an "AS IS" BASIS,
17 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 * See the License for the specific language governing permissions and
19 * limitations under the License.
20 *
21 *//*!
22 * \file
23 * \brief Precision and range tests for builtins and types.
24 *
25 *//*--------------------------------------------------------------------*/
26
27 #include "vktShaderBuiltinPrecisionTests.hpp"
28 #include "vktShaderExecutor.hpp"
29 #include "amber/vktAmberTestCase.hpp"
30
31 #include "deMath.h"
32 #include "deMemory.h"
33 #include "deFloat16.h"
34 #include "deDefs.hpp"
35 #include "deRandom.hpp"
36 #include "deSTLUtil.hpp"
37 #include "deStringUtil.hpp"
38 #include "deUniquePtr.hpp"
39 #include "deSharedPtr.hpp"
40 #include "deArrayUtil.hpp"
41
42 #include "tcuCommandLine.hpp"
43 #include "tcuFloatFormat.hpp"
44 #include "tcuInterval.hpp"
45 #include "tcuTestLog.hpp"
46 #include "tcuVector.hpp"
47 #include "tcuMatrix.hpp"
48 #include "tcuResultCollector.hpp"
49 #include "tcuMaybe.hpp"
50
51 #include "gluContextInfo.hpp"
52 #include "gluVarType.hpp"
53 #include "gluRenderContext.hpp"
54 #include "glwDefs.hpp"
55
56 #include <cmath>
57 #include <string>
58 #include <sstream>
59 #include <iostream>
60 #include <map>
61 #include <utility>
62 #include <limits>
63
64 // Uncomment this to get evaluation trace dumps to std::cerr
65 // #define GLS_ENABLE_TRACE
66
67 // set this to true to dump even passing results
68 #define GLS_LOG_ALL_RESULTS false
69
70 #define FLOAT16_1_0 0x3C00 //1.0 float16bit
71 #define FLOAT16_2_0 0x4000 //2.0 float16bit
72 #define FLOAT16_3_0 0x4200 //3.0 float16bit
73 #define FLOAT16_0_5 0x3800 //0.5 float16bit
74 #define FLOAT16_0_0 0x0000 //0.0 float16bit
75
76
77 using tcu::Vector;
78 typedef Vector<deFloat16, 1> Vec1_16Bit;
79 typedef Vector<deFloat16, 2> Vec2_16Bit;
80 typedef Vector<deFloat16, 3> Vec3_16Bit;
81 typedef Vector<deFloat16, 4> Vec4_16Bit;
82
83 typedef Vector<double, 1> Vec1_64Bit;
84 typedef Vector<double, 2> Vec2_64Bit;
85 typedef Vector<double, 3> Vec3_64Bit;
86 typedef Vector<double, 4> Vec4_64Bit;
87
88 enum
89 {
90 // Computing reference intervals can take a non-trivial amount of time, especially on
91 // platforms where toggling floating-point rounding mode is slow (emulated arm on x86).
92 // As a workaround watchdog is kept happy by touching it periodically during reference
93 // interval computation.
94 TOUCH_WATCHDOG_VALUE_FREQUENCY = 512
95 };
96
97 namespace vkt
98 {
99 namespace shaderexecutor
100 {
101
102 using std::string;
103 using std::map;
104 using std::ostream;
105 using std::ostringstream;
106 using std::pair;
107 using std::vector;
108 using std::set;
109
110 using de::MovePtr;
111 using de::Random;
112 using de::SharedPtr;
113 using de::UniquePtr;
114 using tcu::Interval;
115 using tcu::FloatFormat;
116 using tcu::MessageBuilder;
117 using tcu::TestLog;
118 using tcu::Vector;
119 using tcu::Matrix;
120 using glu::Precision;
121 using glu::VarType;
122 using glu::DataType;
123 using glu::ShaderType;
124
125 enum PrecisionTestFeatureBits
126 {
127 PRECISION_TEST_FEATURES_NONE = 0u,
128 PRECISION_TEST_FEATURES_16BIT_BUFFER_ACCESS = (1u << 1),
129 PRECISION_TEST_FEATURES_16BIT_UNIFORM_AND_STORAGE_BUFFER_ACCESS = (1u << 2),
130 PRECISION_TEST_FEATURES_16BIT_PUSH_CONSTANT = (1u << 3),
131 PRECISION_TEST_FEATURES_16BIT_INPUT_OUTPUT = (1u << 4),
132 PRECISION_TEST_FEATURES_16BIT_SHADER_FLOAT = (1u << 5),
133 PRECISION_TEST_FEATURES_64BIT_SHADER_FLOAT = (1u << 6),
134 };
135 typedef deUint32 PrecisionTestFeatures;
136
137
areFeaturesSupported(const Context & context,deUint32 toCheck)138 void areFeaturesSupported (const Context& context, deUint32 toCheck)
139 {
140 if (toCheck == PRECISION_TEST_FEATURES_NONE) return;
141
142 const vk::VkPhysicalDevice16BitStorageFeatures& extensionFeatures = context.get16BitStorageFeatures();
143
144 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_BUFFER_ACCESS) != 0 && extensionFeatures.storageBuffer16BitAccess == VK_FALSE)
145 TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
146
147 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_UNIFORM_AND_STORAGE_BUFFER_ACCESS) != 0 && extensionFeatures.uniformAndStorageBuffer16BitAccess == VK_FALSE)
148 TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
149
150 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_PUSH_CONSTANT) != 0 && extensionFeatures.storagePushConstant16 == VK_FALSE)
151 TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
152
153 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_INPUT_OUTPUT) != 0 && extensionFeatures.storageInputOutput16 == VK_FALSE)
154 TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
155
156 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_SHADER_FLOAT) != 0 && context.getShaderFloat16Int8Features().shaderFloat16 == VK_FALSE)
157 TCU_THROW(NotSupportedError, "Requested 16-bit floats (halfs) are not supported in shader code");
158
159 if ((toCheck & PRECISION_TEST_FEATURES_64BIT_SHADER_FLOAT) != 0 && context.getDeviceFeatures().shaderFloat64 == VK_FALSE)
160 TCU_THROW(NotSupportedError, "Requested 64-bit floats are not supported in shader code");
161 }
162
163 /*--------------------------------------------------------------------*//*!
164 * \brief Generic singleton creator.
165 *
166 * instance<T>() returns a reference to a unique default-constructed instance
167 * of T. This is mainly used for our GLSL function implementations: each
168 * function is implemented by an object, and each of the objects has a
169 * distinct class. It would be extremely toilsome to maintain a separate
170 * context object that contained individual instances of the function classes,
171 * so we have to resort to global singleton instances.
172 *
173 *//*--------------------------------------------------------------------*/
174 template <typename T>
instance(void)175 const T& instance (void)
176 {
177 static const T s_instance = T();
178 return s_instance;
179 }
180
181 /*--------------------------------------------------------------------*//*!
182 * \brief Empty placeholder type for unused template parameters.
183 *
184 * In the precision tests we are dealing with functions of different arities.
185 * To minimize code duplication, we only define templates with the maximum
186 * number of arguments, currently four. If a function's arity is less than the
187 * maximum, Void us used as the type for unused arguments.
188 *
189 * Although Voids are not used at run-time, they still must be compilable, so
190 * they must support all operations that other types do.
191 *
192 *//*--------------------------------------------------------------------*/
193 struct Void
194 {
195 typedef Void Element;
196 enum
197 {
198 SIZE = 0,
199 };
200
201 template <typename T>
Voidvkt::shaderexecutor::Void202 explicit Void (const T&) {}
Voidvkt::shaderexecutor::Void203 Void (void) {}
operator doublevkt::shaderexecutor::Void204 operator double (void) const { return TCU_NAN; }
205
206 // These are used to make Voids usable as containers in container-generic code.
operator []vkt::shaderexecutor::Void207 Void& operator[] (int) { return *this; }
operator []vkt::shaderexecutor::Void208 const Void& operator[] (int) const { return *this; }
209 };
210
operator <<(ostream & os,Void)211 ostream& operator<< (ostream& os, Void) { return os << "()"; }
212
213 //! Returns true for all other types except Void
isTypeValid(void)214 template <typename T> bool isTypeValid (void) { return true; }
isTypeValid(void)215 template <> bool isTypeValid<Void> (void) { return false; }
216
isInteger(void)217 template <typename T> bool isInteger (void) { return false; }
isInteger(void)218 template <> bool isInteger<int> (void) { return true; }
isInteger(void)219 template <> bool isInteger<tcu::IVec2> (void) { return true; }
isInteger(void)220 template <> bool isInteger<tcu::IVec3> (void) { return true; }
isInteger(void)221 template <> bool isInteger<tcu::IVec4> (void) { return true; }
222
223 //! Utility function for getting the name of a data type.
224 //! This is used in vector and matrix constructors.
225 template <typename T>
dataTypeNameOf(void)226 const char* dataTypeNameOf (void)
227 {
228 return glu::getDataTypeName(glu::dataTypeOf<T>());
229 }
230
231 template <>
dataTypeNameOf(void)232 const char* dataTypeNameOf<Void> (void)
233 {
234 DE_FATAL("Impossible");
235 return DE_NULL;
236 }
237
238 template <typename T>
getVarTypeOf(Precision prec=glu::PRECISION_LAST)239 VarType getVarTypeOf (Precision prec = glu::PRECISION_LAST)
240 {
241 return glu::varTypeOf<T>(prec);
242 }
243
244 //! A hack to get Void support for VarType.
245 template <>
getVarTypeOf(Precision)246 VarType getVarTypeOf<Void> (Precision)
247 {
248 DE_FATAL("Impossible");
249 return VarType();
250 }
251
252 /*--------------------------------------------------------------------*//*!
253 * \brief Type traits for generalized interval types.
254 *
255 * We are trying to compute sets of acceptable values not only for
256 * float-valued expressions but also for compound values: vectors and
257 * matrices. We approximate a set of vectors as a vector of intervals and
258 * likewise for matrices.
259 *
260 * We now need generalized operations for each type and its interval
261 * approximation. These are given in the type Traits<T>.
262 *
263 * The type Traits<T>::IVal is the approximation of T: it is `Interval` for
264 * scalar types, and a vector or matrix of intervals for container types.
265 *
266 * To allow template inference to take place, there are function wrappers for
267 * the actual operations in Traits<T>. Hence we can just use:
268 *
269 * makeIVal(someFloat)
270 *
271 * instead of:
272 *
273 * Traits<float>::doMakeIVal(value)
274 *
275 *//*--------------------------------------------------------------------*/
276
277 template <typename T> struct Traits;
278
279 //! Create container from elementwise singleton values.
280 template <typename T>
makeIVal(const T & value)281 typename Traits<T>::IVal makeIVal (const T& value)
282 {
283 return Traits<T>::doMakeIVal(value);
284 }
285
286 //! Elementwise union of intervals.
287 template <typename T>
unionIVal(const typename Traits<T>::IVal & a,const typename Traits<T>::IVal & b)288 typename Traits<T>::IVal unionIVal (const typename Traits<T>::IVal& a,
289 const typename Traits<T>::IVal& b)
290 {
291 return Traits<T>::doUnion(a, b);
292 }
293
294 //! Returns true iff every element of `ival` contains the corresponding element of `value`.
295 template <typename T, typename U = Void>
contains(const typename Traits<T>::IVal & ival,const T & value,bool is16Bit=false,const tcu::Maybe<U> & modularDivisor=tcu::Nothing)296 bool contains (const typename Traits<T>::IVal& ival, const T& value, bool is16Bit = false, const tcu::Maybe<U>& modularDivisor = tcu::Nothing)
297 {
298 return Traits<T>::doContains(ival, value, is16Bit, modularDivisor);
299 }
300
301 //! Print out an interval with the precision of `fmt`.
302 template <typename T>
printIVal(const FloatFormat & fmt,const typename Traits<T>::IVal & ival,ostream & os)303 void printIVal (const FloatFormat& fmt, const typename Traits<T>::IVal& ival, ostream& os)
304 {
305 Traits<T>::doPrintIVal(fmt, ival, os);
306 }
307
308 template <typename T>
intervalToString(const FloatFormat & fmt,const typename Traits<T>::IVal & ival)309 string intervalToString (const FloatFormat& fmt, const typename Traits<T>::IVal& ival)
310 {
311 ostringstream oss;
312 printIVal<T>(fmt, ival, oss);
313 return oss.str();
314 }
315
316 //! Print out a value with the precision of `fmt`.
317 template <typename T>
printValue16(const FloatFormat & fmt,const T & value,ostream & os)318 void printValue16 (const FloatFormat& fmt, const T& value, ostream& os)
319 {
320 Traits<T>::doPrintValue16(fmt, value, os);
321 }
322
323 template <typename T>
value16ToString(const FloatFormat & fmt,const T & val)324 string value16ToString(const FloatFormat& fmt, const T& val)
325 {
326 ostringstream oss;
327 printValue16(fmt, val, oss);
328 return oss.str();
329 }
330
getComparisonOperation(const int ndx)331 const std::string getComparisonOperation(const int ndx)
332 {
333 const int operationCount = 10;
334 DE_ASSERT(de::inBounds(ndx, 0, operationCount));
335 const std::string operations[operationCount] =
336 {
337 "OpFOrdEqual\t\t\t",
338 "OpFOrdGreaterThan\t",
339 "OpFOrdLessThan\t\t",
340 "OpFOrdGreaterThanEqual",
341 "OpFOrdLessThanEqual\t",
342 "OpFUnordEqual\t\t",
343 "OpFUnordGreaterThan\t",
344 "OpFUnordLessThan\t",
345 "OpFUnordGreaterThanEqual",
346 "OpFUnordLessThanEqual"
347 };
348 return operations[ndx];
349 }
350
351 template <typename T>
comparisonMessage(const T & val)352 string comparisonMessage(const T& val)
353 {
354 DE_UNREF(val);
355 return "";
356 }
357
358 template <>
comparisonMessage(const int & val)359 string comparisonMessage(const int& val)
360 {
361 ostringstream oss;
362
363 int flags = val;
364 for(int ndx = 0; ndx < 10; ++ndx)
365 {
366 oss << getComparisonOperation(ndx) << "\t:\t" << ((flags & 1) == 1 ? "TRUE" : "FALSE") << "\n";
367 flags = flags >> 1;
368 }
369 return oss.str();
370 }
371
372 template <>
comparisonMessage(const tcu::IVec2 & val)373 string comparisonMessage(const tcu::IVec2& val)
374 {
375 ostringstream oss;
376 tcu::IVec2 flags = val;
377 for (int ndx = 0; ndx < 10; ++ndx)
378 {
379 oss << getComparisonOperation(ndx) << "\t:\t" << ((flags.x() & 1) == 1 ? "TRUE" : "FALSE") << "\t" << ((flags.y() & 1) == 1 ? "TRUE" : "FALSE") << "\n";
380 flags.x() = flags.x() >> 1;
381 flags.y() = flags.y() >> 1;
382 }
383 return oss.str();
384 }
385
386 template <>
comparisonMessage(const tcu::IVec3 & val)387 string comparisonMessage(const tcu::IVec3& val)
388 {
389 ostringstream oss;
390 tcu::IVec3 flags = val;
391 for (int ndx = 0; ndx < 10; ++ndx)
392 {
393 oss << getComparisonOperation(ndx) << "\t:\t" << ((flags.x() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
394 << ((flags.y() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
395 << ((flags.z() & 1) == 1 ? "TRUE" : "FALSE") << "\n";
396 flags.x() = flags.x() >> 1;
397 flags.y() = flags.y() >> 1;
398 flags.z() = flags.z() >> 1;
399 }
400 return oss.str();
401 }
402
403 template <>
comparisonMessage(const tcu::IVec4 & val)404 string comparisonMessage(const tcu::IVec4& val)
405 {
406 ostringstream oss;
407 tcu::IVec4 flags = val;
408 for (int ndx = 0; ndx < 10; ++ndx)
409 {
410 oss << getComparisonOperation(ndx) << "\t:\t" << ((flags.x() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
411 << ((flags.y() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
412 << ((flags.z() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
413 << ((flags.w() & 1) == 1 ? "TRUE" : "FALSE") << "\n";
414 flags.x() = flags.x() >> 1;
415 flags.y() = flags.y() >> 1;
416 flags.z() = flags.z() >> 1;
417 flags.w() = flags.z() >> 1;
418 }
419 return oss.str();
420 }
421 //! Print out a value with the precision of `fmt`.
422 template <typename T>
printValue32(const FloatFormat & fmt,const T & value,ostream & os)423 void printValue32 (const FloatFormat& fmt, const T& value, ostream& os)
424 {
425 Traits<T>::doPrintValue32(fmt, value, os);
426 }
427
428 template <typename T>
value32ToString(const FloatFormat & fmt,const T & val)429 string value32ToString (const FloatFormat& fmt, const T& val)
430 {
431 ostringstream oss;
432 printValue32(fmt, val, oss);
433 return oss.str();
434 }
435
436 template <typename T>
printValue64(const FloatFormat & fmt,const T & value,ostream & os)437 void printValue64 (const FloatFormat& fmt, const T& value, ostream& os)
438 {
439 Traits<T>::doPrintValue64(fmt, value, os);
440 }
441
442 template <typename T>
value64ToString(const FloatFormat & fmt,const T & val)443 string value64ToString (const FloatFormat& fmt, const T& val)
444 {
445 ostringstream oss;
446 printValue64(fmt, val, oss);
447 return oss.str();
448 }
449
450 //! Approximate `value` elementwise to the float precision defined in `fmt`.
451 //! The resulting interval might not be a singleton if rounding in both
452 //! directions is allowed.
453 template <typename T>
round(const FloatFormat & fmt,const T & value)454 typename Traits<T>::IVal round (const FloatFormat& fmt, const T& value)
455 {
456 return Traits<T>::doRound(fmt, value);
457 }
458
459 template <typename T>
convert(const FloatFormat & fmt,const typename Traits<T>::IVal & value)460 typename Traits<T>::IVal convert (const FloatFormat& fmt,
461 const typename Traits<T>::IVal& value)
462 {
463 return Traits<T>::doConvert(fmt, value);
464 }
465
466 // Matching input and output types. We may be in a modulo case and modularDivisor may have an actual value.
467 template <typename T>
intervalContains(const Interval & interval,T value,const tcu::Maybe<T> & modularDivisor)468 bool intervalContains (const Interval& interval, T value, const tcu::Maybe<T>& modularDivisor)
469 {
470 bool contained = interval.contains(value);
471
472 if (!contained && modularDivisor)
473 {
474 const T divisor = modularDivisor.get();
475
476 // In a modulo operation, if the calculated answer contains the divisor, allow exactly 0.0 as a replacement. Alternatively,
477 // if the calculated answer contains 0.0, allow exactly the divisor as a replacement.
478 if (interval.contains(static_cast<double>(divisor)))
479 contained |= (value == 0.0);
480 if (interval.contains(0.0))
481 contained |= (value == divisor);
482 }
483 return contained;
484 }
485
486 // When the input and output types do not match, we are not in a real modulo operation. Do not take the divisor into account. This
487 // version is provided for syntactical compatibility only.
488 template <typename T, typename U>
intervalContains(const Interval & interval,T value,const tcu::Maybe<U> & modularDivisor)489 bool intervalContains (const Interval& interval, T value, const tcu::Maybe<U>& modularDivisor)
490 {
491 DE_UNREF(modularDivisor); // For release builds.
492 DE_ASSERT(!modularDivisor);
493 return interval.contains(value);
494 }
495
496 //! Common traits for scalar types.
497 template <typename T>
498 struct ScalarTraits
499 {
500 typedef Interval IVal;
501
doMakeIValvkt::shaderexecutor::ScalarTraits502 static Interval doMakeIVal (const T& value)
503 {
504 // Thankfully all scalar types have a well-defined conversion to `double`,
505 // hence Interval can represent their ranges without problems.
506 return Interval(double(value));
507 }
508
doUnionvkt::shaderexecutor::ScalarTraits509 static Interval doUnion (const Interval& a, const Interval& b)
510 {
511 return a | b;
512 }
513
doContainsvkt::shaderexecutor::ScalarTraits514 static bool doContains (const Interval& a, T value)
515 {
516 return a.contains(double(value));
517 }
518
doConvertvkt::shaderexecutor::ScalarTraits519 static Interval doConvert (const FloatFormat& fmt, const IVal& ival)
520 {
521 return fmt.convert(ival);
522 }
523
doConvertvkt::shaderexecutor::ScalarTraits524 static Interval doConvert (const FloatFormat& fmt, const IVal& ival, bool is16Bit)
525 {
526 DE_UNREF(is16Bit);
527 return fmt.convert(ival);
528 }
529
doRoundvkt::shaderexecutor::ScalarTraits530 static Interval doRound (const FloatFormat& fmt, T value)
531 {
532 return fmt.roundOut(double(value), false);
533 }
534 };
535
536 template <>
537 struct ScalarTraits<deUint16>
538 {
539 typedef Interval IVal;
540
doMakeIValvkt::shaderexecutor::ScalarTraits541 static Interval doMakeIVal (const deUint16& value)
542 {
543 // Thankfully all scalar types have a well-defined conversion to `double`,
544 // hence Interval can represent their ranges without problems.
545 return Interval(double(deFloat16To32(value)));
546 }
547
doUnionvkt::shaderexecutor::ScalarTraits548 static Interval doUnion (const Interval& a, const Interval& b)
549 {
550 return a | b;
551 }
552
doConvertvkt::shaderexecutor::ScalarTraits553 static Interval doConvert (const FloatFormat& fmt, const IVal& ival)
554 {
555 return fmt.convert(ival);
556 }
557
doRoundvkt::shaderexecutor::ScalarTraits558 static Interval doRound (const FloatFormat& fmt, deUint16 value)
559 {
560 return fmt.roundOut(double(deFloat16To32(value)), false);
561 }
562 };
563
564 template<>
565 struct Traits<float> : ScalarTraits<float>
566 {
doPrintIValvkt::shaderexecutor::Traits567 static void doPrintIVal (const FloatFormat& fmt,
568 const Interval& ival,
569 ostream& os)
570 {
571 os << fmt.intervalToHex(ival);
572 }
573
doPrintValue16vkt::shaderexecutor::Traits574 static void doPrintValue16 (const FloatFormat& fmt,
575 const float& value,
576 ostream& os)
577 {
578 const deUint32 iRep = reinterpret_cast<const deUint32 & >(value);
579 float res0 = deFloat16To32((deFloat16)(iRep & 0xFFFF));
580 float res1 = deFloat16To32((deFloat16)(iRep >> 16));
581 os << fmt.floatToHex(res0) << " " << fmt.floatToHex(res1);
582 }
583
doPrintValue32vkt::shaderexecutor::Traits584 static void doPrintValue32 (const FloatFormat& fmt,
585 const float& value,
586 ostream& os)
587 {
588 os << fmt.floatToHex(value);
589 }
590
doPrintValue64vkt::shaderexecutor::Traits591 static void doPrintValue64 (const FloatFormat& fmt,
592 const float& value,
593 ostream& os)
594 {
595 os << fmt.floatToHex(value);
596 }
597
598 template <typename U>
doContainsvkt::shaderexecutor::Traits599 static bool doContains (const Interval& a, const float& value, bool is16Bit, const tcu::Maybe<U>& modularDivisor)
600 {
601 if(is16Bit)
602 {
603 // Note: for deFloat16s packed in 32 bits, the original divisor is provided as a float to the shader in the input
604 // buffer, so U is also float here and we call the right interlvalContains() version.
605 const deUint32 iRep = reinterpret_cast<const deUint32&>(value);
606 float res0 = deFloat16To32((deFloat16)(iRep & 0xFFFF));
607 float res1 = deFloat16To32((deFloat16)(iRep >> 16));
608 return intervalContains(a, res0, modularDivisor) && (res1 == -1.0);
609 }
610 return intervalContains(a, value, modularDivisor);
611 }
612 };
613
614 template<>
615 struct Traits<double> : ScalarTraits<double>
616 {
doPrintIValvkt::shaderexecutor::Traits617 static void doPrintIVal (const FloatFormat& fmt,
618 const Interval& ival,
619 ostream& os)
620 {
621 os << fmt.intervalToHex(ival);
622 }
623
doPrintValue16vkt::shaderexecutor::Traits624 static void doPrintValue16 (const FloatFormat& fmt,
625 const double& value,
626 ostream& os)
627 {
628 const deUint64 iRep = reinterpret_cast<const deUint64&>(value);
629 double byte0 = deFloat16To64((deFloat16)((iRep ) & 0xffff));
630 double byte1 = deFloat16To64((deFloat16)((iRep >> 16) & 0xffff));
631 double byte2 = deFloat16To64((deFloat16)((iRep >> 32) & 0xffff));
632 double byte3 = deFloat16To64((deFloat16)((iRep >> 48) & 0xffff));
633 os << fmt.floatToHex(byte0) << " " << fmt.floatToHex(byte1) << " " << fmt.floatToHex(byte2) << " " << fmt.floatToHex(byte3);
634 }
635
doPrintValue32vkt::shaderexecutor::Traits636 static void doPrintValue32 (const FloatFormat& fmt,
637 const double& value,
638 ostream& os)
639 {
640 const deUint64 iRep = reinterpret_cast<const deUint64&>(value);
641 double res0 = static_cast<double>((float)((iRep ) & 0xffffffff));
642 double res1 = static_cast<double>((float)((iRep >> 32) & 0xffffffff));
643 os << fmt.floatToHex(res0) << " " << fmt.floatToHex(res1);
644 }
645
doPrintValue64vkt::shaderexecutor::Traits646 static void doPrintValue64 (const FloatFormat& fmt,
647 const double& value,
648 ostream& os)
649 {
650 os << fmt.floatToHex(value);
651 }
652
653 template <class U>
doContainsvkt::shaderexecutor::Traits654 static bool doContains (const Interval& a, const double& value, bool is16Bit, const tcu::Maybe<U>& modularDivisor)
655 {
656 DE_UNREF(is16Bit);
657 DE_ASSERT(!is16Bit);
658 return intervalContains(a, value, modularDivisor);
659 }
660 };
661
662 template<>
663 struct Traits<deFloat16> : ScalarTraits<deFloat16>
664 {
doPrintIValvkt::shaderexecutor::Traits665 static void doPrintIVal (const FloatFormat& fmt,
666 const Interval& ival,
667 ostream& os)
668 {
669 os << fmt.intervalToHex(ival);
670 }
671
doPrintValue16vkt::shaderexecutor::Traits672 static void doPrintValue16 (const FloatFormat& fmt,
673 const deFloat16& value,
674 ostream& os)
675 {
676 const float res0 = deFloat16To32(value);
677 os << fmt.floatToHex(static_cast<double>(res0));
678 }
doPrintValue32vkt::shaderexecutor::Traits679 static void doPrintValue32 (const FloatFormat& fmt,
680 const deFloat16& value,
681 ostream& os)
682 {
683 const float res0 = deFloat16To32(value);
684 os << fmt.floatToHex(static_cast<double>(res0));
685 }
686
doPrintValue64vkt::shaderexecutor::Traits687 static void doPrintValue64 (const FloatFormat& fmt,
688 const deFloat16& value,
689 ostream& os)
690 {
691 const double res0 = deFloat16To64(value);
692 os << fmt.floatToHex(res0);
693 }
694
695 // When the value and divisor are both deFloat16, convert both to float to call the right intervalContains version.
doContainsvkt::shaderexecutor::Traits696 static bool doContains (const Interval& a, const deFloat16& value, bool is16Bit, const tcu::Maybe<deFloat16>& modularDivisor)
697 {
698 DE_UNREF(is16Bit);
699 float res0 = deFloat16To32(value);
700 const tcu::Maybe<float> convertedDivisor = (modularDivisor ? tcu::just(deFloat16To32(modularDivisor.get())) : tcu::Nothing);
701 return intervalContains(a, res0, convertedDivisor);
702 }
703
704 // If the types don't match we should not be in a modulo operation, so no conversion should take place.
705 template <class U>
doContainsvkt::shaderexecutor::Traits706 static bool doContains (const Interval& a, const deFloat16& value, bool is16Bit, const tcu::Maybe<U>& modularDivisor)
707 {
708 DE_UNREF(is16Bit);
709 float res0 = deFloat16To32(value);
710 return intervalContains(a, res0, modularDivisor);
711 }
712 };
713
714 template<>
715 struct Traits<bool> : ScalarTraits<bool>
716 {
doPrintValue16vkt::shaderexecutor::Traits717 static void doPrintValue16 (const FloatFormat&,
718 const float& value,
719 ostream& os)
720 {
721 os << (value != 0.0f ? "true" : "false");
722 }
723
doPrintValue32vkt::shaderexecutor::Traits724 static void doPrintValue32 (const FloatFormat&,
725 const float& value,
726 ostream& os)
727 {
728 os << (value != 0.0f ? "true" : "false");
729 }
730
doPrintValue64vkt::shaderexecutor::Traits731 static void doPrintValue64 (const FloatFormat&,
732 const float& value,
733 ostream& os)
734 {
735 os << (value != 0.0f ? "true" : "false");
736 }
737
doPrintIValvkt::shaderexecutor::Traits738 static void doPrintIVal (const FloatFormat&,
739 const Interval& ival,
740 ostream& os)
741 {
742 os << "{";
743 if (ival.contains(false))
744 os << "false";
745 if (ival.contains(false) && ival.contains(true))
746 os << ", ";
747 if (ival.contains(true))
748 os << "true";
749 os << "}";
750 }
751 };
752
753 template<>
754 struct Traits<int> : ScalarTraits<int>
755 {
doPrintValue16vkt::shaderexecutor::Traits756 static void doPrintValue16 (const FloatFormat&,
757 const int& value,
758 ostream& os)
759 {
760 int res0 = value & 0xFFFF;
761 int res1 = value >> 16;
762 os << res0 << " " << res1;
763 }
764
doPrintValue32vkt::shaderexecutor::Traits765 static void doPrintValue32 (const FloatFormat&,
766 const int& value,
767 ostream& os)
768 {
769 os << value;
770 }
771
doPrintValue64vkt::shaderexecutor::Traits772 static void doPrintValue64 (const FloatFormat&,
773 const int& value,
774 ostream& os)
775 {
776 os << value;
777 }
778
doPrintIValvkt::shaderexecutor::Traits779 static void doPrintIVal (const FloatFormat&,
780 const Interval& ival,
781 ostream& os)
782 {
783 os << "[" << int(ival.lo()) << ", " << int(ival.hi()) << "]";
784 }
785
786 template <typename U>
doContainsvkt::shaderexecutor::Traits787 static bool doContains (const Interval& a, const int& value, bool is16Bit, const tcu::Maybe<U>& modularDivisor)
788 {
789 DE_UNREF(is16Bit);
790 return intervalContains(a, value, modularDivisor);
791 }
792 };
793
794 //! Common traits for containers, i.e. vectors and matrices.
795 //! T is the container type itself, I is the same type with interval elements.
796 template <typename T, typename I>
797 struct ContainerTraits
798 {
799 typedef typename T::Element Element;
800 typedef I IVal;
801
doMakeIValvkt::shaderexecutor::ContainerTraits802 static IVal doMakeIVal (const T& value)
803 {
804 IVal ret;
805
806 for (int ndx = 0; ndx < T::SIZE; ++ndx)
807 ret[ndx] = makeIVal(value[ndx]);
808
809 return ret;
810 }
811
doUnionvkt::shaderexecutor::ContainerTraits812 static IVal doUnion (const IVal& a, const IVal& b)
813 {
814 IVal ret;
815
816 for (int ndx = 0; ndx < T::SIZE; ++ndx)
817 ret[ndx] = unionIVal<Element>(a[ndx], b[ndx]);
818
819 return ret;
820 }
821
822 // When the input and output types match, we may be in a modulo operation. If the divisor is provided, use each of its
823 // components to determine if the obtained result is fine.
doContainsvkt::shaderexecutor::ContainerTraits824 static bool doContains (const IVal& ival, const T& value, bool is16Bit, const tcu::Maybe<T>& modularDivisor)
825 {
826 using DivisorElement = typename T::Element;
827
828 for (int ndx = 0; ndx < T::SIZE; ++ndx)
829 {
830 const tcu::Maybe<DivisorElement> divisorElement = (modularDivisor ? tcu::just((*modularDivisor)[ndx]) : tcu::Nothing);
831 if (!contains(ival[ndx], value[ndx], is16Bit, divisorElement))
832 return false;
833 }
834
835 return true;
836 }
837
838 // When the input and output types do not match we should not be in a modulo operation. This version is provided for syntactical
839 // compatibility.
840 template <typename U>
doContainsvkt::shaderexecutor::ContainerTraits841 static bool doContains (const IVal& ival, const T& value, bool is16Bit, const tcu::Maybe<U>& modularDivisor)
842 {
843 for (int ndx = 0; ndx < T::SIZE; ++ndx)
844 {
845 if (!contains(ival[ndx], value[ndx], is16Bit, modularDivisor))
846 return false;
847 }
848
849 return true;
850 }
851
doPrintIValvkt::shaderexecutor::ContainerTraits852 static void doPrintIVal (const FloatFormat& fmt, const IVal ival, ostream& os)
853 {
854 os << "(";
855
856 for (int ndx = 0; ndx < T::SIZE; ++ndx)
857 {
858 if (ndx > 0)
859 os << ", ";
860
861 printIVal<Element>(fmt, ival[ndx], os);
862 }
863
864 os << ")";
865 }
866
doPrintValue16vkt::shaderexecutor::ContainerTraits867 static void doPrintValue16 (const FloatFormat& fmt, const T& value, ostream& os)
868 {
869 os << dataTypeNameOf<T>() << "(";
870
871 for (int ndx = 0; ndx < T::SIZE; ++ndx)
872 {
873 if (ndx > 0)
874 os << ", ";
875
876 printValue16<Element>(fmt, value[ndx], os);
877 }
878
879 os << ")";
880 }
881
doPrintValue32vkt::shaderexecutor::ContainerTraits882 static void doPrintValue32 (const FloatFormat& fmt, const T& value, ostream& os)
883 {
884 os << dataTypeNameOf<T>() << "(";
885
886 for (int ndx = 0; ndx < T::SIZE; ++ndx)
887 {
888 if (ndx > 0)
889 os << ", ";
890
891 printValue32<Element>(fmt, value[ndx], os);
892 }
893
894 os << ")";
895 }
896
doPrintValue64vkt::shaderexecutor::ContainerTraits897 static void doPrintValue64 (const FloatFormat& fmt, const T& value, ostream& os)
898 {
899 os << dataTypeNameOf<T>() << "(";
900
901 for (int ndx = 0; ndx < T::SIZE; ++ndx)
902 {
903 if (ndx > 0)
904 os << ", ";
905
906 printValue64<Element>(fmt, value[ndx], os);
907 }
908
909 os << ")";
910 }
911
doConvertvkt::shaderexecutor::ContainerTraits912 static IVal doConvert (const FloatFormat& fmt, const IVal& value)
913 {
914 IVal ret;
915
916 for (int ndx = 0; ndx < T::SIZE; ++ndx)
917 ret[ndx] = convert<Element>(fmt, value[ndx]);
918
919 return ret;
920 }
921
doRoundvkt::shaderexecutor::ContainerTraits922 static IVal doRound (const FloatFormat& fmt, T value)
923 {
924 IVal ret;
925
926 for (int ndx = 0; ndx < T::SIZE; ++ndx)
927 ret[ndx] = round(fmt, value[ndx]);
928
929 return ret;
930 }
931 };
932
933 template <typename T, int Size>
934 struct Traits<Vector<T, Size> > :
935 ContainerTraits<Vector<T, Size>, Vector<typename Traits<T>::IVal, Size> >
936 {
937 };
938
939 template <typename T, int Rows, int Cols>
940 struct Traits<Matrix<T, Rows, Cols> > :
941 ContainerTraits<Matrix<T, Rows, Cols>, Matrix<typename Traits<T>::IVal, Rows, Cols> >
942 {
943 };
944
945 //! Void traits. These are just dummies, but technically valid: a Void is a
946 //! unit type with a single possible value.
947 template<>
948 struct Traits<Void>
949 {
950 typedef Void IVal;
951
doMakeIValvkt::shaderexecutor::Traits952 static Void doMakeIVal (const Void& value) { return value; }
doUnionvkt::shaderexecutor::Traits953 static Void doUnion (const Void&, const Void&) { return Void(); }
doContainsvkt::shaderexecutor::Traits954 static bool doContains (const Void&, Void) { return true; }
955 template <typename U>
doContainsvkt::shaderexecutor::Traits956 static bool doContains (const Void&, const Void& value, bool is16Bit, const tcu::Maybe<U>& modularDivisor) { DE_UNREF(value); DE_UNREF(is16Bit); DE_UNREF(modularDivisor); return true; }
doRoundvkt::shaderexecutor::Traits957 static Void doRound (const FloatFormat&, const Void& value) { return value; }
doConvertvkt::shaderexecutor::Traits958 static Void doConvert (const FloatFormat&, const Void& value) { return value; }
959
doPrintValue16vkt::shaderexecutor::Traits960 static void doPrintValue16 (const FloatFormat&, const Void&, ostream& os)
961 {
962 os << "()";
963 }
964
doPrintValue32vkt::shaderexecutor::Traits965 static void doPrintValue32 (const FloatFormat&, const Void&, ostream& os)
966 {
967 os << "()";
968 }
969
doPrintValue64vkt::shaderexecutor::Traits970 static void doPrintValue64 (const FloatFormat&, const Void&, ostream& os)
971 {
972 os << "()";
973 }
974
doPrintIValvkt::shaderexecutor::Traits975 static void doPrintIVal (const FloatFormat&, const Void&, ostream& os)
976 {
977 os << "()";
978 }
979 };
980
981 //! This is needed for container-generic operations.
982 //! We want a scalar type T to be its own "one-element vector".
983 template <typename T, int Size> struct ContainerOf { typedef Vector<T, Size> Container; };
984
985 template <typename T> struct ContainerOf<T, 1> { typedef T Container; };
986 template <int Size> struct ContainerOf<Void, Size> { typedef Void Container; };
987
988 // This is a kludge that is only needed to get the ExprP::operator[] syntactic sugar to work.
989 template <typename T> struct ElementOf { typedef typename T::Element Element; };
990 template <> struct ElementOf<float> { typedef void Element; };
991 template <> struct ElementOf<double>{ typedef void Element; };
992 template <> struct ElementOf<bool> { typedef void Element; };
993 template <> struct ElementOf<int> { typedef void Element; };
994
995 template <typename T>
comparisonMessageInterval(const typename Traits<T>::IVal & val)996 string comparisonMessageInterval(const typename Traits<T>::IVal& val)
997 {
998 DE_UNREF(val);
999 return "";
1000 }
1001
1002 template <>
comparisonMessageInterval(const Traits<int>::IVal & val)1003 string comparisonMessageInterval<int>(const Traits<int>::IVal& val)
1004 {
1005 return comparisonMessage(static_cast<int>(val.lo()));
1006 }
1007
1008 template <>
comparisonMessageInterval(const Traits<float>::IVal & val)1009 string comparisonMessageInterval<float>(const Traits<float>::IVal& val)
1010 {
1011 return comparisonMessage(static_cast<int>(val.lo()));
1012 }
1013
1014 template <>
comparisonMessageInterval(const tcu::Vector<tcu::Interval,2> & val)1015 string comparisonMessageInterval<tcu::Vector<int, 2> >(const tcu::Vector<tcu::Interval, 2> & val)
1016 {
1017 tcu::IVec2 result(static_cast<int>(val[0].lo()), static_cast<int>(val[1].lo()));
1018 return comparisonMessage(result);
1019 }
1020
1021 template <>
comparisonMessageInterval(const tcu::Vector<tcu::Interval,3> & val)1022 string comparisonMessageInterval<tcu::Vector<int, 3> >(const tcu::Vector<tcu::Interval, 3> & val)
1023 {
1024 tcu::IVec3 result(static_cast<int>(val[0].lo()), static_cast<int>(val[1].lo()), static_cast<int>(val[2].lo()));
1025 return comparisonMessage(result);
1026 }
1027
1028 template <>
comparisonMessageInterval(const tcu::Vector<tcu::Interval,4> & val)1029 string comparisonMessageInterval<tcu::Vector<int, 4> >(const tcu::Vector<tcu::Interval, 4> & val)
1030 {
1031 tcu::IVec4 result(static_cast<int>(val[0].lo()), static_cast<int>(val[1].lo()), static_cast<int>(val[2].lo()), static_cast<int>(val[3].lo()));
1032 return comparisonMessage(result);
1033 }
1034
1035 /*--------------------------------------------------------------------*//*!
1036 *
1037 * \name Abstract syntax for expressions and statements.
1038 *
1039 * We represent GLSL programs as syntax objects: an Expr<T> represents an
1040 * expression whose GLSL type corresponds to the C++ type T, and a Statement
1041 * represents a statement.
1042 *
1043 * To ease memory management, we use shared pointers to refer to expressions
1044 * and statements. ExprP<T> is a shared pointer to an Expr<T>, and StatementP
1045 * is a shared pointer to a Statement.
1046 *
1047 * \{
1048 *
1049 *//*--------------------------------------------------------------------*/
1050
1051 class ExprBase;
1052 class ExpandContext;
1053 class Statement;
1054 class StatementP;
1055 class FuncBase;
1056 template <typename T> class ExprP;
1057 template <typename T> class Variable;
1058 template <typename T> class VariableP;
1059 template <typename T> class DefaultSampling;
1060
1061 typedef set<const FuncBase*> FuncSet;
1062
1063 template <typename T>
1064 VariableP<T> variable (const string& name);
1065 StatementP compoundStatement (const vector<StatementP>& statements);
1066
1067 /*--------------------------------------------------------------------*//*!
1068 * \brief A variable environment.
1069 *
1070 * An Environment object maintains the mapping between variables of the
1071 * abstract syntax tree and their values.
1072 *
1073 * \todo [2014-03-28 lauri] At least run-time type safety.
1074 *
1075 *//*--------------------------------------------------------------------*/
1076 class Environment
1077 {
1078 public:
1079 template<typename T>
bind(const Variable<T> & variable,const typename Traits<T>::IVal & value)1080 void bind (const Variable<T>& variable,
1081 const typename Traits<T>::IVal& value)
1082 {
1083 deUint8* const data = new deUint8[sizeof(value)];
1084
1085 deMemcpy(data, &value, sizeof(value));
1086 de::insert(m_map, variable.getName(), SharedPtr<deUint8>(data, de::ArrayDeleter<deUint8>()));
1087 }
1088
1089 template<typename T>
lookup(const Variable<T> & variable) const1090 typename Traits<T>::IVal& lookup (const Variable<T>& variable) const
1091 {
1092 deUint8* const data = de::lookup(m_map, variable.getName()).get();
1093
1094 return *reinterpret_cast<typename Traits<T>::IVal*>(data);
1095 }
1096
1097 private:
1098 map<string, SharedPtr<deUint8> > m_map;
1099 };
1100
1101 /*--------------------------------------------------------------------*//*!
1102 * \brief Evaluation context.
1103 *
1104 * The evaluation context contains everything that separates one execution of
1105 * an expression from the next. Currently this means the desired floating
1106 * point precision and the current variable environment.
1107 *
1108 *//*--------------------------------------------------------------------*/
1109 struct EvalContext
1110 {
EvalContextvkt::shaderexecutor::EvalContext1111 EvalContext (const FloatFormat& format_,
1112 Precision floatPrecision_,
1113 Environment& env_,
1114 int callDepth_)
1115 : format (format_)
1116 , floatPrecision (floatPrecision_)
1117 , env (env_)
1118 , callDepth (callDepth_) {}
1119
1120 FloatFormat format;
1121 Precision floatPrecision;
1122 Environment& env;
1123 int callDepth;
1124 };
1125
1126 /*--------------------------------------------------------------------*//*!
1127 * \brief Simple incremental counter.
1128 *
1129 * This is used to make sure that different ExpandContexts will not produce
1130 * overlapping temporary names.
1131 *
1132 *//*--------------------------------------------------------------------*/
1133 class Counter
1134 {
1135 public:
Counter(int count=0)1136 Counter (int count = 0) : m_count(count) {}
operator ()(void)1137 int operator() (void) { return m_count++; }
1138
1139 private:
1140 int m_count;
1141 };
1142
1143 class ExpandContext
1144 {
1145 public:
ExpandContext(Counter & symCounter)1146 ExpandContext (Counter& symCounter) : m_symCounter(symCounter) {}
ExpandContext(const ExpandContext & parent)1147 ExpandContext (const ExpandContext& parent)
1148 : m_symCounter(parent.m_symCounter) {}
1149
1150 template<typename T>
genSym(const string & baseName)1151 VariableP<T> genSym (const string& baseName)
1152 {
1153 return variable<T>(baseName + de::toString(m_symCounter()));
1154 }
1155
addStatement(const StatementP & stmt)1156 void addStatement (const StatementP& stmt)
1157 {
1158 m_statements.push_back(stmt);
1159 }
1160
getStatements(void) const1161 vector<StatementP> getStatements (void) const
1162 {
1163 return m_statements;
1164 }
1165 private:
1166 Counter& m_symCounter;
1167 vector<StatementP> m_statements;
1168 };
1169
1170 /*--------------------------------------------------------------------*//*!
1171 * \brief A statement or declaration.
1172 *
1173 * Statements have no values. Instead, they are executed for their side
1174 * effects only: the execute() method should modify at least one variable in
1175 * the environment.
1176 *
1177 * As a bit of a kludge, a Statement object can also represent a declaration:
1178 * when it is evaluated, it can add a variable binding to the environment
1179 * instead of modifying a current one.
1180 *
1181 *//*--------------------------------------------------------------------*/
1182 class Statement
1183 {
1184 public:
~Statement(void)1185 virtual ~Statement (void) { }
1186 //! Execute the statement, modifying the environment of `ctx`
execute(EvalContext & ctx) const1187 void execute (EvalContext& ctx) const { this->doExecute(ctx); }
print(ostream & os) const1188 void print (ostream& os) const { this->doPrint(os); }
1189 //! Add the functions used in this statement to `dst`.
getUsedFuncs(FuncSet & dst) const1190 void getUsedFuncs (FuncSet& dst) const { this->doGetUsedFuncs(dst); }
failed(EvalContext & ctx) const1191 void failed (EvalContext& ctx) const { this->doFail(ctx); }
1192
1193 protected:
1194 virtual void doPrint (ostream& os) const = 0;
1195 virtual void doExecute (EvalContext& ctx) const = 0;
1196 virtual void doGetUsedFuncs (FuncSet& dst) const = 0;
doFail(EvalContext & ctx) const1197 virtual void doFail (EvalContext& ctx) const { DE_UNREF(ctx); }
1198 };
1199
operator <<(ostream & os,const Statement & stmt)1200 ostream& operator<<(ostream& os, const Statement& stmt)
1201 {
1202 stmt.print(os);
1203 return os;
1204 }
1205
1206 /*--------------------------------------------------------------------*//*!
1207 * \brief Smart pointer for statements (and declarations)
1208 *
1209 *//*--------------------------------------------------------------------*/
1210 class StatementP : public SharedPtr<const Statement>
1211 {
1212 public:
1213 typedef SharedPtr<const Statement> Super;
1214
StatementP(void)1215 StatementP (void) {}
StatementP(const Statement * ptr)1216 explicit StatementP (const Statement* ptr) : Super(ptr) {}
StatementP(const Super & ptr)1217 StatementP (const Super& ptr) : Super(ptr) {}
1218 };
1219
1220 /*--------------------------------------------------------------------*//*!
1221 * \brief
1222 *
1223 * A statement that modifies a variable or a declaration that binds a variable.
1224 *
1225 *//*--------------------------------------------------------------------*/
1226 template <typename T>
1227 class VariableStatement : public Statement
1228 {
1229 public:
VariableStatement(const VariableP<T> & variable,const ExprP<T> & value,bool isDeclaration)1230 VariableStatement (const VariableP<T>& variable, const ExprP<T>& value,
1231 bool isDeclaration)
1232 : m_variable (variable)
1233 , m_value (value)
1234 , m_isDeclaration (isDeclaration) {}
1235
1236 protected:
doPrint(ostream & os) const1237 void doPrint (ostream& os) const
1238 {
1239 if (m_isDeclaration)
1240 os << glu::declare(getVarTypeOf<T>(), m_variable->getName());
1241 else
1242 os << m_variable->getName();
1243
1244 os << " = ";
1245 os<< *m_value << ";\n";
1246 }
1247
doExecute(EvalContext & ctx) const1248 void doExecute (EvalContext& ctx) const
1249 {
1250 if (m_isDeclaration)
1251 ctx.env.bind(*m_variable, m_value->evaluate(ctx));
1252 else
1253 ctx.env.lookup(*m_variable) = m_value->evaluate(ctx);
1254 }
1255
doGetUsedFuncs(FuncSet & dst) const1256 void doGetUsedFuncs (FuncSet& dst) const
1257 {
1258 m_value->getUsedFuncs(dst);
1259 }
1260
doFail(EvalContext & ctx) const1261 virtual void doFail (EvalContext& ctx) const
1262 {
1263 if (m_isDeclaration)
1264 ctx.env.bind(*m_variable, m_value->fails(ctx));
1265 else
1266 ctx.env.lookup(*m_variable) = m_value->fails(ctx);
1267 }
1268
1269 VariableP<T> m_variable;
1270 ExprP<T> m_value;
1271 bool m_isDeclaration;
1272 };
1273
1274 template <typename T>
variableStatement(const VariableP<T> & variable,const ExprP<T> & value,bool isDeclaration)1275 StatementP variableStatement (const VariableP<T>& variable,
1276 const ExprP<T>& value,
1277 bool isDeclaration)
1278 {
1279 return StatementP(new VariableStatement<T>(variable, value, isDeclaration));
1280 }
1281
1282 template <typename T>
variableDeclaration(const VariableP<T> & variable,const ExprP<T> & definiens)1283 StatementP variableDeclaration (const VariableP<T>& variable, const ExprP<T>& definiens)
1284 {
1285 return variableStatement(variable, definiens, true);
1286 }
1287
1288 template <typename T>
variableAssignment(const VariableP<T> & variable,const ExprP<T> & value)1289 StatementP variableAssignment (const VariableP<T>& variable, const ExprP<T>& value)
1290 {
1291 return variableStatement(variable, value, false);
1292 }
1293
1294 /*--------------------------------------------------------------------*//*!
1295 * \brief A compound statement, i.e. a block.
1296 *
1297 * A compound statement is executed by executing its constituent statements in
1298 * sequence.
1299 *
1300 *//*--------------------------------------------------------------------*/
1301 class CompoundStatement : public Statement
1302 {
1303 public:
CompoundStatement(const vector<StatementP> & statements)1304 CompoundStatement (const vector<StatementP>& statements)
1305 : m_statements (statements) {}
1306
1307 protected:
doPrint(ostream & os) const1308 void doPrint (ostream& os) const
1309 {
1310 os << "{\n";
1311
1312 for (size_t ndx = 0; ndx < m_statements.size(); ++ndx)
1313 os << *m_statements[ndx];
1314
1315 os << "}\n";
1316 }
1317
doExecute(EvalContext & ctx) const1318 void doExecute (EvalContext& ctx) const
1319 {
1320 for (size_t ndx = 0; ndx < m_statements.size(); ++ndx)
1321 m_statements[ndx]->execute(ctx);
1322 }
1323
doGetUsedFuncs(FuncSet & dst) const1324 void doGetUsedFuncs (FuncSet& dst) const
1325 {
1326 for (size_t ndx = 0; ndx < m_statements.size(); ++ndx)
1327 m_statements[ndx]->getUsedFuncs(dst);
1328 }
1329
1330 vector<StatementP> m_statements;
1331 };
1332
compoundStatement(const vector<StatementP> & statements)1333 StatementP compoundStatement(const vector<StatementP>& statements)
1334 {
1335 return StatementP(new CompoundStatement(statements));
1336 }
1337
1338 //! Common base class for all expressions regardless of their type.
1339 class ExprBase
1340 {
1341 public:
~ExprBase(void)1342 virtual ~ExprBase (void) {}
printExpr(ostream & os) const1343 void printExpr (ostream& os) const { this->doPrintExpr(os); }
1344
1345 //! Output the functions that this expression refers to
getUsedFuncs(FuncSet & dst) const1346 void getUsedFuncs (FuncSet& dst) const
1347 {
1348 this->doGetUsedFuncs(dst);
1349 }
1350
1351 protected:
doPrintExpr(ostream &) const1352 virtual void doPrintExpr (ostream&) const {}
doGetUsedFuncs(FuncSet &) const1353 virtual void doGetUsedFuncs (FuncSet&) const {}
1354 };
1355
1356 //! Type-specific operations for an expression representing type T.
1357 template <typename T>
1358 class Expr : public ExprBase
1359 {
1360 public:
1361 typedef T Val;
1362 typedef typename Traits<T>::IVal IVal;
1363
1364 IVal evaluate (const EvalContext& ctx) const;
fails(const EvalContext & ctx) const1365 IVal fails (const EvalContext& ctx) const { return this->doFails(ctx); }
1366
1367 protected:
1368 virtual IVal doEvaluate (const EvalContext& ctx) const = 0;
doFails(const EvalContext & ctx) const1369 virtual IVal doFails (const EvalContext& ctx) const {return doEvaluate(ctx);}
1370 };
1371
1372 //! Evaluate an expression with the given context, optionally tracing the calls to stderr.
1373 template <typename T>
evaluate(const EvalContext & ctx) const1374 typename Traits<T>::IVal Expr<T>::evaluate (const EvalContext& ctx) const
1375 {
1376 #ifdef GLS_ENABLE_TRACE
1377 static const FloatFormat highpFmt (-126, 127, 23, true,
1378 tcu::MAYBE,
1379 tcu::YES,
1380 tcu::MAYBE);
1381 EvalContext newCtx (ctx.format, ctx.floatPrecision,
1382 ctx.env, ctx.callDepth + 1);
1383 const IVal ret = this->doEvaluate(newCtx);
1384
1385 if (isTypeValid<T>())
1386 {
1387 std::cerr << string(ctx.callDepth, ' ');
1388 this->printExpr(std::cerr);
1389 std::cerr << " -> " << intervalToString<T>(highpFmt, ret) << std::endl;
1390 }
1391 return ret;
1392 #else
1393 return this->doEvaluate(ctx);
1394 #endif
1395 }
1396
1397 template <typename T>
1398 class ExprPBase : public SharedPtr<const Expr<T> >
1399 {
1400 public:
1401 };
1402
operator <<(ostream & os,const ExprBase & expr)1403 ostream& operator<< (ostream& os, const ExprBase& expr)
1404 {
1405 expr.printExpr(os);
1406 return os;
1407 }
1408
1409 /*--------------------------------------------------------------------*//*!
1410 * \brief Shared pointer to an expression of a container type.
1411 *
1412 * Container types (i.e. vectors and matrices) support the subscription
1413 * operator. This class provides a bit of syntactic sugar to allow us to use
1414 * the C++ subscription operator to create a subscription expression.
1415 *//*--------------------------------------------------------------------*/
1416 template <typename T>
1417 class ContainerExprPBase : public ExprPBase<T>
1418 {
1419 public:
1420 ExprP<typename T::Element> operator[] (int i) const;
1421 };
1422
1423 template <typename T>
1424 class ExprP : public ExprPBase<T> {};
1425
1426 // We treat Voids as containers since the unused parameters in generalized
1427 // vector functions are represented as Voids.
1428 template <>
1429 class ExprP<Void> : public ContainerExprPBase<Void> {};
1430
1431 template <typename T, int Size>
1432 class ExprP<Vector<T, Size> > : public ContainerExprPBase<Vector<T, Size> > {};
1433
1434 template <typename T, int Rows, int Cols>
1435 class ExprP<Matrix<T, Rows, Cols> > : public ContainerExprPBase<Matrix<T, Rows, Cols> > {};
1436
exprP(void)1437 template <typename T> ExprP<T> exprP (void)
1438 {
1439 return ExprP<T>();
1440 }
1441
1442 template <typename T>
exprP(const SharedPtr<const Expr<T>> & ptr)1443 ExprP<T> exprP (const SharedPtr<const Expr<T> >& ptr)
1444 {
1445 ExprP<T> ret;
1446 static_cast<SharedPtr<const Expr<T> >&>(ret) = ptr;
1447 return ret;
1448 }
1449
1450 template <typename T>
exprP(const Expr<T> * ptr)1451 ExprP<T> exprP (const Expr<T>* ptr)
1452 {
1453 return exprP(SharedPtr<const Expr<T> >(ptr));
1454 }
1455
1456 /*--------------------------------------------------------------------*//*!
1457 * \brief A shared pointer to a variable expression.
1458 *
1459 * This is just a narrowing of ExprP for the operations that require a variable
1460 * instead of an arbitrary expression.
1461 *
1462 *//*--------------------------------------------------------------------*/
1463 template <typename T>
1464 class VariableP : public SharedPtr<const Variable<T> >
1465 {
1466 public:
1467 typedef SharedPtr<const Variable<T> > Super;
VariableP(const Variable<T> * ptr)1468 explicit VariableP (const Variable<T>* ptr) : Super(ptr) {}
VariableP(void)1469 VariableP (void) {}
VariableP(const Super & ptr)1470 VariableP (const Super& ptr) : Super(ptr) {}
1471
operator ExprP<T>(void) const1472 operator ExprP<T> (void) const { return exprP(SharedPtr<const Expr<T> >(*this)); }
1473 };
1474
1475 /*--------------------------------------------------------------------*//*!
1476 * \name Syntactic sugar operators for expressions.
1477 *
1478 * @{
1479 *
1480 * These operators allow the use of C++ syntax to construct GLSL expressions
1481 * containing operators: e.g. "a+b" creates an addition expression with
1482 * operands a and b, and so on.
1483 *
1484 *//*--------------------------------------------------------------------*/
1485 ExprP<float> operator+ (const ExprP<float>& arg0,
1486 const ExprP<float>& arg1);
1487 ExprP<deFloat16> operator+ (const ExprP<deFloat16>& arg0,
1488 const ExprP<deFloat16>& arg1);
1489 ExprP<double> operator+ (const ExprP<double>& arg0,
1490 const ExprP<double>& arg1);
1491 template <typename T>
1492 ExprP<T> operator- (const ExprP<T>& arg0);
1493 template <typename T>
1494 ExprP<T> operator- (const ExprP<T>& arg0,
1495 const ExprP<T>& arg1);
1496 template<int Left, int Mid, int Right, typename T>
1497 ExprP<Matrix<T, Left, Right> > operator* (const ExprP<Matrix<T, Left, Mid> >& left,
1498 const ExprP<Matrix<T, Mid, Right> >& right);
1499 ExprP<float> operator* (const ExprP<float>& arg0,
1500 const ExprP<float>& arg1);
1501 ExprP<deFloat16> operator* (const ExprP<deFloat16>& arg0,
1502 const ExprP<deFloat16>& arg1);
1503 ExprP<double> operator* (const ExprP<double>& arg0,
1504 const ExprP<double>& arg1);
1505 template <typename T>
1506 ExprP<T> operator/ (const ExprP<T>& arg0,
1507 const ExprP<T>& arg1);
1508 template<typename T, int Size>
1509 ExprP<Vector<T, Size> > operator- (const ExprP<Vector<T, Size> >& arg0);
1510 template<typename T, int Size>
1511 ExprP<Vector<T, Size> > operator- (const ExprP<Vector<T, Size> >& arg0,
1512 const ExprP<Vector<T, Size> >& arg1);
1513 template<int Size, typename T>
1514 ExprP<Vector<T, Size> > operator* (const ExprP<Vector<T, Size> >& arg0,
1515 const ExprP<T>& arg1);
1516 template<typename T, int Size>
1517 ExprP<Vector<T, Size> > operator* (const ExprP<Vector<T, Size> >& arg0,
1518 const ExprP<Vector<T, Size> >& arg1);
1519 template<int Rows, int Cols, typename T>
1520 ExprP<Vector<T, Rows> > operator* (const ExprP<Vector<T, Cols> >& left,
1521 const ExprP<Matrix<T, Rows, Cols> >& right);
1522 template<int Rows, int Cols, typename T>
1523 ExprP<Vector<T, Cols> > operator* (const ExprP<Matrix<T, Rows, Cols> >& left,
1524 const ExprP<Vector<T, Rows> >& right);
1525 template<int Rows, int Cols, typename T>
1526 ExprP<Matrix<T, Rows, Cols> > operator* (const ExprP<Matrix<T, Rows, Cols> >& left,
1527 const ExprP<T>& right);
1528 template<int Rows, int Cols>
1529 ExprP<Matrix<float, Rows, Cols> > operator+ (const ExprP<Matrix<float, Rows, Cols> >& left,
1530 const ExprP<Matrix<float, Rows, Cols> >& right);
1531 template<int Rows, int Cols>
1532 ExprP<Matrix<deFloat16, Rows, Cols> > operator+ (const ExprP<Matrix<deFloat16, Rows, Cols> >& left,
1533 const ExprP<Matrix<deFloat16, Rows, Cols> >& right);
1534 template<int Rows, int Cols>
1535 ExprP<Matrix<double, Rows, Cols> > operator+ (const ExprP<Matrix<double, Rows, Cols> >& left,
1536 const ExprP<Matrix<double, Rows, Cols> >& right);
1537 template<typename T, int Rows, int Cols>
1538 ExprP<Matrix<T, Rows, Cols> > operator- (const ExprP<Matrix<T, Rows, Cols> >& mat);
1539
1540 //! @}
1541
1542 /*--------------------------------------------------------------------*//*!
1543 * \brief Variable expression.
1544 *
1545 * A variable is evaluated by looking up its range of possible values from an
1546 * environment.
1547 *//*--------------------------------------------------------------------*/
1548 template <typename T>
1549 class Variable : public Expr<T>
1550 {
1551 public:
1552 typedef typename Expr<T>::IVal IVal;
1553
Variable(const string & name)1554 Variable (const string& name) : m_name (name) {}
getName(void) const1555 string getName (void) const { return m_name; }
1556
1557 protected:
doPrintExpr(ostream & os) const1558 void doPrintExpr (ostream& os) const { os << m_name; }
doEvaluate(const EvalContext & ctx) const1559 IVal doEvaluate (const EvalContext& ctx) const
1560 {
1561 return ctx.env.lookup<T>(*this);
1562 }
1563
1564 private:
1565 string m_name;
1566 };
1567
1568 template <typename T>
variable(const string & name)1569 VariableP<T> variable (const string& name)
1570 {
1571 return VariableP<T>(new Variable<T>(name));
1572 }
1573
1574 template <typename T>
bindExpression(const string & name,ExpandContext & ctx,const ExprP<T> & expr)1575 VariableP<T> bindExpression (const string& name, ExpandContext& ctx, const ExprP<T>& expr)
1576 {
1577 VariableP<T> var = ctx.genSym<T>(name);
1578 ctx.addStatement(variableDeclaration(var, expr));
1579 return var;
1580 }
1581
1582 /*--------------------------------------------------------------------*//*!
1583 * \brief Constant expression.
1584 *
1585 * A constant is evaluated by rounding it to a set of possible values allowed
1586 * by the current floating point precision.
1587 * TODO: For whatever reason this doesn't happen, the constant is converted to
1588 * type T and the interval contains only (T)value. See FloatConstant, below.
1589 *//*--------------------------------------------------------------------*/
1590 template <typename T>
1591 class Constant : public Expr<T>
1592 {
1593 public:
1594 typedef typename Expr<T>::IVal IVal;
1595
Constant(const T & value)1596 Constant (const T& value) : m_value(value) {}
1597
1598 protected:
doPrintExpr(ostream & os) const1599 void doPrintExpr (ostream& os) const { os << m_value; }
doEvaluate(const EvalContext &) const1600 IVal doEvaluate (const EvalContext&) const { return makeIVal(m_value); }
1601
1602 private:
1603 T m_value;
1604 };
1605
1606 template <typename T>
constant(const T & value)1607 ExprP<T> constant (const T& value)
1608 {
1609 return exprP(new Constant<T>(value));
1610 }
1611
1612 template <typename T>
1613 class FloatConstant : public Expr<T>
1614 {
1615 public:
1616 typedef typename Expr<T>::IVal IVal;
1617
FloatConstant(double value)1618 FloatConstant (double value) : m_value(value) {}
1619
1620 protected:
doPrintExpr(ostream & os) const1621 void doPrintExpr (ostream& os) const { os << m_value; }
1622 // TODO: This should probably roundOut to T, not ctx.format, but the templates don't work like that.
doEvaluate(const EvalContext & ctx) const1623 IVal doEvaluate (const EvalContext &ctx) const { return ctx.format.roundOut(makeIVal(m_value), true); }
1624
1625 private:
1626 double m_value;
1627 };
1628
f16Constant(double value)1629 ExprP<deFloat16> f16Constant (double value)
1630 {
1631 return exprP(new FloatConstant<deFloat16>(value));
1632 }
f32Constant(double value)1633 ExprP<float> f32Constant (double value)
1634 {
1635 return exprP(new FloatConstant<float>(value));
1636 }
1637
1638 //! Return a reference to a singleton void constant.
voidP(void)1639 const ExprP<Void>& voidP (void)
1640 {
1641 static const ExprP<Void> singleton = constant(Void());
1642
1643 return singleton;
1644 }
1645
1646 /*--------------------------------------------------------------------*//*!
1647 * \brief Four-element tuple.
1648 *
1649 * This is used for various things where we need one thing for each possible
1650 * function parameter. Currently the maximum supported number of parameters is
1651 * four.
1652 *//*--------------------------------------------------------------------*/
1653 template <typename T0 = Void, typename T1 = Void, typename T2 = Void, typename T3 = Void>
1654 struct Tuple4
1655 {
Tuple4vkt::shaderexecutor::Tuple41656 explicit Tuple4 (const T0 e0 = T0(),
1657 const T1 e1 = T1(),
1658 const T2 e2 = T2(),
1659 const T3 e3 = T3())
1660 : a (e0)
1661 , b (e1)
1662 , c (e2)
1663 , d (e3)
1664 {
1665 }
1666
1667 T0 a;
1668 T1 b;
1669 T2 c;
1670 T3 d;
1671 };
1672
1673 /*--------------------------------------------------------------------*//*!
1674 * \brief Function signature.
1675 *
1676 * This is a purely compile-time structure used to bundle all types in a
1677 * function signature together. This makes passing the signature around in
1678 * templates easier, since we only need to take and pass a single Sig instead
1679 * of a bunch of parameter types and a return type.
1680 *
1681 *//*--------------------------------------------------------------------*/
1682 template <typename R,
1683 typename P0 = Void, typename P1 = Void,
1684 typename P2 = Void, typename P3 = Void>
1685 struct Signature
1686 {
1687 typedef R Ret;
1688 typedef P0 Arg0;
1689 typedef P1 Arg1;
1690 typedef P2 Arg2;
1691 typedef P3 Arg3;
1692 typedef typename Traits<Ret>::IVal IRet;
1693 typedef typename Traits<Arg0>::IVal IArg0;
1694 typedef typename Traits<Arg1>::IVal IArg1;
1695 typedef typename Traits<Arg2>::IVal IArg2;
1696 typedef typename Traits<Arg3>::IVal IArg3;
1697
1698 typedef Tuple4< const Arg0&, const Arg1&, const Arg2&, const Arg3&> Args;
1699 typedef Tuple4< const IArg0&, const IArg1&, const IArg2&, const IArg3&> IArgs;
1700 typedef Tuple4< ExprP<Arg0>, ExprP<Arg1>, ExprP<Arg2>, ExprP<Arg3> > ArgExprs;
1701 };
1702
1703 typedef vector<const ExprBase*> BaseArgExprs;
1704
1705 /*--------------------------------------------------------------------*//*!
1706 * \brief Type-independent operations for function objects.
1707 *
1708 *//*--------------------------------------------------------------------*/
1709 class FuncBase
1710 {
1711 public:
~FuncBase(void)1712 virtual ~FuncBase (void) {}
1713 virtual string getName (void) const = 0;
1714 //! Name of extension that this function requires, or empty.
getRequiredExtension(void) const1715 virtual string getRequiredExtension (void) const { return ""; }
getInputRange(const bool is16bit) const1716 virtual Interval getInputRange (const bool is16bit) const {DE_UNREF(is16bit); return Interval(true, -TCU_INFINITY, TCU_INFINITY); }
1717 virtual void print (ostream&,
1718 const BaseArgExprs&) const = 0;
1719 //! Index of output parameter, or -1 if none of the parameters is output.
getOutParamIndex(void) const1720 virtual int getOutParamIndex (void) const { return -1; }
1721
getSpirvCase(void) const1722 virtual SpirVCaseT getSpirvCase (void) const { return SPIRV_CASETYPE_NONE; }
1723
printDefinition(ostream & os) const1724 void printDefinition (ostream& os) const
1725 {
1726 doPrintDefinition(os);
1727 }
1728
getUsedFuncs(FuncSet & dst) const1729 void getUsedFuncs (FuncSet& dst) const
1730 {
1731 this->doGetUsedFuncs(dst);
1732 }
1733
1734 protected:
1735 virtual void doPrintDefinition (ostream& os) const = 0;
1736 virtual void doGetUsedFuncs (FuncSet& dst) const = 0;
1737 };
1738
1739 typedef Tuple4<string, string, string, string> ParamNames;
1740
1741 /*--------------------------------------------------------------------*//*!
1742 * \brief Function objects.
1743 *
1744 * Each Func object represents a GLSL function. It can be applied to interval
1745 * arguments, and it returns the an interval that is a conservative
1746 * approximation of the image of the GLSL function over the argument
1747 * intervals. That is, it is given a set of possible arguments and it returns
1748 * the set of possible values.
1749 *
1750 *//*--------------------------------------------------------------------*/
1751 template <typename Sig_>
1752 class Func : public FuncBase
1753 {
1754 public:
1755 typedef Sig_ Sig;
1756 typedef typename Sig::Ret Ret;
1757 typedef typename Sig::Arg0 Arg0;
1758 typedef typename Sig::Arg1 Arg1;
1759 typedef typename Sig::Arg2 Arg2;
1760 typedef typename Sig::Arg3 Arg3;
1761 typedef typename Sig::IRet IRet;
1762 typedef typename Sig::IArg0 IArg0;
1763 typedef typename Sig::IArg1 IArg1;
1764 typedef typename Sig::IArg2 IArg2;
1765 typedef typename Sig::IArg3 IArg3;
1766 typedef typename Sig::Args Args;
1767 typedef typename Sig::IArgs IArgs;
1768 typedef typename Sig::ArgExprs ArgExprs;
1769
print(ostream & os,const BaseArgExprs & args) const1770 void print (ostream& os,
1771 const BaseArgExprs& args) const
1772 {
1773 this->doPrint(os, args);
1774 }
1775
apply(const EvalContext & ctx,const IArg0 & arg0=IArg0 (),const IArg1 & arg1=IArg1 (),const IArg2 & arg2=IArg2 (),const IArg3 & arg3=IArg3 ()) const1776 IRet apply (const EvalContext& ctx,
1777 const IArg0& arg0 = IArg0(),
1778 const IArg1& arg1 = IArg1(),
1779 const IArg2& arg2 = IArg2(),
1780 const IArg3& arg3 = IArg3()) const
1781 {
1782 return this->applyArgs(ctx, IArgs(arg0, arg1, arg2, arg3));
1783 }
1784
fail(const EvalContext & ctx,const IArg0 & arg0=IArg0 (),const IArg1 & arg1=IArg1 (),const IArg2 & arg2=IArg2 (),const IArg3 & arg3=IArg3 ()) const1785 IRet fail (const EvalContext& ctx,
1786 const IArg0& arg0 = IArg0(),
1787 const IArg1& arg1 = IArg1(),
1788 const IArg2& arg2 = IArg2(),
1789 const IArg3& arg3 = IArg3()) const
1790 {
1791 return this->doFail(ctx, IArgs(arg0, arg1, arg2, arg3));
1792 }
applyArgs(const EvalContext & ctx,const IArgs & args) const1793 IRet applyArgs (const EvalContext& ctx,
1794 const IArgs& args) const
1795 {
1796 return this->doApply(ctx, args);
1797 }
1798 ExprP<Ret> operator() (const ExprP<Arg0>& arg0 = voidP(),
1799 const ExprP<Arg1>& arg1 = voidP(),
1800 const ExprP<Arg2>& arg2 = voidP(),
1801 const ExprP<Arg3>& arg3 = voidP()) const;
1802
getParamNames(void) const1803 const ParamNames& getParamNames (void) const
1804 {
1805 return this->doGetParamNames();
1806 }
1807
1808 protected:
1809 virtual IRet doApply (const EvalContext&,
1810 const IArgs&) const = 0;
doFail(const EvalContext & ctx,const IArgs & args) const1811 virtual IRet doFail (const EvalContext& ctx,
1812 const IArgs& args) const
1813 {
1814 return this->doApply(ctx, args);
1815 }
doPrint(ostream & os,const BaseArgExprs & args) const1816 virtual void doPrint (ostream& os, const BaseArgExprs& args) const
1817 {
1818 os << getName() << "(";
1819
1820 if (isTypeValid<Arg0>())
1821 os << *args[0];
1822
1823 if (isTypeValid<Arg1>())
1824 os << ", " << *args[1];
1825
1826 if (isTypeValid<Arg2>())
1827 os << ", " << *args[2];
1828
1829 if (isTypeValid<Arg3>())
1830 os << ", " << *args[3];
1831
1832 os << ")";
1833 }
1834
doGetParamNames(void) const1835 virtual const ParamNames& doGetParamNames (void) const
1836 {
1837 static ParamNames names ("a", "b", "c", "d");
1838 return names;
1839 }
1840 };
1841
1842 template <typename Sig>
1843 class Apply : public Expr<typename Sig::Ret>
1844 {
1845 public:
1846 typedef typename Sig::Ret Ret;
1847 typedef typename Sig::Arg0 Arg0;
1848 typedef typename Sig::Arg1 Arg1;
1849 typedef typename Sig::Arg2 Arg2;
1850 typedef typename Sig::Arg3 Arg3;
1851 typedef typename Expr<Ret>::Val Val;
1852 typedef typename Expr<Ret>::IVal IVal;
1853 typedef Func<Sig> ApplyFunc;
1854 typedef typename ApplyFunc::ArgExprs ArgExprs;
1855
Apply(const ApplyFunc & func,const ExprP<Arg0> & arg0=voidP (),const ExprP<Arg1> & arg1=voidP (),const ExprP<Arg2> & arg2=voidP (),const ExprP<Arg3> & arg3=voidP ())1856 Apply (const ApplyFunc& func,
1857 const ExprP<Arg0>& arg0 = voidP(),
1858 const ExprP<Arg1>& arg1 = voidP(),
1859 const ExprP<Arg2>& arg2 = voidP(),
1860 const ExprP<Arg3>& arg3 = voidP())
1861 : m_func (func),
1862 m_args (arg0, arg1, arg2, arg3) {}
1863
Apply(const ApplyFunc & func,const ArgExprs & args)1864 Apply (const ApplyFunc& func,
1865 const ArgExprs& args)
1866 : m_func (func),
1867 m_args (args) {}
1868 protected:
doPrintExpr(ostream & os) const1869 void doPrintExpr (ostream& os) const
1870 {
1871 BaseArgExprs args;
1872 args.push_back(m_args.a.get());
1873 args.push_back(m_args.b.get());
1874 args.push_back(m_args.c.get());
1875 args.push_back(m_args.d.get());
1876 m_func.print(os, args);
1877 }
1878
doEvaluate(const EvalContext & ctx) const1879 IVal doEvaluate (const EvalContext& ctx) const
1880 {
1881 return m_func.apply(ctx,
1882 m_args.a->evaluate(ctx), m_args.b->evaluate(ctx),
1883 m_args.c->evaluate(ctx), m_args.d->evaluate(ctx));
1884 }
1885
doGetUsedFuncs(FuncSet & dst) const1886 void doGetUsedFuncs (FuncSet& dst) const
1887 {
1888 m_func.getUsedFuncs(dst);
1889 m_args.a->getUsedFuncs(dst);
1890 m_args.b->getUsedFuncs(dst);
1891 m_args.c->getUsedFuncs(dst);
1892 m_args.d->getUsedFuncs(dst);
1893 }
1894
1895 const ApplyFunc& m_func;
1896 ArgExprs m_args;
1897 };
1898
1899 template<typename T>
1900 class Alternatives : public Func<Signature<T, T, T> >
1901 {
1902 public:
1903 typedef typename Alternatives::Sig Sig;
1904
1905 protected:
1906 typedef typename Alternatives::IRet IRet;
1907 typedef typename Alternatives::IArgs IArgs;
1908
getName(void) const1909 virtual string getName (void) const { return "alternatives"; }
doPrintDefinition(std::ostream &) const1910 virtual void doPrintDefinition (std::ostream&) const {}
doGetUsedFuncs(FuncSet &) const1911 void doGetUsedFuncs (FuncSet&) const {}
1912
doApply(const EvalContext &,const IArgs & args) const1913 virtual IRet doApply (const EvalContext&, const IArgs& args) const
1914 {
1915 return unionIVal<T>(args.a, args.b);
1916 }
1917
doPrint(ostream & os,const BaseArgExprs & args) const1918 virtual void doPrint (ostream& os, const BaseArgExprs& args) const
1919 {
1920 os << "{" << *args[0] << " | " << *args[1] << "}";
1921 }
1922 };
1923
1924 template <typename Sig>
createApply(const Func<Sig> & func,const typename Func<Sig>::ArgExprs & args)1925 ExprP<typename Sig::Ret> createApply (const Func<Sig>& func,
1926 const typename Func<Sig>::ArgExprs& args)
1927 {
1928 return exprP(new Apply<Sig>(func, args));
1929 }
1930
1931 template <typename Sig>
createApply(const Func<Sig> & func,const ExprP<typename Sig::Arg0> & arg0=voidP (),const ExprP<typename Sig::Arg1> & arg1=voidP (),const ExprP<typename Sig::Arg2> & arg2=voidP (),const ExprP<typename Sig::Arg3> & arg3=voidP ())1932 ExprP<typename Sig::Ret> createApply (
1933 const Func<Sig>& func,
1934 const ExprP<typename Sig::Arg0>& arg0 = voidP(),
1935 const ExprP<typename Sig::Arg1>& arg1 = voidP(),
1936 const ExprP<typename Sig::Arg2>& arg2 = voidP(),
1937 const ExprP<typename Sig::Arg3>& arg3 = voidP())
1938 {
1939 return exprP(new Apply<Sig>(func, arg0, arg1, arg2, arg3));
1940 }
1941
1942 template <typename Sig>
operator ()(const ExprP<typename Sig::Arg0> & arg0,const ExprP<typename Sig::Arg1> & arg1,const ExprP<typename Sig::Arg2> & arg2,const ExprP<typename Sig::Arg3> & arg3) const1943 ExprP<typename Sig::Ret> Func<Sig>::operator() (const ExprP<typename Sig::Arg0>& arg0,
1944 const ExprP<typename Sig::Arg1>& arg1,
1945 const ExprP<typename Sig::Arg2>& arg2,
1946 const ExprP<typename Sig::Arg3>& arg3) const
1947 {
1948 return createApply(*this, arg0, arg1, arg2, arg3);
1949 }
1950
1951 template <typename F>
app(const ExprP<typename F::Arg0> & arg0=voidP (),const ExprP<typename F::Arg1> & arg1=voidP (),const ExprP<typename F::Arg2> & arg2=voidP (),const ExprP<typename F::Arg3> & arg3=voidP ())1952 ExprP<typename F::Ret> app (const ExprP<typename F::Arg0>& arg0 = voidP(),
1953 const ExprP<typename F::Arg1>& arg1 = voidP(),
1954 const ExprP<typename F::Arg2>& arg2 = voidP(),
1955 const ExprP<typename F::Arg3>& arg3 = voidP())
1956 {
1957 return createApply(instance<F>(), arg0, arg1, arg2, arg3);
1958 }
1959
1960 template <typename F>
call(const EvalContext & ctx,const typename F::IArg0 & arg0=Void (),const typename F::IArg1 & arg1=Void (),const typename F::IArg2 & arg2=Void (),const typename F::IArg3 & arg3=Void ())1961 typename F::IRet call (const EvalContext& ctx,
1962 const typename F::IArg0& arg0 = Void(),
1963 const typename F::IArg1& arg1 = Void(),
1964 const typename F::IArg2& arg2 = Void(),
1965 const typename F::IArg3& arg3 = Void())
1966 {
1967 return instance<F>().apply(ctx, arg0, arg1, arg2, arg3);
1968 }
1969
1970 template <typename T>
alternatives(const ExprP<T> & arg0,const ExprP<T> & arg1)1971 ExprP<T> alternatives (const ExprP<T>& arg0,
1972 const ExprP<T>& arg1)
1973 {
1974 return createApply<typename Alternatives<T>::Sig>(instance<Alternatives<T> >(), arg0, arg1);
1975 }
1976
1977 template <typename Sig>
1978 class ApplyVar : public Apply<Sig>
1979 {
1980 public:
1981 typedef typename Sig::Ret Ret;
1982 typedef typename Sig::Arg0 Arg0;
1983 typedef typename Sig::Arg1 Arg1;
1984 typedef typename Sig::Arg2 Arg2;
1985 typedef typename Sig::Arg3 Arg3;
1986 typedef typename Expr<Ret>::Val Val;
1987 typedef typename Expr<Ret>::IVal IVal;
1988 typedef Func<Sig> ApplyFunc;
1989 typedef typename ApplyFunc::ArgExprs ArgExprs;
1990
ApplyVar(const ApplyFunc & func,const VariableP<Arg0> & arg0,const VariableP<Arg1> & arg1,const VariableP<Arg2> & arg2,const VariableP<Arg3> & arg3)1991 ApplyVar (const ApplyFunc& func,
1992 const VariableP<Arg0>& arg0,
1993 const VariableP<Arg1>& arg1,
1994 const VariableP<Arg2>& arg2,
1995 const VariableP<Arg3>& arg3)
1996 : Apply<Sig> (func, arg0, arg1, arg2, arg3) {}
1997 protected:
doEvaluate(const EvalContext & ctx) const1998 IVal doEvaluate (const EvalContext& ctx) const
1999 {
2000 const Variable<Arg0>& var0 = static_cast<const Variable<Arg0>&>(*this->m_args.a);
2001 const Variable<Arg1>& var1 = static_cast<const Variable<Arg1>&>(*this->m_args.b);
2002 const Variable<Arg2>& var2 = static_cast<const Variable<Arg2>&>(*this->m_args.c);
2003 const Variable<Arg3>& var3 = static_cast<const Variable<Arg3>&>(*this->m_args.d);
2004 return this->m_func.apply(ctx,
2005 ctx.env.lookup(var0), ctx.env.lookup(var1),
2006 ctx.env.lookup(var2), ctx.env.lookup(var3));
2007 }
2008
doFails(const EvalContext & ctx) const2009 IVal doFails (const EvalContext& ctx) const
2010 {
2011 const Variable<Arg0>& var0 = static_cast<const Variable<Arg0>&>(*this->m_args.a);
2012 const Variable<Arg1>& var1 = static_cast<const Variable<Arg1>&>(*this->m_args.b);
2013 const Variable<Arg2>& var2 = static_cast<const Variable<Arg2>&>(*this->m_args.c);
2014 const Variable<Arg3>& var3 = static_cast<const Variable<Arg3>&>(*this->m_args.d);
2015 return this->m_func.fail(ctx,
2016 ctx.env.lookup(var0), ctx.env.lookup(var1),
2017 ctx.env.lookup(var2), ctx.env.lookup(var3));
2018 }
2019 };
2020
2021 template <typename Sig>
applyVar(const Func<Sig> & func,const VariableP<typename Sig::Arg0> & arg0,const VariableP<typename Sig::Arg1> & arg1,const VariableP<typename Sig::Arg2> & arg2,const VariableP<typename Sig::Arg3> & arg3)2022 ExprP<typename Sig::Ret> applyVar (const Func<Sig>& func,
2023 const VariableP<typename Sig::Arg0>& arg0,
2024 const VariableP<typename Sig::Arg1>& arg1,
2025 const VariableP<typename Sig::Arg2>& arg2,
2026 const VariableP<typename Sig::Arg3>& arg3)
2027 {
2028 return exprP(new ApplyVar<Sig>(func, arg0, arg1, arg2, arg3));
2029 }
2030
2031 template <typename Sig_>
2032 class DerivedFunc : public Func<Sig_>
2033 {
2034 public:
2035 typedef typename DerivedFunc::ArgExprs ArgExprs;
2036 typedef typename DerivedFunc::IRet IRet;
2037 typedef typename DerivedFunc::IArgs IArgs;
2038 typedef typename DerivedFunc::Ret Ret;
2039 typedef typename DerivedFunc::Arg0 Arg0;
2040 typedef typename DerivedFunc::Arg1 Arg1;
2041 typedef typename DerivedFunc::Arg2 Arg2;
2042 typedef typename DerivedFunc::Arg3 Arg3;
2043 typedef typename DerivedFunc::IArg0 IArg0;
2044 typedef typename DerivedFunc::IArg1 IArg1;
2045 typedef typename DerivedFunc::IArg2 IArg2;
2046 typedef typename DerivedFunc::IArg3 IArg3;
2047
2048 protected:
doPrintDefinition(ostream & os) const2049 void doPrintDefinition (ostream& os) const
2050 {
2051 const ParamNames& paramNames = this->getParamNames();
2052
2053 initialize();
2054
2055 os << dataTypeNameOf<Ret>() << " " << this->getName()
2056 << "(";
2057 if (isTypeValid<Arg0>())
2058 os << dataTypeNameOf<Arg0>() << " " << paramNames.a;
2059 if (isTypeValid<Arg1>())
2060 os << ", " << dataTypeNameOf<Arg1>() << " " << paramNames.b;
2061 if (isTypeValid<Arg2>())
2062 os << ", " << dataTypeNameOf<Arg2>() << " " << paramNames.c;
2063 if (isTypeValid<Arg3>())
2064 os << ", " << dataTypeNameOf<Arg3>() << " " << paramNames.d;
2065 os << ")\n{\n";
2066
2067 for (size_t ndx = 0; ndx < m_body.size(); ++ndx)
2068 os << *m_body[ndx];
2069 os << "return " << *m_ret << ";\n";
2070 os << "}\n";
2071 }
2072
doApply(const EvalContext & ctx,const IArgs & args) const2073 IRet doApply (const EvalContext& ctx,
2074 const IArgs& args) const
2075 {
2076 Environment funEnv;
2077 IArgs& mutArgs = const_cast<IArgs&>(args);
2078 IRet ret;
2079
2080 initialize();
2081
2082 funEnv.bind(*m_var0, args.a);
2083 funEnv.bind(*m_var1, args.b);
2084 funEnv.bind(*m_var2, args.c);
2085 funEnv.bind(*m_var3, args.d);
2086
2087 {
2088 EvalContext funCtx(ctx.format, ctx.floatPrecision, funEnv, ctx.callDepth);
2089
2090 for (size_t ndx = 0; ndx < m_body.size(); ++ndx)
2091 m_body[ndx]->execute(funCtx);
2092
2093 ret = m_ret->evaluate(funCtx);
2094 }
2095
2096 // \todo [lauri] Store references instead of values in environment
2097 const_cast<IArg0&>(mutArgs.a) = funEnv.lookup(*m_var0);
2098 const_cast<IArg1&>(mutArgs.b) = funEnv.lookup(*m_var1);
2099 const_cast<IArg2&>(mutArgs.c) = funEnv.lookup(*m_var2);
2100 const_cast<IArg3&>(mutArgs.d) = funEnv.lookup(*m_var3);
2101
2102 return ret;
2103 }
2104
doGetUsedFuncs(FuncSet & dst) const2105 void doGetUsedFuncs (FuncSet& dst) const
2106 {
2107 initialize();
2108 if (dst.insert(this).second)
2109 {
2110 for (size_t ndx = 0; ndx < m_body.size(); ++ndx)
2111 m_body[ndx]->getUsedFuncs(dst);
2112 m_ret->getUsedFuncs(dst);
2113 }
2114 }
2115
2116 virtual ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args_) const = 0;
2117
2118 // These are transparently initialized when first needed. They cannot be
2119 // initialized in the constructor because they depend on the doExpand
2120 // method of the subclass.
2121
2122 mutable VariableP<Arg0> m_var0;
2123 mutable VariableP<Arg1> m_var1;
2124 mutable VariableP<Arg2> m_var2;
2125 mutable VariableP<Arg3> m_var3;
2126 mutable vector<StatementP> m_body;
2127 mutable ExprP<Ret> m_ret;
2128
2129 private:
2130
initialize(void) const2131 void initialize (void) const
2132 {
2133 if (!m_ret)
2134 {
2135 const ParamNames& paramNames = this->getParamNames();
2136 Counter symCounter;
2137 ExpandContext ctx (symCounter);
2138 ArgExprs args;
2139
2140 args.a = m_var0 = variable<Arg0>(paramNames.a);
2141 args.b = m_var1 = variable<Arg1>(paramNames.b);
2142 args.c = m_var2 = variable<Arg2>(paramNames.c);
2143 args.d = m_var3 = variable<Arg3>(paramNames.d);
2144
2145 m_ret = this->doExpand(ctx, args);
2146 m_body = ctx.getStatements();
2147 }
2148 }
2149 };
2150
2151 template <typename Sig>
2152 class PrimitiveFunc : public Func<Sig>
2153 {
2154 public:
2155 typedef typename PrimitiveFunc::Ret Ret;
2156 typedef typename PrimitiveFunc::ArgExprs ArgExprs;
2157
2158 protected:
doPrintDefinition(ostream &) const2159 void doPrintDefinition (ostream&) const {}
doGetUsedFuncs(FuncSet &) const2160 void doGetUsedFuncs (FuncSet&) const {}
2161 };
2162
2163 template <typename T>
2164 class Cond : public PrimitiveFunc<Signature<T, bool, T, T> >
2165 {
2166 public:
2167 typedef typename Cond::IArgs IArgs;
2168 typedef typename Cond::IRet IRet;
2169
getName(void) const2170 string getName (void) const
2171 {
2172 return "_cond";
2173 }
2174
2175 protected:
2176
doPrint(ostream & os,const BaseArgExprs & args) const2177 void doPrint (ostream& os, const BaseArgExprs& args) const
2178 {
2179 os << "(" << *args[0] << " ? " << *args[1] << " : " << *args[2] << ")";
2180 }
2181
doApply(const EvalContext &,const IArgs & iargs) const2182 IRet doApply (const EvalContext&, const IArgs& iargs)const
2183 {
2184 IRet ret;
2185
2186 if (iargs.a.contains(true))
2187 ret = unionIVal<T>(ret, iargs.b);
2188
2189 if (iargs.a.contains(false))
2190 ret = unionIVal<T>(ret, iargs.c);
2191
2192 return ret;
2193 }
2194 };
2195
2196 template <typename T>
2197 class CompareOperator : public PrimitiveFunc<Signature<bool, T, T> >
2198 {
2199 public:
2200 typedef typename CompareOperator::IArgs IArgs;
2201 typedef typename CompareOperator::IArg0 IArg0;
2202 typedef typename CompareOperator::IArg1 IArg1;
2203 typedef typename CompareOperator::IRet IRet;
2204
2205 protected:
doPrint(ostream & os,const BaseArgExprs & args) const2206 void doPrint (ostream& os, const BaseArgExprs& args) const
2207 {
2208 os << "(" << *args[0] << getSymbol() << *args[1] << ")";
2209 }
2210
doApply(const EvalContext &,const IArgs & iargs) const2211 Interval doApply (const EvalContext&, const IArgs& iargs) const
2212 {
2213 const IArg0& arg0 = iargs.a;
2214 const IArg1& arg1 = iargs.b;
2215 IRet ret;
2216
2217 if (canSucceed(arg0, arg1))
2218 ret |= true;
2219 if (canFail(arg0, arg1))
2220 ret |= false;
2221
2222 return ret;
2223 }
2224
2225 virtual string getSymbol (void) const = 0;
2226 virtual bool canSucceed (const IArg0&, const IArg1&) const = 0;
2227 virtual bool canFail (const IArg0&, const IArg1&) const = 0;
2228 };
2229
2230 template <typename T>
2231 class LessThan : public CompareOperator<T>
2232 {
2233 public:
getName(void) const2234 string getName (void) const { return "lessThan"; }
2235
2236 protected:
getSymbol(void) const2237 string getSymbol (void) const { return "<"; }
2238
canSucceed(const Interval & a,const Interval & b) const2239 bool canSucceed (const Interval& a, const Interval& b) const
2240 {
2241 return (a.lo() < b.hi());
2242 }
2243
canFail(const Interval & a,const Interval & b) const2244 bool canFail (const Interval& a, const Interval& b) const
2245 {
2246 return !(a.hi() < b.lo());
2247 }
2248 };
2249
2250 template <typename T>
operator <(const ExprP<T> & a,const ExprP<T> & b)2251 ExprP<bool> operator< (const ExprP<T>& a, const ExprP<T>& b)
2252 {
2253 return app<LessThan<T> >(a, b);
2254 }
2255
2256 template <typename T>
cond(const ExprP<bool> & test,const ExprP<T> & consequent,const ExprP<T> & alternative)2257 ExprP<T> cond (const ExprP<bool>& test,
2258 const ExprP<T>& consequent,
2259 const ExprP<T>& alternative)
2260 {
2261 return app<Cond<T> >(test, consequent, alternative);
2262 }
2263
2264 /*--------------------------------------------------------------------*//*!
2265 *
2266 * @}
2267 *
2268 *//*--------------------------------------------------------------------*/
2269 //Proper parameters for template T
2270 // Signature<float, float> 32bit tests
2271 // Signature<float, deFloat16> 16bit tests
2272 // Signature<double, double> 64bit tests
2273 template< class T>
2274 class FloatFunc1 : public PrimitiveFunc<T>
2275 {
2276 protected:
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0>::IArgs & iargs) const2277 Interval doApply (const EvalContext& ctx, const typename Signature<typename T::Ret, typename T::Arg0>::IArgs& iargs) const
2278 {
2279 return this->applyMonotone(ctx, iargs.a);
2280 }
2281
applyMonotone(const EvalContext & ctx,const Interval & iarg0) const2282 Interval applyMonotone (const EvalContext& ctx, const Interval& iarg0) const
2283 {
2284 Interval ret;
2285
2286 TCU_INTERVAL_APPLY_MONOTONE1(ret, arg0, iarg0, val,
2287 TCU_SET_INTERVAL(val, point,
2288 point = this->applyPoint(ctx, arg0)));
2289
2290 ret |= innerExtrema(ctx, iarg0);
2291 ret &= (this->getCodomain(ctx) | TCU_NAN);
2292
2293 return ctx.format.convert(ret);
2294 }
2295
innerExtrema(const EvalContext &,const Interval &) const2296 virtual Interval innerExtrema (const EvalContext&, const Interval&) const
2297 {
2298 return Interval(); // empty interval, i.e. no extrema
2299 }
2300
applyPoint(const EvalContext & ctx,double arg0) const2301 virtual Interval applyPoint (const EvalContext& ctx, double arg0) const
2302 {
2303 const double exact = this->applyExact(arg0);
2304 const double prec = this->precision(ctx, exact, arg0);
2305
2306 return exact + Interval(-prec, prec);
2307 }
2308
applyExact(double) const2309 virtual double applyExact (double) const
2310 {
2311 TCU_THROW(InternalError, "Cannot apply");
2312 }
2313
getCodomain(const EvalContext &) const2314 virtual Interval getCodomain (const EvalContext&) const
2315 {
2316 return Interval::unbounded(true);
2317 }
2318
2319 virtual double precision (const EvalContext& ctx, double, double) const = 0;
2320 };
2321
2322 /*Proper parameters for template T
2323 Signature<double, double> 64bit tests
2324 Signature<float, float> 32bit tests
2325 Signature<float, deFloat16> 16bit tests*/
2326 template <class T>
2327 class CFloatFunc1 : public FloatFunc1<T>
2328 {
2329 public:
CFloatFunc1(const string & name,tcu::DoubleFunc1 & func)2330 CFloatFunc1 (const string& name, tcu::DoubleFunc1& func)
2331 : m_name(name), m_func(func) {}
2332
getName(void) const2333 string getName (void) const { return m_name; }
2334
2335 protected:
applyExact(double x) const2336 double applyExact (double x) const { return m_func(x); }
2337
2338 const string m_name;
2339 tcu::DoubleFunc1& m_func;
2340 };
2341
2342 //<Signature<float, deFloat16, deFloat16> >
2343 //<Signature<float, float, float> >
2344 //<Signature<double, double, double> >
2345 template <class T>
2346 class FloatFunc2 : public PrimitiveFunc<T>
2347 {
2348 protected:
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2349 Interval doApply (const EvalContext& ctx, const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs& iargs) const
2350 {
2351 return this->applyMonotone(ctx, iargs.a, iargs.b);
2352 }
2353
applyMonotone(const EvalContext & ctx,const Interval & xi,const Interval & yi) const2354 Interval applyMonotone (const EvalContext& ctx,
2355 const Interval& xi,
2356 const Interval& yi) const
2357 {
2358 Interval reti;
2359
2360 TCU_INTERVAL_APPLY_MONOTONE2(reti, x, xi, y, yi, ret,
2361 TCU_SET_INTERVAL(ret, point,
2362 point = this->applyPoint(ctx, x, y)));
2363 reti |= innerExtrema(ctx, xi, yi);
2364 reti &= (this->getCodomain(ctx) | TCU_NAN);
2365
2366 return ctx.format.convert(reti);
2367 }
2368
innerExtrema(const EvalContext &,const Interval &,const Interval &) const2369 virtual Interval innerExtrema (const EvalContext&,
2370 const Interval&,
2371 const Interval&) const
2372 {
2373 return Interval(); // empty interval, i.e. no extrema
2374 }
2375
applyPoint(const EvalContext & ctx,double x,double y) const2376 virtual Interval applyPoint (const EvalContext& ctx,
2377 double x,
2378 double y) const
2379 {
2380 const double exact = this->applyExact(x, y);
2381 const double prec = this->precision(ctx, exact, x, y);
2382
2383 return exact + Interval(-prec, prec);
2384 }
2385
applyExact(double,double) const2386 virtual double applyExact (double, double) const
2387 {
2388 TCU_THROW(InternalError, "Cannot apply");
2389 }
2390
getCodomain(const EvalContext &) const2391 virtual Interval getCodomain (const EvalContext&) const
2392 {
2393 return Interval::unbounded(true);
2394 }
2395
2396 virtual double precision (const EvalContext& ctx,
2397 double ret,
2398 double x,
2399 double y) const = 0;
2400 };
2401
2402 template <class T>
2403 class CFloatFunc2 : public FloatFunc2<T>
2404 {
2405 public:
CFloatFunc2(const string & name,tcu::DoubleFunc2 & func)2406 CFloatFunc2 (const string& name,
2407 tcu::DoubleFunc2& func)
2408 : m_name(name)
2409 , m_func(func)
2410 {
2411 }
2412
getName(void) const2413 string getName (void) const { return m_name; }
2414
2415 protected:
applyExact(double x,double y) const2416 double applyExact (double x, double y) const { return m_func(x, y); }
2417
2418 const string m_name;
2419 tcu::DoubleFunc2& m_func;
2420 };
2421
2422 template <class T>
2423 class InfixOperator : public FloatFunc2<T>
2424 {
2425 protected:
2426 virtual string getSymbol (void) const = 0;
2427
doPrint(ostream & os,const BaseArgExprs & args) const2428 void doPrint (ostream& os, const BaseArgExprs& args) const
2429 {
2430 os << "(" << *args[0] << " " << getSymbol() << " " << *args[1] << ")";
2431 }
2432
applyPoint(const EvalContext & ctx,double x,double y) const2433 Interval applyPoint (const EvalContext& ctx,
2434 double x,
2435 double y) const
2436 {
2437 const double exact = this->applyExact(x, y);
2438
2439 // Allow either representable number on both sides of the exact value,
2440 // but require exactly representable values to be preserved.
2441 return ctx.format.roundOut(exact, !deIsInf(x) && !deIsInf(y));
2442 }
2443
precision(const EvalContext &,double,double,double) const2444 double precision (const EvalContext&, double, double, double) const
2445 {
2446 return 0.0;
2447 }
2448 };
2449
2450 class InfixOperator16Bit : public FloatFunc2 <Signature<float, deFloat16, deFloat16> >
2451 {
2452 protected:
2453 virtual string getSymbol (void) const = 0;
2454
doPrint(ostream & os,const BaseArgExprs & args) const2455 void doPrint (ostream& os, const BaseArgExprs& args) const
2456 {
2457 os << "(" << *args[0] << " " << getSymbol() << " " << *args[1] << ")";
2458 }
2459
applyPoint(const EvalContext & ctx,double x,double y) const2460 Interval applyPoint (const EvalContext& ctx,
2461 double x,
2462 double y) const
2463 {
2464 const double exact = this->applyExact(x, y);
2465
2466 // Allow either representable number on both sides of the exact value,
2467 // but require exactly representable values to be preserved.
2468 return ctx.format.roundOut(exact, !deIsInf(x) && !deIsInf(y));
2469 }
2470
precision(const EvalContext &,double,double,double) const2471 double precision (const EvalContext&, double, double, double) const
2472 {
2473 return 0.0;
2474 }
2475 };
2476
2477 template <class T>
2478 class FloatFunc3 : public PrimitiveFunc<T>
2479 {
2480 protected:
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1,typename T::Arg2>::IArgs & iargs) const2481 Interval doApply (const EvalContext& ctx, const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1, typename T::Arg2>::IArgs& iargs) const
2482 {
2483 return this->applyMonotone(ctx, iargs.a, iargs.b, iargs.c);
2484 }
2485
applyMonotone(const EvalContext & ctx,const Interval & xi,const Interval & yi,const Interval & zi) const2486 Interval applyMonotone (const EvalContext& ctx,
2487 const Interval& xi,
2488 const Interval& yi,
2489 const Interval& zi) const
2490 {
2491 Interval reti;
2492 TCU_INTERVAL_APPLY_MONOTONE3(reti, x, xi, y, yi, z, zi, ret,
2493 TCU_SET_INTERVAL(ret, point,
2494 point = this->applyPoint(ctx, x, y, z)));
2495 return ctx.format.convert(reti);
2496 }
2497
applyPoint(const EvalContext & ctx,double x,double y,double z) const2498 virtual Interval applyPoint (const EvalContext& ctx,
2499 double x,
2500 double y,
2501 double z) const
2502 {
2503 const double exact = this->applyExact(x, y, z);
2504 const double prec = this->precision(ctx, exact, x, y, z);
2505 return exact + Interval(-prec, prec);
2506 }
2507
applyExact(double,double,double) const2508 virtual double applyExact (double, double, double) const
2509 {
2510 TCU_THROW(InternalError, "Cannot apply");
2511 }
2512
2513 virtual double precision (const EvalContext& ctx,
2514 double result,
2515 double x,
2516 double y,
2517 double z) const = 0;
2518 };
2519
2520 // We define syntactic sugar functions for expression constructors. Since
2521 // these have the same names as ordinary mathematical operations (sin, log
2522 // etc.), it's better to give them a dedicated namespace.
2523 namespace Functions
2524 {
2525
2526 using namespace tcu;
2527
2528 template <class T>
2529 class Comparison : public InfixOperator < T >
2530 {
2531 public:
getName(void) const2532 string getName (void) const { return "comparison"; }
getSymbol(void) const2533 string getSymbol (void) const { return ""; }
2534
getSpirvCase() const2535 SpirVCaseT getSpirvCase () const { return SPIRV_CASETYPE_COMPARE; }
2536
doApply(const EvalContext & ctx,const typename Comparison<T>::IArgs & iargs) const2537 Interval doApply (const EvalContext& ctx,
2538 const typename Comparison<T>::IArgs& iargs) const
2539 {
2540 DE_UNREF(ctx);
2541 if (iargs.a.hasNaN() || iargs.b.hasNaN())
2542 {
2543 return TCU_NAN; // one of the floats is NaN: block analysis
2544 }
2545
2546 int operationFlag = 1;
2547 int result = 0;
2548 const double a = iargs.a.midpoint();
2549 const double b = iargs.b.midpoint();
2550
2551 for (int i = 0; i<2; ++i)
2552 {
2553 if (a == b)
2554 result += operationFlag;
2555 operationFlag = operationFlag << 1;
2556
2557 if (a > b)
2558 result += operationFlag;
2559 operationFlag = operationFlag << 1;
2560
2561 if (a < b)
2562 result += operationFlag;
2563 operationFlag = operationFlag << 1;
2564
2565 if (a >= b)
2566 result += operationFlag;
2567 operationFlag = operationFlag << 1;
2568
2569 if (a <= b)
2570 result += operationFlag;
2571 operationFlag = operationFlag << 1;
2572 }
2573 return result;
2574 }
2575 };
2576
2577 template <class T>
2578 class Add : public InfixOperator < T >
2579 {
2580 public:
getName(void) const2581 string getName (void) const { return "add"; }
getSymbol(void) const2582 string getSymbol (void) const { return "+"; }
2583
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2584 Interval doApply (const EvalContext& ctx,
2585 const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs& iargs) const
2586 {
2587 // Fast-path for common case
2588 if (iargs.a.isOrdinary(ctx.format.getMaxValue()) && iargs.b.isOrdinary(ctx.format.getMaxValue()))
2589 {
2590 Interval ret;
2591 TCU_SET_INTERVAL_BOUNDS(ret, sum,
2592 sum = iargs.a.lo() + iargs.b.lo(),
2593 sum = iargs.a.hi() + iargs.b.hi());
2594 return ctx.format.convert(ctx.format.roundOut(ret, true));
2595 }
2596 return this->applyMonotone(ctx, iargs.a, iargs.b);
2597 }
2598
2599 protected:
applyExact(double x,double y) const2600 double applyExact (double x, double y) const { return x + y; }
2601 };
2602
2603 template<class T>
2604 class Mul : public InfixOperator<T>
2605 {
2606 public:
getName(void) const2607 string getName (void) const { return "mul"; }
getSymbol(void) const2608 string getSymbol (void) const { return "*"; }
2609
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2610 Interval doApply (const EvalContext& ctx, const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs& iargs) const
2611 {
2612 Interval a = iargs.a;
2613 Interval b = iargs.b;
2614
2615 // Fast-path for common case
2616 if (a.isOrdinary(ctx.format.getMaxValue()) && b.isOrdinary(ctx.format.getMaxValue()))
2617 {
2618 Interval ret;
2619 if (a.hi() < 0)
2620 {
2621 a = -a;
2622 b = -b;
2623 }
2624 if (a.lo() >= 0 && b.lo() >= 0)
2625 {
2626 TCU_SET_INTERVAL_BOUNDS(ret, prod,
2627 prod = a.lo() * b.lo(),
2628 prod = a.hi() * b.hi());
2629 return ctx.format.convert(ctx.format.roundOut(ret, true));
2630 }
2631 if (a.lo() >= 0 && b.hi() <= 0)
2632 {
2633 TCU_SET_INTERVAL_BOUNDS(ret, prod,
2634 prod = a.hi() * b.lo(),
2635 prod = a.lo() * b.hi());
2636 return ctx.format.convert(ctx.format.roundOut(ret, true));
2637 }
2638 }
2639 return this->applyMonotone(ctx, iargs.a, iargs.b);
2640 }
2641
2642 protected:
applyExact(double x,double y) const2643 double applyExact (double x, double y) const { return x * y; }
2644
innerExtrema(const EvalContext &,const Interval & xi,const Interval & yi) const2645 Interval innerExtrema(const EvalContext&, const Interval& xi, const Interval& yi) const
2646 {
2647 if (((xi.contains(-TCU_INFINITY) || xi.contains(TCU_INFINITY)) && yi.contains(0.0)) ||
2648 ((yi.contains(-TCU_INFINITY) || yi.contains(TCU_INFINITY)) && xi.contains(0.0)))
2649 return Interval(TCU_NAN);
2650
2651 return Interval();
2652 }
2653 };
2654
2655 template<class T>
2656 class Sub : public InfixOperator <T>
2657 {
2658 public:
getName(void) const2659 string getName (void) const { return "sub"; }
getSymbol(void) const2660 string getSymbol (void) const { return "-"; }
2661
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2662 Interval doApply (const EvalContext& ctx, const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs& iargs) const
2663 {
2664 // Fast-path for common case
2665 if (iargs.a.isOrdinary(ctx.format.getMaxValue()) && iargs.b.isOrdinary(ctx.format.getMaxValue()))
2666 {
2667 Interval ret;
2668
2669 TCU_SET_INTERVAL_BOUNDS(ret, diff,
2670 diff = iargs.a.lo() - iargs.b.hi(),
2671 diff = iargs.a.hi() - iargs.b.lo());
2672 return ctx.format.convert(ctx.format.roundOut(ret, true));
2673
2674 }
2675 else
2676 {
2677 return this->applyMonotone(ctx, iargs.a, iargs.b);
2678 }
2679 }
2680
2681 protected:
applyExact(double x,double y) const2682 double applyExact (double x, double y) const { return x - y; }
2683 };
2684
2685 template <class T>
2686 class Negate : public FloatFunc1<T>
2687 {
2688 public:
getName(void) const2689 string getName (void) const { return "_negate"; }
doPrint(ostream & os,const BaseArgExprs & args) const2690 void doPrint (ostream& os, const BaseArgExprs& args) const { os << "-" << *args[0]; }
2691
2692 protected:
precision(const EvalContext &,double,double) const2693 double precision (const EvalContext&, double, double) const { return 0.0; }
applyExact(double x) const2694 double applyExact (double x) const { return -x; }
2695 };
2696
2697 template <class T>
2698 class Div : public InfixOperator<T>
2699 {
2700 public:
getName(void) const2701 string getName (void) const { return "div"; }
2702
2703 protected:
getSymbol(void) const2704 string getSymbol (void) const { return "/"; }
2705
innerExtrema(const EvalContext &,const Interval & nom,const Interval & den) const2706 Interval innerExtrema (const EvalContext&,
2707 const Interval& nom,
2708 const Interval& den) const
2709 {
2710 Interval ret;
2711
2712 if (den.contains(0.0))
2713 {
2714 if (nom.contains(0.0))
2715 ret |= TCU_NAN;
2716
2717 if (nom.lo() < 0.0 || nom.hi() > 0.0)
2718 ret |= Interval::unbounded();
2719 }
2720
2721 return ret;
2722 }
2723
applyExact(double x,double y) const2724 double applyExact (double x, double y) const { return x / y; }
2725
applyPoint(const EvalContext & ctx,double x,double y) const2726 Interval applyPoint (const EvalContext& ctx, double x, double y) const
2727 {
2728 Interval ret = FloatFunc2<T>::applyPoint(ctx, x, y);
2729
2730 if (!deIsInf(x) && !deIsInf(y) && y != 0.0)
2731 {
2732 const Interval dst = ctx.format.convert(ret);
2733 if (dst.contains(-TCU_INFINITY)) ret |= -ctx.format.getMaxValue();
2734 if (dst.contains(+TCU_INFINITY)) ret |= +ctx.format.getMaxValue();
2735 }
2736
2737 return ret;
2738 }
2739
precision(const EvalContext & ctx,double ret,double,double den) const2740 double precision (const EvalContext& ctx, double ret, double, double den) const
2741 {
2742 const FloatFormat& fmt = ctx.format;
2743
2744 // \todo [2014-03-05 lauri] Check that the limits in GLSL 3.10 are actually correct.
2745 // For now, we assume that division's precision is 2.5 ULP when the value is within
2746 // [2^MINEXP, 2^MAXEXP-1]
2747
2748 if (den == 0.0)
2749 return 0.0; // Result must be exactly inf
2750 else if (de::inBounds(deAbs(den),
2751 deLdExp(1.0, fmt.getMinExp()),
2752 deLdExp(1.0, fmt.getMaxExp() - 1)))
2753 return fmt.ulp(ret, 2.5);
2754 else
2755 return TCU_INFINITY; // Can be any number, but must be a number.
2756 }
2757 };
2758
2759 template <class T>
2760 class InverseSqrt : public FloatFunc1 <T>
2761 {
2762 public:
getName(void) const2763 string getName (void) const { return "inversesqrt"; }
2764
2765 protected:
applyExact(double x) const2766 double applyExact (double x) const { return 1.0 / deSqrt(x); }
2767
precision(const EvalContext & ctx,double ret,double x) const2768 double precision (const EvalContext& ctx, double ret, double x) const
2769 {
2770 return x <= 0 ? TCU_NAN : ctx.format.ulp(ret, 2.0);
2771 }
2772
getCodomain(const EvalContext &) const2773 Interval getCodomain (const EvalContext&) const
2774 {
2775 return Interval(0.0, TCU_INFINITY);
2776 }
2777 };
2778
2779 template <class T>
2780 class ExpFunc : public CFloatFunc1<T>
2781 {
2782 public:
ExpFunc(const string & name,DoubleFunc1 & func)2783 ExpFunc (const string& name, DoubleFunc1& func)
2784 : CFloatFunc1<T> (name, func)
2785 {}
2786 protected:
2787 double precision (const EvalContext& ctx, double ret, double x) const;
getCodomain(const EvalContext &) const2788 Interval getCodomain (const EvalContext&) const
2789 {
2790 return Interval(0.0, TCU_INFINITY);
2791 }
2792 };
2793
2794 template <>
precision(const EvalContext & ctx,double ret,double x) const2795 double ExpFunc <Signature<float, float> >::precision (const EvalContext& ctx, double ret, double x) const
2796 {
2797 switch (ctx.floatPrecision)
2798 {
2799 case glu::PRECISION_HIGHP:
2800 return ctx.format.ulp(ret, 3.0 + 2.0 * deAbs(x));
2801 case glu::PRECISION_MEDIUMP:
2802 case glu::PRECISION_LAST:
2803 return ctx.format.ulp(ret, 1.0 + 2.0 * deAbs(x));
2804 default:
2805 DE_FATAL("Impossible");
2806 }
2807
2808 return 0.0;
2809 }
2810
2811 template <>
precision(const EvalContext & ctx,double ret,double x) const2812 double ExpFunc <Signature<deFloat16, deFloat16> >::precision(const EvalContext& ctx, double ret, double x) const
2813 {
2814 return ctx.format.ulp(ret, 1.0 + 2.0 * deAbs(x));
2815 }
2816
2817 template <>
precision(const EvalContext & ctx,double ret,double x) const2818 double ExpFunc <Signature<double, double> >::precision(const EvalContext& ctx, double ret, double x) const
2819 {
2820 return ctx.format.ulp(ret, 1.0 + 2.0 * deAbs(x));
2821 }
2822
2823 template <class T>
Exp2(void)2824 class Exp2 : public ExpFunc<T> { public: Exp2 (void) : ExpFunc<T>("exp2", deExp2) {} };
2825 template <class T>
Exp(void)2826 class Exp : public ExpFunc<T> { public: Exp (void) : ExpFunc<T>("exp", deExp) {} };
2827
2828 template <typename T>
exp2(const ExprP<T> & x)2829 ExprP<T> exp2 (const ExprP<T>& x) { return app<Exp2< Signature<T, T> > >(x); }
2830 template <typename T>
exp(const ExprP<T> & x)2831 ExprP<T> exp (const ExprP<T>& x) { return app<Exp< Signature<T, T> > >(x); }
2832
2833 template <class T>
2834 class LogFunc : public CFloatFunc1<T>
2835 {
2836 public:
LogFunc(const string & name,DoubleFunc1 & func)2837 LogFunc (const string& name, DoubleFunc1& func)
2838 : CFloatFunc1<T>(name, func) {}
2839
2840 protected:
2841 double precision (const EvalContext& ctx, double ret, double x) const;
2842 };
2843
2844 template <>
precision(const EvalContext & ctx,double ret,double x) const2845 double LogFunc<Signature<float, float> >::precision(const EvalContext& ctx, double ret, double x) const
2846 {
2847 if (x <= 0)
2848 return TCU_NAN;
2849
2850 switch (ctx.floatPrecision)
2851 {
2852 case glu::PRECISION_HIGHP:
2853 return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -21) : ctx.format.ulp(ret, 3.0);
2854 case glu::PRECISION_MEDIUMP:
2855 case glu::PRECISION_LAST:
2856 return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -7) : ctx.format.ulp(ret, 3.0);
2857 default:
2858 DE_FATAL("Impossible");
2859 }
2860
2861 return 0;
2862 }
2863
2864 template <>
precision(const EvalContext & ctx,double ret,double x) const2865 double LogFunc<Signature<deFloat16, deFloat16> >::precision(const EvalContext& ctx, double ret, double x) const
2866 {
2867 if (x <= 0)
2868 return TCU_NAN;
2869 return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -7) : ctx.format.ulp(ret, 3.0);
2870 }
2871
2872 // Spec: "The precision of double-precision instructions is at least that of single precision."
2873 // Lets pick float high precision as a reference.
2874 template <>
precision(const EvalContext & ctx,double ret,double x) const2875 double LogFunc<Signature<double, double> >::precision(const EvalContext& ctx, double ret, double x) const
2876 {
2877 if (x <= 0)
2878 return TCU_NAN;
2879 return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -21) : ctx.format.ulp(ret, 3.0);
2880 }
2881
2882 template <class T>
Log2(void)2883 class Log2 : public LogFunc<T> { public: Log2 (void) : LogFunc<T>("log2", deLog2) {} };
2884 template <class T>
Log(void)2885 class Log : public LogFunc<T> { public: Log (void) : LogFunc<T>("log", deLog) {} };
2886
log2(const ExprP<float> & x)2887 ExprP<float> log2 (const ExprP<float>& x) { return app<Log2< Signature<float, float> > >(x); }
log(const ExprP<float> & x)2888 ExprP<float> log (const ExprP<float>& x) { return app<Log< Signature<float, float> > >(x); }
2889
log2(const ExprP<deFloat16> & x)2890 ExprP<deFloat16> log2 (const ExprP<deFloat16>& x) { return app<Log2< Signature<deFloat16, deFloat16> > >(x); }
log(const ExprP<deFloat16> & x)2891 ExprP<deFloat16> log (const ExprP<deFloat16>& x) { return app<Log< Signature<deFloat16, deFloat16> > >(x); }
2892
log2(const ExprP<double> & x)2893 ExprP<double> log2 (const ExprP<double>& x) { return app<Log2< Signature<double, double> > >(x); }
log(const ExprP<double> & x)2894 ExprP<double> log (const ExprP<double>& x) { return app<Log< Signature<double, double> > >(x); }
2895
2896 #define DEFINE_CONSTRUCTOR1(CLASS, TRET, NAME, T0) \
2897 ExprP<TRET> NAME (const ExprP<T0>& arg0) { return app<CLASS>(arg0); }
2898
2899 #define DEFINE_DERIVED1(CLASS, TRET, NAME, T0, ARG0, EXPANSION) \
2900 class CLASS : public DerivedFunc<Signature<TRET, T0> > /* NOLINT(CLASS) */ \
2901 { \
2902 public: \
2903 string getName (void) const { return #NAME; } \
2904 \
2905 protected: \
2906 ExprP<TRET> doExpand (ExpandContext&, \
2907 const CLASS::ArgExprs& args_) const \
2908 { \
2909 const ExprP<T0>& ARG0 = args_.a; \
2910 return EXPANSION; \
2911 } \
2912 }; \
2913 DEFINE_CONSTRUCTOR1(CLASS, TRET, NAME, T0)
2914
2915 #define DEFINE_DERIVED_DOUBLE1(CLASS, NAME, ARG0, EXPANSION) \
2916 DEFINE_DERIVED1(CLASS, double, NAME, double, ARG0, EXPANSION)
2917
2918 #define DEFINE_DERIVED_FLOAT1(CLASS, NAME, ARG0, EXPANSION) \
2919 DEFINE_DERIVED1(CLASS, float, NAME, float, ARG0, EXPANSION)
2920
2921
2922 #define DEFINE_DERIVED1_INPUTRANGE(CLASS, TRET, NAME, T0, ARG0, EXPANSION, INTERVAL) \
2923 class CLASS : public DerivedFunc<Signature<TRET, T0> > /* NOLINT(CLASS) */ \
2924 { \
2925 public: \
2926 string getName (void) const { return #NAME; } \
2927 \
2928 protected: \
2929 ExprP<TRET> doExpand (ExpandContext&, \
2930 const CLASS::ArgExprs& args_) const \
2931 { \
2932 const ExprP<T0>& ARG0 = args_.a; \
2933 return EXPANSION; \
2934 } \
2935 Interval getInputRange (const bool /*is16bit*/) const \
2936 { \
2937 return INTERVAL; \
2938 } \
2939 }; \
2940 DEFINE_CONSTRUCTOR1(CLASS, TRET, NAME, T0)
2941
2942 #define DEFINE_DERIVED_FLOAT1_INPUTRANGE(CLASS, NAME, ARG0, EXPANSION, INTERVAL) \
2943 DEFINE_DERIVED1_INPUTRANGE(CLASS, float, NAME, float, ARG0, EXPANSION, INTERVAL)
2944
2945 #define DEFINE_DERIVED_DOUBLE1_INPUTRANGE(CLASS, NAME, ARG0, EXPANSION, INTERVAL) \
2946 DEFINE_DERIVED1_INPUTRANGE(CLASS, double, NAME, double, ARG0, EXPANSION, INTERVAL)
2947
2948 #define DEFINE_DERIVED_FLOAT1_16BIT(CLASS, NAME, ARG0, EXPANSION) \
2949 DEFINE_DERIVED1(CLASS, deFloat16, NAME, deFloat16, ARG0, EXPANSION)
2950
2951 #define DEFINE_DERIVED_FLOAT1_INPUTRANGE_16BIT(CLASS, NAME, ARG0, EXPANSION, INTERVAL) \
2952 DEFINE_DERIVED1_INPUTRANGE(CLASS, deFloat16, NAME, deFloat16, ARG0, EXPANSION, INTERVAL)
2953
2954 #define DEFINE_CONSTRUCTOR2(CLASS, TRET, NAME, T0, T1) \
2955 ExprP<TRET> NAME (const ExprP<T0>& arg0, const ExprP<T1>& arg1) \
2956 { \
2957 return app<CLASS>(arg0, arg1); \
2958 }
2959
2960 #define DEFINE_CASED_DERIVED2(CLASS, TRET, NAME, T0, Arg0, T1, Arg1, EXPANSION, SPIRVCASE) \
2961 class CLASS : public DerivedFunc<Signature<TRET, T0, T1> > /* NOLINT(CLASS) */ \
2962 { \
2963 public: \
2964 string getName (void) const { return #NAME; } \
2965 \
2966 SpirVCaseT getSpirvCase(void) const { return SPIRVCASE; } \
2967 \
2968 protected: \
2969 ExprP<TRET> doExpand (ExpandContext&, const ArgExprs& args_) const \
2970 { \
2971 const ExprP<T0>& Arg0 = args_.a; \
2972 const ExprP<T1>& Arg1 = args_.b; \
2973 return EXPANSION; \
2974 } \
2975 }; \
2976 DEFINE_CONSTRUCTOR2(CLASS, TRET, NAME, T0, T1)
2977
2978 #define DEFINE_DERIVED2(CLASS, TRET, NAME, T0, Arg0, T1, Arg1, EXPANSION) \
2979 DEFINE_CASED_DERIVED2(CLASS, TRET, NAME, T0, Arg0, T1, Arg1, EXPANSION, SPIRV_CASETYPE_NONE)
2980
2981 #define DEFINE_DERIVED_DOUBLE2(CLASS, NAME, Arg0, Arg1, EXPANSION) \
2982 DEFINE_DERIVED2(CLASS, double, NAME, double, Arg0, double, Arg1, EXPANSION)
2983
2984 #define DEFINE_DERIVED_FLOAT2(CLASS, NAME, Arg0, Arg1, EXPANSION) \
2985 DEFINE_DERIVED2(CLASS, float, NAME, float, Arg0, float, Arg1, EXPANSION)
2986
2987 #define DEFINE_DERIVED_FLOAT2_16BIT(CLASS, NAME, Arg0, Arg1, EXPANSION) \
2988 DEFINE_DERIVED2(CLASS, deFloat16, NAME, deFloat16, Arg0, deFloat16, Arg1, EXPANSION)
2989
2990 #define DEFINE_CASED_DERIVED_FLOAT2(CLASS, NAME, Arg0, Arg1, EXPANSION, SPIRVCASE) \
2991 DEFINE_CASED_DERIVED2(CLASS, float, NAME, float, Arg0, float, Arg1, EXPANSION, SPIRVCASE)
2992
2993 #define DEFINE_CASED_DERIVED_FLOAT2_16BIT(CLASS, NAME, Arg0, Arg1, EXPANSION, SPIRVCASE) \
2994 DEFINE_CASED_DERIVED2(CLASS, deFloat16, NAME, deFloat16, Arg0, deFloat16, Arg1, EXPANSION, SPIRVCASE)
2995
2996 #define DEFINE_CASED_DERIVED_DOUBLE2(CLASS, NAME, Arg0, Arg1, EXPANSION, SPIRVCASE) \
2997 DEFINE_CASED_DERIVED2(CLASS, double, NAME, double, Arg0, double, Arg1, EXPANSION, SPIRVCASE)
2998
2999 #define DEFINE_CONSTRUCTOR3(CLASS, TRET, NAME, T0, T1, T2) \
3000 ExprP<TRET> NAME (const ExprP<T0>& arg0, const ExprP<T1>& arg1, const ExprP<T2>& arg2) \
3001 { \
3002 return app<CLASS>(arg0, arg1, arg2); \
3003 }
3004
3005 #define DEFINE_DERIVED3(CLASS, TRET, NAME, T0, ARG0, T1, ARG1, T2, ARG2, EXPANSION) \
3006 class CLASS : public DerivedFunc<Signature<TRET, T0, T1, T2> > /* NOLINT(CLASS) */ \
3007 { \
3008 public: \
3009 string getName (void) const { return #NAME; } \
3010 \
3011 protected: \
3012 ExprP<TRET> doExpand (ExpandContext&, const ArgExprs& args_) const \
3013 { \
3014 const ExprP<T0>& ARG0 = args_.a; \
3015 const ExprP<T1>& ARG1 = args_.b; \
3016 const ExprP<T2>& ARG2 = args_.c; \
3017 return EXPANSION; \
3018 } \
3019 }; \
3020 DEFINE_CONSTRUCTOR3(CLASS, TRET, NAME, T0, T1, T2)
3021
3022 #define DEFINE_DERIVED_DOUBLE3(CLASS, NAME, ARG0, ARG1, ARG2, EXPANSION) \
3023 DEFINE_DERIVED3(CLASS, double, NAME, double, ARG0, double, ARG1, double, ARG2, EXPANSION)
3024
3025 #define DEFINE_DERIVED_FLOAT3(CLASS, NAME, ARG0, ARG1, ARG2, EXPANSION) \
3026 DEFINE_DERIVED3(CLASS, float, NAME, float, ARG0, float, ARG1, float, ARG2, EXPANSION)
3027
3028 #define DEFINE_DERIVED_FLOAT3_16BIT(CLASS, NAME, ARG0, ARG1, ARG2, EXPANSION) \
3029 DEFINE_DERIVED3(CLASS, deFloat16, NAME, deFloat16, ARG0, deFloat16, ARG1, deFloat16, ARG2, EXPANSION)
3030
3031 #define DEFINE_CONSTRUCTOR4(CLASS, TRET, NAME, T0, T1, T2, T3) \
3032 ExprP<TRET> NAME (const ExprP<T0>& arg0, const ExprP<T1>& arg1, \
3033 const ExprP<T2>& arg2, const ExprP<T3>& arg3) \
3034 { \
3035 return app<CLASS>(arg0, arg1, arg2, arg3); \
3036 }
3037
3038 typedef InverseSqrt< Signature<deFloat16, deFloat16> > InverseSqrt16Bit;
3039 typedef InverseSqrt< Signature<float, float> > InverseSqrt32Bit;
3040 typedef InverseSqrt< Signature<double, double> > InverseSqrt64Bit;
3041
3042 DEFINE_DERIVED_FLOAT1(Sqrt32Bit, sqrt, x, constant(1.0f) / app<InverseSqrt32Bit>(x))
3043 DEFINE_DERIVED_FLOAT1_16BIT(Sqrt16Bit, sqrt, x, constant((deFloat16)FLOAT16_1_0) / app<InverseSqrt16Bit>(x))
3044 DEFINE_DERIVED_DOUBLE1(Sqrt64Bit, sqrt, x, constant(1.0) / app<InverseSqrt64Bit>(x))
3045 DEFINE_DERIVED_FLOAT2(Pow, pow, x, y, exp2<float>(y * log2(x)))
3046 DEFINE_DERIVED_FLOAT2_16BIT(Pow16, pow, x, y, exp2<deFloat16>(y * log2(x)))
3047 DEFINE_DERIVED_FLOAT1(Radians, radians, d, f32Constant(DE_PI_DOUBLE / 180.0f) * d)
3048 DEFINE_DERIVED_FLOAT1_16BIT(Radians16, radians, d, f16Constant(DE_PI_DOUBLE / 180.0f) * d)
3049 DEFINE_DERIVED_FLOAT1(Degrees, degrees, r, f32Constant(180.0 / DE_PI_DOUBLE) * r)
3050 DEFINE_DERIVED_FLOAT1_16BIT(Degrees16, degrees, r, f16Constant(180.0 / DE_PI_DOUBLE) * r)
3051
3052 /*Proper parameters for template T
3053 Signature<float, float> 32bit tests
3054 Signature<float, deFloat16> 16bit tests*/
3055 template<class T>
3056 class TrigFunc : public CFloatFunc1<T>
3057 {
3058 public:
TrigFunc(const string & name,DoubleFunc1 & func,const Interval & loEx,const Interval & hiEx)3059 TrigFunc (const string& name,
3060 DoubleFunc1& func,
3061 const Interval& loEx,
3062 const Interval& hiEx)
3063 : CFloatFunc1<T> (name, func)
3064 , m_loExtremum (loEx)
3065 , m_hiExtremum (hiEx) {}
3066
3067 protected:
innerExtrema(const EvalContext &,const Interval & angle) const3068 Interval innerExtrema (const EvalContext&, const Interval& angle) const
3069 {
3070 const double lo = angle.lo();
3071 const double hi = angle.hi();
3072 const int loSlope = doGetSlope(lo);
3073 const int hiSlope = doGetSlope(hi);
3074
3075 // Detect the high and low values the function can take between the
3076 // interval endpoints.
3077 if (angle.length() >= 2.0 * DE_PI_DOUBLE)
3078 {
3079 // The interval is longer than a full cycle, so it must get all possible values.
3080 return m_hiExtremum | m_loExtremum;
3081 }
3082 else if (loSlope == 1 && hiSlope == -1)
3083 {
3084 // The slope can change from positive to negative only at the maximum value.
3085 return m_hiExtremum;
3086 }
3087 else if (loSlope == -1 && hiSlope == 1)
3088 {
3089 // The slope can change from negative to positive only at the maximum value.
3090 return m_loExtremum;
3091 }
3092 else if (loSlope == hiSlope &&
3093 deIntSign(CFloatFunc1<T>::applyExact(hi) - CFloatFunc1<T>::applyExact(lo)) * loSlope == -1)
3094 {
3095 // The slope has changed twice between the endpoints, so both extrema are included.
3096 return m_hiExtremum | m_loExtremum;
3097 }
3098
3099 return Interval();
3100 }
3101
getCodomain(const EvalContext &) const3102 Interval getCodomain (const EvalContext&) const
3103 {
3104 // Ensure that result is always within [-1, 1], or NaN (for +-inf)
3105 return Interval(-1.0, 1.0) | TCU_NAN;
3106 }
3107
3108 double precision (const EvalContext& ctx, double ret, double arg) const;
3109
3110 Interval getInputRange (const bool is16bit) const;
3111 virtual int doGetSlope (double angle) const = 0;
3112
3113 Interval m_loExtremum;
3114 Interval m_hiExtremum;
3115 };
3116
3117 //Only -DE_PI_DOUBLE, DE_PI_DOUBLE input range
3118 template<>
getInputRange(const bool is16bit) const3119 Interval TrigFunc<Signature<float, float> >::getInputRange(const bool is16bit) const
3120 {
3121 DE_UNREF(is16bit);
3122 return Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE);
3123 }
3124
3125 //Only -DE_PI_DOUBLE, DE_PI_DOUBLE input range
3126 template<>
getInputRange(const bool is16bit) const3127 Interval TrigFunc<Signature<deFloat16, deFloat16> >::getInputRange(const bool is16bit) const
3128 {
3129 DE_UNREF(is16bit);
3130 return Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE);
3131 }
3132
3133 //Only -DE_PI_DOUBLE, DE_PI_DOUBLE input range
3134 template<>
getInputRange(const bool is16bit) const3135 Interval TrigFunc<Signature<double, double> >::getInputRange(const bool is16bit) const
3136 {
3137 DE_UNREF(is16bit);
3138 return Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE);
3139 }
3140
3141 template<>
precision(const EvalContext & ctx,double ret,double arg) const3142 double TrigFunc<Signature<float, float> >::precision(const EvalContext& ctx, double ret, double arg) const
3143 {
3144 DE_UNREF(ret);
3145 if (ctx.floatPrecision == glu::PRECISION_HIGHP)
3146 {
3147 if (-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE)
3148 return deLdExp(1.0, -11);
3149 else
3150 {
3151 // "larger otherwise", let's pick |x| * 2^-12 , which is slightly over
3152 // 2^-11 at x == pi.
3153 return deLdExp(deAbs(arg), -12);
3154 }
3155 }
3156 else
3157 {
3158 DE_ASSERT(ctx.floatPrecision == glu::PRECISION_MEDIUMP || ctx.floatPrecision == glu::PRECISION_LAST);
3159
3160 if (-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE)
3161 return deLdExp(1.0, -7);
3162 else
3163 {
3164 // |x| * 2^-8, slightly larger than 2^-7 at x == pi
3165 return deLdExp(deAbs(arg), -8);
3166 }
3167 }
3168 }
3169 //
3170 /*
3171 * Half tests
3172 * From Spec:
3173 * Absolute error 2^{-7} inside the range [-pi, pi].
3174 */
3175 template<>
precision(const EvalContext & ctx,double ret,double arg) const3176 double TrigFunc<Signature<deFloat16, deFloat16> >::precision(const EvalContext& ctx, double ret, double arg) const
3177 {
3178 DE_UNREF(ctx);
3179 DE_UNREF(ret);
3180 DE_UNREF(arg);
3181 DE_ASSERT(-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE && ctx.floatPrecision == glu::PRECISION_LAST);
3182 return deLdExp(1.0, -7);
3183 }
3184
3185 // Spec: "The precision of double-precision instructions is at least that of single precision."
3186 // Lets pick float high precision as a reference.
3187 template<>
precision(const EvalContext & ctx,double ret,double arg) const3188 double TrigFunc<Signature<double, double> >::precision(const EvalContext& ctx, double ret, double arg) const
3189 {
3190 DE_UNREF(ctx);
3191 DE_UNREF(ret);
3192 if (-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE)
3193 return deLdExp(1.0, -11);
3194 else
3195 {
3196 // "larger otherwise", let's pick |x| * 2^-12 , which is slightly over
3197 // 2^-11 at x == pi.
3198 return deLdExp(deAbs(arg), -12);
3199 }
3200 }
3201
3202 /*Proper parameters for template T
3203 Signature<float, float> 32bit tests
3204 Signature<float, deFloat16> 16bit tests*/
3205 template <class T>
3206 class Sin : public TrigFunc<T>
3207 {
3208 public:
Sin(void)3209 Sin (void) : TrigFunc<T>("sin", deSin, -1.0, 1.0) {}
3210
3211 protected:
doGetSlope(double angle) const3212 int doGetSlope (double angle) const { return deIntSign(deCos(angle)); }
3213 };
3214
sin(const ExprP<float> & x)3215 ExprP<float> sin (const ExprP<float>& x) { return app<Sin<Signature<float, float> > >(x); }
sin(const ExprP<deFloat16> & x)3216 ExprP<deFloat16> sin (const ExprP<deFloat16>& x) { return app<Sin<Signature<deFloat16, deFloat16> > >(x); }
sin(const ExprP<double> & x)3217 ExprP<double> sin (const ExprP<double>& x) { return app<Sin<Signature<double, double> > >(x); }
3218
3219 template <class T>
3220 class Cos : public TrigFunc<T>
3221 {
3222 public:
Cos(void)3223 Cos (void) : TrigFunc<T> ("cos", deCos, -1.0, 1.0) {}
3224
3225 protected:
doGetSlope(double angle) const3226 int doGetSlope (double angle) const { return -deIntSign(deSin(angle)); }
3227 };
3228
cos(const ExprP<float> & x)3229 ExprP<float> cos (const ExprP<float>& x) { return app<Cos<Signature<float, float> > >(x); }
cos(const ExprP<deFloat16> & x)3230 ExprP<deFloat16> cos (const ExprP<deFloat16>& x) { return app<Cos<Signature<deFloat16, deFloat16> > >(x); }
cos(const ExprP<double> & x)3231 ExprP<double> cos (const ExprP<double>& x) { return app<Cos<Signature<double, double> > >(x); }
3232
3233 DEFINE_DERIVED_FLOAT1_INPUTRANGE(Tan, tan, x, sin(x) * (constant(1.0f) / cos(x)), Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE))
3234 DEFINE_DERIVED_FLOAT1_INPUTRANGE_16BIT(Tan16Bit, tan, x, sin(x) * (constant((deFloat16)FLOAT16_1_0) / cos(x)), Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE))
3235
3236 template <class T>
3237 class ATan : public CFloatFunc1<T>
3238 {
3239 public:
ATan(void)3240 ATan (void) : CFloatFunc1<T> ("atan", deAtanOver) {}
3241
3242 protected:
precision(const EvalContext & ctx,double ret,double) const3243 double precision (const EvalContext& ctx, double ret, double) const
3244 {
3245 if (ctx.floatPrecision == glu::PRECISION_HIGHP)
3246 return ctx.format.ulp(ret, 4096.0);
3247 else
3248 return ctx.format.ulp(ret, 5.0);
3249 }
3250
getCodomain(const EvalContext & ctx) const3251 Interval getCodomain(const EvalContext& ctx) const
3252 {
3253 return ctx.format.roundOut(Interval(-0.5 * DE_PI_DOUBLE, 0.5 * DE_PI_DOUBLE), true);
3254 }
3255 };
3256
3257 template <class T>
3258 class ATan2 : public CFloatFunc2<T>
3259 {
3260 public:
ATan2(void)3261 ATan2 (void) : CFloatFunc2<T> ("atan", deAtan2) {}
3262
3263 protected:
innerExtrema(const EvalContext & ctx,const Interval & yi,const Interval & xi) const3264 Interval innerExtrema (const EvalContext& ctx,
3265 const Interval& yi,
3266 const Interval& xi) const
3267 {
3268 Interval ret;
3269
3270 if (yi.contains(0.0))
3271 {
3272 if (xi.contains(0.0))
3273 ret |= TCU_NAN;
3274 if (xi.intersects(Interval(-TCU_INFINITY, 0.0)))
3275 ret |= ctx.format.roundOut(Interval(-DE_PI_DOUBLE, DE_PI_DOUBLE), true);
3276 }
3277
3278 if (!yi.isFinite(ctx.format.getMaxValue()) || !xi.isFinite(ctx.format.getMaxValue()))
3279 {
3280 // Infinities may not be supported, allow anything, including NaN
3281 ret |= TCU_NAN;
3282 }
3283
3284 return ret;
3285 }
3286
precision(const EvalContext & ctx,double ret,double,double) const3287 double precision (const EvalContext& ctx, double ret, double, double) const
3288 {
3289 if (ctx.floatPrecision == glu::PRECISION_HIGHP)
3290 return ctx.format.ulp(ret, 4096.0);
3291 else
3292 return ctx.format.ulp(ret, 5.0);
3293 }
3294
getCodomain(const EvalContext & ctx) const3295 Interval getCodomain(const EvalContext& ctx) const
3296 {
3297 return ctx.format.roundOut(Interval(-DE_PI_DOUBLE, DE_PI_DOUBLE), true);
3298 }
3299 };
3300
atan2(const ExprP<float> & x,const ExprP<float> & y)3301 ExprP<float> atan2 (const ExprP<float>& x, const ExprP<float>& y) { return app<ATan2<Signature<float, float, float> > >(x, y); }
3302
atan2(const ExprP<deFloat16> & x,const ExprP<deFloat16> & y)3303 ExprP<deFloat16> atan2 (const ExprP<deFloat16>& x, const ExprP<deFloat16>& y) { return app<ATan2<Signature<deFloat16, deFloat16, deFloat16> > >(x, y); }
3304
atan2(const ExprP<double> & x,const ExprP<double> & y)3305 ExprP<double> atan2 (const ExprP<double>& x, const ExprP<double>& y) { return app<ATan2<Signature<double, double, double> > >(x, y); }
3306
3307
3308 DEFINE_DERIVED_FLOAT1(Sinh, sinh, x, (exp<float>(x) - exp<float>(-x)) / constant(2.0f))
3309 DEFINE_DERIVED_FLOAT1(Cosh, cosh, x, (exp<float>(x) + exp<float>(-x)) / constant(2.0f))
3310 DEFINE_DERIVED_FLOAT1(Tanh, tanh, x, sinh(x) / cosh(x))
3311
3312 DEFINE_DERIVED_FLOAT1_16BIT(Sinh16Bit, sinh, x, (exp(x) - exp(-x)) / constant((deFloat16)FLOAT16_2_0))
3313 DEFINE_DERIVED_FLOAT1_16BIT(Cosh16Bit, cosh, x, (exp(x) + exp(-x)) / constant((deFloat16)FLOAT16_2_0))
3314 DEFINE_DERIVED_FLOAT1_16BIT(Tanh16Bit, tanh, x, sinh(x) / cosh(x))
3315
3316 DEFINE_DERIVED_FLOAT1(ASin, asin, x, atan2(x, sqrt(constant(1.0f) - x * x)))
3317 DEFINE_DERIVED_FLOAT1(ACos, acos, x, atan2(sqrt(constant(1.0f) - x * x), x))
3318 DEFINE_DERIVED_FLOAT1(ASinh, asinh, x, log(x + sqrt(x * x + constant(1.0f))))
3319 DEFINE_DERIVED_FLOAT1(ACosh, acosh, x, log(x + sqrt(alternatives((x + constant(1.0f)) * (x - constant(1.0f)),
3320 (x * x - constant(1.0f))))))
3321 DEFINE_DERIVED_FLOAT1(ATanh, atanh, x, constant(0.5f) * log((constant(1.0f) + x) /
3322 (constant(1.0f) - x)))
3323
3324 DEFINE_DERIVED_FLOAT1_16BIT(ASin16Bit, asin, x, atan2(x, sqrt(constant((deFloat16)FLOAT16_1_0) - x * x)))
3325 DEFINE_DERIVED_FLOAT1_16BIT(ACos16Bit, acos, x, atan2(sqrt(constant((deFloat16)FLOAT16_1_0) - x * x), x))
3326 DEFINE_DERIVED_FLOAT1_16BIT(ASinh16Bit, asinh, x, log(x + sqrt(x * x + constant((deFloat16)FLOAT16_1_0))))
3327 DEFINE_DERIVED_FLOAT1_16BIT(ACosh16Bit, acosh, x, log(x + sqrt(alternatives((x + constant((deFloat16)FLOAT16_1_0)) * (x - constant((deFloat16)FLOAT16_1_0)),
3328 (x * x - constant((deFloat16)FLOAT16_1_0))))))
3329 DEFINE_DERIVED_FLOAT1_16BIT(ATanh16Bit, atanh, x, constant((deFloat16)FLOAT16_0_5) * log((constant((deFloat16)FLOAT16_1_0) + x) /
3330 (constant((deFloat16)FLOAT16_1_0) - x)))
3331
3332 template <typename T>
3333 class GetComponent : public PrimitiveFunc<Signature<typename T::Element, T, int> >
3334 {
3335 public:
3336 typedef typename GetComponent::IRet IRet;
3337
getName(void) const3338 string getName (void) const { return "_getComponent"; }
3339
print(ostream & os,const BaseArgExprs & args) const3340 void print (ostream& os,
3341 const BaseArgExprs& args) const
3342 {
3343 os << *args[0] << "[" << *args[1] << "]";
3344 }
3345
3346 protected:
doApply(const EvalContext &,const typename GetComponent::IArgs & iargs) const3347 IRet doApply (const EvalContext&,
3348 const typename GetComponent::IArgs& iargs) const
3349 {
3350 IRet ret;
3351
3352 for (int compNdx = 0; compNdx < T::SIZE; ++compNdx)
3353 {
3354 if (iargs.b.contains(compNdx))
3355 ret = unionIVal<typename T::Element>(ret, iargs.a[compNdx]);
3356 }
3357
3358 return ret;
3359 }
3360
3361 };
3362
3363 template <typename T>
getComponent(const ExprP<T> & container,int ndx)3364 ExprP<typename T::Element> getComponent (const ExprP<T>& container, int ndx)
3365 {
3366 DE_ASSERT(0 <= ndx && ndx < T::SIZE);
3367 return app<GetComponent<T> >(container, constant(ndx));
3368 }
3369
3370 template <typename T> string vecNamePrefix (void);
vecNamePrefix(void)3371 template <> string vecNamePrefix<float> (void) { return ""; }
vecNamePrefix(void)3372 template <> string vecNamePrefix<deFloat16>(void) { return ""; }
vecNamePrefix(void)3373 template <> string vecNamePrefix<double> (void) { return "d"; }
vecNamePrefix(void)3374 template <> string vecNamePrefix<int> (void) { return "i"; }
vecNamePrefix(void)3375 template <> string vecNamePrefix<bool> (void) { return "b"; }
3376
3377 template <typename T, int Size>
vecName(void)3378 string vecName (void) { return vecNamePrefix<T>() + "vec" + de::toString(Size); }
3379
3380 template <typename T, int Size> class GenVec;
3381
3382 template <typename T>
3383 class GenVec<T, 1> : public DerivedFunc<Signature<T, T> >
3384 {
3385 public:
3386 typedef typename GenVec<T, 1>::ArgExprs ArgExprs;
3387
getName(void) const3388 string getName (void) const
3389 {
3390 return "_" + vecName<T, 1>();
3391 }
3392
3393 protected:
3394
doExpand(ExpandContext &,const ArgExprs & args) const3395 ExprP<T> doExpand (ExpandContext&, const ArgExprs& args) const { return args.a; }
3396 };
3397
3398 template <typename T>
3399 class GenVec<T, 2> : public PrimitiveFunc<Signature<Vector<T, 2>, T, T> >
3400 {
3401 public:
3402 typedef typename GenVec::IRet IRet;
3403 typedef typename GenVec::IArgs IArgs;
3404
getName(void) const3405 string getName (void) const
3406 {
3407 return vecName<T, 2>();
3408 }
3409
3410 protected:
doApply(const EvalContext &,const IArgs & iargs) const3411 IRet doApply (const EvalContext&, const IArgs& iargs) const
3412 {
3413 return IRet(iargs.a, iargs.b);
3414 }
3415 };
3416
3417 template <typename T>
3418 class GenVec<T, 3> : public PrimitiveFunc<Signature<Vector<T, 3>, T, T, T> >
3419 {
3420 public:
3421 typedef typename GenVec::IRet IRet;
3422 typedef typename GenVec::IArgs IArgs;
3423
getName(void) const3424 string getName (void) const
3425 {
3426 return vecName<T, 3>();
3427 }
3428
3429 protected:
doApply(const EvalContext &,const IArgs & iargs) const3430 IRet doApply (const EvalContext&, const IArgs& iargs) const
3431 {
3432 return IRet(iargs.a, iargs.b, iargs.c);
3433 }
3434 };
3435
3436 template <typename T>
3437 class GenVec<T, 4> : public PrimitiveFunc<Signature<Vector<T, 4>, T, T, T, T> >
3438 {
3439 public:
3440 typedef typename GenVec::IRet IRet;
3441 typedef typename GenVec::IArgs IArgs;
3442
getName(void) const3443 string getName (void) const { return vecName<T, 4>(); }
3444
3445 protected:
doApply(const EvalContext &,const IArgs & iargs) const3446 IRet doApply (const EvalContext&, const IArgs& iargs) const
3447 {
3448 return IRet(iargs.a, iargs.b, iargs.c, iargs.d);
3449 }
3450 };
3451
3452 template <typename T, int Rows, int Columns>
3453 class GenMat;
3454
3455 template <typename T, int Rows>
3456 class GenMat<T, Rows, 2> : public PrimitiveFunc<
3457 Signature<Matrix<T, Rows, 2>, Vector<T, Rows>, Vector<T, Rows> > >
3458 {
3459 public:
3460 typedef typename GenMat::Ret Ret;
3461 typedef typename GenMat::IRet IRet;
3462 typedef typename GenMat::IArgs IArgs;
3463
getName(void) const3464 string getName (void) const
3465 {
3466 return dataTypeNameOf<Ret>();
3467 }
3468
3469 protected:
3470
doApply(const EvalContext &,const IArgs & iargs) const3471 IRet doApply (const EvalContext&, const IArgs& iargs) const
3472 {
3473 IRet ret;
3474 ret[0] = iargs.a;
3475 ret[1] = iargs.b;
3476 return ret;
3477 }
3478 };
3479
3480 template <typename T, int Rows>
3481 class GenMat<T, Rows, 3> : public PrimitiveFunc<
3482 Signature<Matrix<T, Rows, 3>, Vector<T, Rows>, Vector<T, Rows>, Vector<T, Rows> > >
3483 {
3484 public:
3485 typedef typename GenMat::Ret Ret;
3486 typedef typename GenMat::IRet IRet;
3487 typedef typename GenMat::IArgs IArgs;
3488
getName(void) const3489 string getName (void) const
3490 {
3491 return dataTypeNameOf<Ret>();
3492 }
3493
3494 protected:
3495
doApply(const EvalContext &,const IArgs & iargs) const3496 IRet doApply (const EvalContext&, const IArgs& iargs) const
3497 {
3498 IRet ret;
3499 ret[0] = iargs.a;
3500 ret[1] = iargs.b;
3501 ret[2] = iargs.c;
3502 return ret;
3503 }
3504 };
3505
3506 template <typename T, int Rows>
3507 class GenMat<T, Rows, 4> : public PrimitiveFunc<
3508 Signature<Matrix<T, Rows, 4>,
3509 Vector<T, Rows>, Vector<T, Rows>, Vector<T, Rows>, Vector<T, Rows> > >
3510 {
3511 public:
3512 typedef typename GenMat::Ret Ret;
3513 typedef typename GenMat::IRet IRet;
3514 typedef typename GenMat::IArgs IArgs;
3515
getName(void) const3516 string getName (void) const
3517 {
3518 return dataTypeNameOf<Ret>();
3519 }
3520
3521 protected:
doApply(const EvalContext &,const IArgs & iargs) const3522 IRet doApply (const EvalContext&, const IArgs& iargs) const
3523 {
3524 IRet ret;
3525 ret[0] = iargs.a;
3526 ret[1] = iargs.b;
3527 ret[2] = iargs.c;
3528 ret[3] = iargs.d;
3529 return ret;
3530 }
3531 };
3532
3533 template <typename T, int Rows>
mat2(const ExprP<Vector<T,Rows>> & arg0,const ExprP<Vector<T,Rows>> & arg1)3534 ExprP<Matrix<T, Rows, 2> > mat2 (const ExprP<Vector<T, Rows> >& arg0,
3535 const ExprP<Vector<T, Rows> >& arg1)
3536 {
3537 return app<GenMat<T, Rows, 2> >(arg0, arg1);
3538 }
3539
3540 template <typename T, int Rows>
mat3(const ExprP<Vector<T,Rows>> & arg0,const ExprP<Vector<T,Rows>> & arg1,const ExprP<Vector<T,Rows>> & arg2)3541 ExprP<Matrix<T, Rows, 3> > mat3 (const ExprP<Vector<T, Rows> >& arg0,
3542 const ExprP<Vector<T, Rows> >& arg1,
3543 const ExprP<Vector<T, Rows> >& arg2)
3544 {
3545 return app<GenMat<T, Rows, 3> >(arg0, arg1, arg2);
3546 }
3547
3548 template <typename T, int Rows>
mat4(const ExprP<Vector<T,Rows>> & arg0,const ExprP<Vector<T,Rows>> & arg1,const ExprP<Vector<T,Rows>> & arg2,const ExprP<Vector<T,Rows>> & arg3)3549 ExprP<Matrix<T, Rows, 4> > mat4 (const ExprP<Vector<T, Rows> >& arg0,
3550 const ExprP<Vector<T, Rows> >& arg1,
3551 const ExprP<Vector<T, Rows> >& arg2,
3552 const ExprP<Vector<T, Rows> >& arg3)
3553 {
3554 return app<GenMat<T, Rows, 4> >(arg0, arg1, arg2, arg3);
3555 }
3556
3557 template <typename T, int Rows, int Cols>
3558 class MatNeg : public PrimitiveFunc<Signature<Matrix<T, Rows, Cols>,
3559 Matrix<T, Rows, Cols> > >
3560 {
3561 public:
3562 typedef typename MatNeg::IRet IRet;
3563 typedef typename MatNeg::IArgs IArgs;
3564
getName(void) const3565 string getName (void) const
3566 {
3567 return "_matNeg";
3568 }
3569
3570 protected:
doPrint(ostream & os,const BaseArgExprs & args) const3571 void doPrint (ostream& os, const BaseArgExprs& args) const
3572 {
3573 os << "-(" << *args[0] << ")";
3574 }
3575
doApply(const EvalContext &,const IArgs & iargs) const3576 IRet doApply (const EvalContext&, const IArgs& iargs) const
3577 {
3578 IRet ret;
3579
3580 for (int col = 0; col < Cols; ++col)
3581 {
3582 for (int row = 0; row < Rows; ++row)
3583 ret[col][row] = -iargs.a[col][row];
3584 }
3585
3586 return ret;
3587 }
3588 };
3589
3590 template <typename T, typename Sig>
3591 class CompWiseFunc : public PrimitiveFunc<Sig>
3592 {
3593 public:
3594 typedef Func<Signature<T, T, T> > ScalarFunc;
3595
getName(void) const3596 string getName (void) const
3597 {
3598 return doGetScalarFunc().getName();
3599 }
3600 protected:
doPrint(ostream & os,const BaseArgExprs & args) const3601 void doPrint (ostream& os,
3602 const BaseArgExprs& args) const
3603 {
3604 doGetScalarFunc().print(os, args);
3605 }
3606
3607 virtual
3608 const ScalarFunc& doGetScalarFunc (void) const = 0;
3609 };
3610
3611 template <typename T, int Rows, int Cols>
3612 class CompMatFuncBase : public CompWiseFunc<T, Signature<Matrix<T, Rows, Cols>,
3613 Matrix<T, Rows, Cols>,
3614 Matrix<T, Rows, Cols> > >
3615 {
3616 public:
3617 typedef typename CompMatFuncBase::IRet IRet;
3618 typedef typename CompMatFuncBase::IArgs IArgs;
3619
3620 protected:
3621
doApply(const EvalContext & ctx,const IArgs & iargs) const3622 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
3623 {
3624 IRet ret;
3625
3626 for (int col = 0; col < Cols; ++col)
3627 {
3628 for (int row = 0; row < Rows; ++row)
3629 ret[col][row] = this->doGetScalarFunc().apply(ctx,
3630 iargs.a[col][row],
3631 iargs.b[col][row]);
3632 }
3633
3634 return ret;
3635 }
3636 };
3637
3638 template <typename F, typename T, int Rows, int Cols>
3639 class CompMatFunc : public CompMatFuncBase<T, Rows, Cols>
3640 {
3641 protected:
doGetScalarFunc(void) const3642 const typename CompMatFunc::ScalarFunc& doGetScalarFunc (void) const
3643 {
3644 return instance<F>();
3645 }
3646 };
3647
3648 template <class T>
3649 class ScalarMatrixCompMult : public Mul< Signature<T, T, T> >
3650 {
3651 public:
3652
getName(void) const3653 string getName (void) const
3654 {
3655 return "matrixCompMult";
3656 }
3657
doPrint(ostream & os,const BaseArgExprs & args) const3658 void doPrint (ostream& os, const BaseArgExprs& args) const
3659 {
3660 Func<Signature<T, T, T> >::doPrint(os, args);
3661 }
3662 };
3663
3664 template <int Rows, int Cols, class T>
3665 class MatrixCompMult : public CompMatFunc<ScalarMatrixCompMult<T>, T, Rows, Cols>
3666 {
3667 };
3668
3669 template <int Rows, int Cols>
3670 class ScalarMatFuncBase : public CompWiseFunc<float, Signature<Matrix<float, Rows, Cols>,
3671 Matrix<float, Rows, Cols>,
3672 float> >
3673 {
3674 public:
3675 typedef typename ScalarMatFuncBase::IRet IRet;
3676 typedef typename ScalarMatFuncBase::IArgs IArgs;
3677
3678 protected:
3679
doApply(const EvalContext & ctx,const IArgs & iargs) const3680 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
3681 {
3682 IRet ret;
3683
3684 for (int col = 0; col < Cols; ++col)
3685 {
3686 for (int row = 0; row < Rows; ++row)
3687 ret[col][row] = this->doGetScalarFunc().apply(ctx, iargs.a[col][row], iargs.b);
3688 }
3689
3690 return ret;
3691 }
3692 };
3693
3694 template <typename F, int Rows, int Cols>
3695 class ScalarMatFunc : public ScalarMatFuncBase<Rows, Cols>
3696 {
3697 protected:
doGetScalarFunc(void) const3698 const typename ScalarMatFunc::ScalarFunc& doGetScalarFunc (void) const
3699 {
3700 return instance<F>();
3701 }
3702 };
3703
3704 template<typename T, int Size> struct GenXType;
3705
3706 template<typename T>
3707 struct GenXType<T, 1>
3708 {
genXTypevkt::shaderexecutor::Functions::GenXType3709 static ExprP<T> genXType (const ExprP<T>& x) { return x; }
3710 };
3711
3712 template<typename T>
3713 struct GenXType<T, 2>
3714 {
genXTypevkt::shaderexecutor::Functions::GenXType3715 static ExprP<Vector<T, 2> > genXType (const ExprP<T>& x)
3716 {
3717 return app<GenVec<T, 2> >(x, x);
3718 }
3719 };
3720
3721 template<typename T>
3722 struct GenXType<T, 3>
3723 {
genXTypevkt::shaderexecutor::Functions::GenXType3724 static ExprP<Vector<T, 3> > genXType (const ExprP<T>& x)
3725 {
3726 return app<GenVec<T, 3> >(x, x, x);
3727 }
3728 };
3729
3730 template<typename T>
3731 struct GenXType<T, 4>
3732 {
genXTypevkt::shaderexecutor::Functions::GenXType3733 static ExprP<Vector<T, 4> > genXType (const ExprP<T>& x)
3734 {
3735 return app<GenVec<T, 4> >(x, x, x, x);
3736 }
3737 };
3738
3739 //! Returns an expression of vector of size `Size` (or scalar if Size == 1),
3740 //! with each element initialized with the expression `x`.
3741 template<typename T, int Size>
genXType(const ExprP<T> & x)3742 ExprP<typename ContainerOf<T, Size>::Container> genXType (const ExprP<T>& x)
3743 {
3744 return GenXType<T, Size>::genXType(x);
3745 }
3746
3747 typedef GenVec<float, 2> FloatVec2;
3748 DEFINE_CONSTRUCTOR2(FloatVec2, Vec2, vec2, float, float)
3749
3750 typedef GenVec<deFloat16, 2> FloatVec2_16bit;
3751 DEFINE_CONSTRUCTOR2(FloatVec2_16bit, Vec2_16Bit, vec2, deFloat16, deFloat16)
3752
3753 typedef GenVec<double, 2> DoubleVec2;
3754 DEFINE_CONSTRUCTOR2(DoubleVec2, Vec2_64Bit, vec2, double, double)
3755
3756 typedef GenVec<float, 3> FloatVec3;
3757 DEFINE_CONSTRUCTOR3(FloatVec3, Vec3, vec3, float, float, float)
3758
3759 typedef GenVec<deFloat16, 3> FloatVec3_16bit;
3760 DEFINE_CONSTRUCTOR3(FloatVec3_16bit, Vec3_16Bit, vec3, deFloat16, deFloat16, deFloat16)
3761
3762 typedef GenVec<double, 3> DoubleVec3;
3763 DEFINE_CONSTRUCTOR3(DoubleVec3, Vec3_64Bit, vec3, double, double, double)
3764
3765 typedef GenVec<float, 4> FloatVec4;
3766 DEFINE_CONSTRUCTOR4(FloatVec4, Vec4, vec4, float, float, float, float)
3767
3768 typedef GenVec<deFloat16, 4> FloatVec4_16bit;
3769 DEFINE_CONSTRUCTOR4(FloatVec4_16bit, Vec4_16Bit, vec4, deFloat16, deFloat16, deFloat16, deFloat16)
3770
3771 typedef GenVec<double, 4> DoubleVec4;
3772 DEFINE_CONSTRUCTOR4(DoubleVec4, Vec4_64Bit, vec4, double, double, double, double)
3773
3774 template <class T>
3775 const ExprP<T> getConstZero(void);
3776 template <class T>
3777 const ExprP<T> getConstOne(void);
3778 template <class T>
3779 const ExprP<T> getConstTwo(void);
3780
3781 template <>
getConstZero(void)3782 const ExprP<float> getConstZero<float>(void)
3783 {
3784 return constant(0.0f);
3785 }
3786
3787 template <>
getConstZero(void)3788 const ExprP<deFloat16> getConstZero<deFloat16>(void)
3789 {
3790 return constant((deFloat16)FLOAT16_0_0);
3791 }
3792
3793 template <>
getConstZero(void)3794 const ExprP<double> getConstZero<double>(void)
3795 {
3796 return constant(0.0);
3797 }
3798
3799 template <>
getConstOne(void)3800 const ExprP<float> getConstOne<float>(void)
3801 {
3802 return constant(1.0f);
3803 }
3804
3805 template <>
getConstOne(void)3806 const ExprP<deFloat16> getConstOne<deFloat16>(void)
3807 {
3808 return constant((deFloat16)FLOAT16_1_0);
3809 }
3810
3811 template <>
getConstOne(void)3812 const ExprP<double> getConstOne<double>(void)
3813 {
3814 return constant(1.0);
3815 }
3816
3817 template <>
getConstTwo(void)3818 const ExprP<float> getConstTwo<float>(void)
3819 {
3820 return constant(2.0f);
3821 }
3822
3823 template <>
getConstTwo(void)3824 const ExprP<deFloat16> getConstTwo<deFloat16>(void)
3825 {
3826 return constant((deFloat16)FLOAT16_2_0);
3827 }
3828
3829 template <>
getConstTwo(void)3830 const ExprP<double> getConstTwo<double>(void)
3831 {
3832 return constant(2.0);
3833 }
3834
3835 template <int Size, class T>
3836 class Dot : public DerivedFunc<Signature<T, Vector<T, Size>, Vector<T, Size> > >
3837 {
3838 public:
3839 typedef typename Dot::ArgExprs ArgExprs;
3840
getName(void) const3841 string getName (void) const
3842 {
3843 return "dot";
3844 }
3845
3846 protected:
doExpand(ExpandContext &,const ArgExprs & args) const3847 ExprP<T> doExpand (ExpandContext&, const ArgExprs& args) const
3848 {
3849 ExprP<T> op[Size];
3850 // Precompute all products.
3851 for (int ndx = 0; ndx < Size; ++ndx)
3852 op[ndx] = args.a[ndx] * args.b[ndx];
3853
3854 int idx[Size];
3855 //Prepare an array of indices.
3856 for (int ndx = 0; ndx < Size; ++ndx)
3857 idx[ndx] = ndx;
3858
3859 ExprP<T> res = op[0];
3860 // Compute the first dot alternative: SUM(a[i]*b[i]), i = 0 .. Size-1
3861 for (int ndx = 1; ndx < Size; ++ndx)
3862 res = res + op[ndx];
3863
3864 // Generate all permutations of indices and
3865 // using a permutation compute a dot alternative.
3866 // Generates all possible variants fo summation of products in the dot product expansion expression.
3867 do {
3868 ExprP<T> alt = getConstZero<T>();
3869 for (int ndx = 0; ndx < Size; ++ndx)
3870 alt = alt + op[idx[ndx]];
3871 res = alternatives(res, alt);
3872 } while (std::next_permutation(idx, idx + Size));
3873
3874 return res;
3875 }
3876 };
3877
3878 template <class T>
3879 class Dot<1, T> : public DerivedFunc<Signature<T, T, T> >
3880 {
3881 public:
3882 typedef typename DerivedFunc<Signature<T, T, T> >::ArgExprs TArgExprs;
3883
getName(void) const3884 string getName (void) const
3885 {
3886 return "dot";
3887 }
3888
doExpand(ExpandContext &,const TArgExprs & args) const3889 ExprP<T> doExpand (ExpandContext&, const TArgExprs& args) const
3890 {
3891 return args.a * args.b;
3892 }
3893 };
3894
3895 template <int Size>
dot(const ExprP<Vector<deFloat16,Size>> & x,const ExprP<Vector<deFloat16,Size>> & y)3896 ExprP<deFloat16> dot (const ExprP<Vector<deFloat16, Size> >& x, const ExprP<Vector<deFloat16, Size> >& y)
3897 {
3898 return app<Dot<Size, deFloat16> >(x, y);
3899 }
3900
dot(const ExprP<deFloat16> & x,const ExprP<deFloat16> & y)3901 ExprP<deFloat16> dot (const ExprP<deFloat16>& x, const ExprP<deFloat16>& y)
3902 {
3903 return app<Dot<1, deFloat16> >(x, y);
3904 }
3905
3906 template <int Size>
dot(const ExprP<Vector<float,Size>> & x,const ExprP<Vector<float,Size>> & y)3907 ExprP<float> dot (const ExprP<Vector<float, Size> >& x, const ExprP<Vector<float, Size> >& y)
3908 {
3909 return app<Dot<Size, float> >(x, y);
3910 }
3911
dot(const ExprP<float> & x,const ExprP<float> & y)3912 ExprP<float> dot (const ExprP<float>& x, const ExprP<float>& y)
3913 {
3914 return app<Dot<1, float> >(x, y);
3915 }
3916
3917 template <int Size>
dot(const ExprP<Vector<double,Size>> & x,const ExprP<Vector<double,Size>> & y)3918 ExprP<double> dot (const ExprP<Vector<double, Size> >& x, const ExprP<Vector<double, Size> >& y)
3919 {
3920 return app<Dot<Size, double> >(x, y);
3921 }
3922
dot(const ExprP<double> & x,const ExprP<double> & y)3923 ExprP<double> dot (const ExprP<double>& x, const ExprP<double>& y)
3924 {
3925 return app<Dot<1, double> >(x, y);
3926 }
3927
3928 template <int Size, class T>
3929 class Length : public DerivedFunc<
3930 Signature<T, typename ContainerOf<T, Size>::Container> >
3931 {
3932 public:
3933 typedef typename Length::ArgExprs ArgExprs;
3934
getName(void) const3935 string getName (void) const
3936 {
3937 return "length";
3938 }
3939
3940 protected:
doExpand(ExpandContext &,const ArgExprs & args) const3941 ExprP<T> doExpand (ExpandContext&, const ArgExprs& args) const
3942 {
3943 return sqrt(dot(args.a, args.a));
3944 }
3945 };
3946
3947
3948 template <class T, class TRet>
length(const ExprP<T> & x)3949 ExprP<TRet> length (const ExprP<T>& x)
3950 {
3951 return app<Length<1, T> >(x);
3952 }
3953
3954 template <int Size, class T, class TRet>
length(const ExprP<typename ContainerOf<T,Size>::Container> & x)3955 ExprP<TRet> length (const ExprP<typename ContainerOf<T, Size>::Container>& x)
3956 {
3957 return app<Length<Size, T> >(x);
3958 }
3959
3960 template <int Size, class T>
3961 class Distance : public DerivedFunc<
3962 Signature<T,
3963 typename ContainerOf<T, Size>::Container,
3964 typename ContainerOf<T, Size>::Container> >
3965 {
3966 public:
3967 typedef typename Distance::Ret Ret;
3968 typedef typename Distance::ArgExprs ArgExprs;
3969
getName(void) const3970 string getName (void) const
3971 {
3972 return "distance";
3973 }
3974
3975 protected:
doExpand(ExpandContext &,const ArgExprs & args) const3976 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
3977 {
3978 return length<Size, T, Ret>(args.a - args.b);
3979 }
3980 };
3981
3982 // cross
3983
3984 class Cross : public DerivedFunc<Signature<Vec3, Vec3, Vec3> >
3985 {
3986 public:
getName(void) const3987 string getName (void) const
3988 {
3989 return "cross";
3990 }
3991
3992 protected:
doExpand(ExpandContext &,const ArgExprs & x) const3993 ExprP<Vec3> doExpand (ExpandContext&, const ArgExprs& x) const
3994 {
3995 return vec3(x.a[1] * x.b[2] - x.b[1] * x.a[2],
3996 x.a[2] * x.b[0] - x.b[2] * x.a[0],
3997 x.a[0] * x.b[1] - x.b[0] * x.a[1]);
3998 }
3999 };
4000
4001 class Cross16Bit : public DerivedFunc<Signature<Vec3_16Bit, Vec3_16Bit, Vec3_16Bit> >
4002 {
4003 public:
getName(void) const4004 string getName (void) const
4005 {
4006 return "cross";
4007 }
4008
4009 protected:
doExpand(ExpandContext &,const ArgExprs & x) const4010 ExprP<Vec3_16Bit> doExpand (ExpandContext&, const ArgExprs& x) const
4011 {
4012 return vec3(x.a[1] * x.b[2] - x.b[1] * x.a[2],
4013 x.a[2] * x.b[0] - x.b[2] * x.a[0],
4014 x.a[0] * x.b[1] - x.b[0] * x.a[1]);
4015 }
4016 };
4017
4018 class Cross64Bit : public DerivedFunc<Signature<Vec3_64Bit, Vec3_64Bit, Vec3_64Bit> >
4019 {
4020 public:
getName(void) const4021 string getName (void) const
4022 {
4023 return "cross";
4024 }
4025
4026 protected:
doExpand(ExpandContext &,const ArgExprs & x) const4027 ExprP<Vec3_64Bit> doExpand (ExpandContext&, const ArgExprs& x) const
4028 {
4029 return vec3(x.a[1] * x.b[2] - x.b[1] * x.a[2],
4030 x.a[2] * x.b[0] - x.b[2] * x.a[0],
4031 x.a[0] * x.b[1] - x.b[0] * x.a[1]);
4032 }
4033 };
4034
4035 DEFINE_CONSTRUCTOR2(Cross, Vec3, cross, Vec3, Vec3)
4036 DEFINE_CONSTRUCTOR2(Cross16Bit, Vec3_16Bit, cross, Vec3_16Bit, Vec3_16Bit)
4037 DEFINE_CONSTRUCTOR2(Cross64Bit, Vec3_64Bit, cross, Vec3_64Bit, Vec3_64Bit)
4038
4039 template<int Size, class T>
4040 class Normalize : public DerivedFunc<
4041 Signature<typename ContainerOf<T, Size>::Container,
4042 typename ContainerOf<T, Size>::Container> >
4043 {
4044 public:
4045 typedef typename Normalize::Ret Ret;
4046 typedef typename Normalize::ArgExprs ArgExprs;
4047
getName(void) const4048 string getName (void) const
4049 {
4050 return "normalize";
4051 }
4052
4053 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4054 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
4055 {
4056 return args.a * app<InverseSqrt<Signature<T, T>>>(dot(args.a, args.a));
4057 }
4058 };
4059
4060 template <int Size, class T>
4061 class FaceForward : public DerivedFunc<
4062 Signature<typename ContainerOf<T, Size>::Container,
4063 typename ContainerOf<T, Size>::Container,
4064 typename ContainerOf<T, Size>::Container,
4065 typename ContainerOf<T, Size>::Container> >
4066 {
4067 public:
4068 typedef typename FaceForward::Ret Ret;
4069 typedef typename FaceForward::ArgExprs ArgExprs;
4070
getName(void) const4071 string getName (void) const
4072 {
4073 return "faceforward";
4074 }
4075
4076 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4077 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
4078 {
4079 return cond(dot(args.c, args.b) < getConstZero<T>(), args.a, -args.a);
4080 }
4081 };
4082
4083 template <int Size, class T>
4084 class Reflect : public DerivedFunc<
4085 Signature<typename ContainerOf<T, Size>::Container,
4086 typename ContainerOf<T, Size>::Container,
4087 typename ContainerOf<T, Size>::Container> >
4088 {
4089 public:
4090 typedef typename Reflect::Ret Ret;
4091 typedef typename Reflect::Arg0 Arg0;
4092 typedef typename Reflect::Arg1 Arg1;
4093 typedef typename Reflect::ArgExprs ArgExprs;
4094
getName(void) const4095 string getName (void) const
4096 {
4097 return "reflect";
4098 }
4099
4100 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4101 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
4102 {
4103 const ExprP<Arg0>& i = args.a;
4104 const ExprP<Arg1>& n = args.b;
4105 const ExprP<T> dotNI = bindExpression("dotNI", ctx, dot(n, i));
4106
4107 return i - alternatives((n * dotNI) * getConstTwo<T>(),
4108 alternatives( n * (dotNI * getConstTwo<T>()),
4109 alternatives(n * dot(i * getConstTwo<T>(), n),
4110 n * dot(i, n * getConstTwo<T>())
4111 )
4112 )
4113 );
4114 }
4115 };
4116
4117 template <class T>
4118 class Reflect<1, T> : public DerivedFunc<
4119 Signature<T, T, T> >
4120 {
4121 public:
4122 typedef typename Reflect::Ret Ret;
4123 typedef typename Reflect::Arg0 Arg0;
4124 typedef typename Reflect::Arg1 Arg1;
4125 typedef typename Reflect::ArgExprs ArgExprs;
4126
getName(void) const4127 string getName (void) const
4128 {
4129 return "reflect";
4130 }
4131
4132 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4133 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
4134 {
4135 const ExprP<Arg0>& i = args.a;
4136 const ExprP<Arg1>& n = args.b;
4137 const ExprP<T> dotNI = bindExpression("dotNI", ctx, dot(n, i));
4138
4139 return i - alternatives((n * dotNI) * getConstTwo<T>(),
4140 alternatives( n * (dotNI * getConstTwo<T>()),
4141 alternatives(n * dot(i * getConstTwo<T>(), n),
4142 alternatives(n * dot(i, n * getConstTwo<T>()),
4143 dot(n * n, i * getConstTwo<T>()))
4144 )
4145 )
4146 );
4147 }
4148 };
4149
4150 template <int Size, class T>
4151 class Refract : public DerivedFunc<
4152 Signature<typename ContainerOf<T, Size>::Container,
4153 typename ContainerOf<T, Size>::Container,
4154 typename ContainerOf<T, Size>::Container,
4155 T> >
4156 {
4157 public:
4158 typedef typename Refract::Ret Ret;
4159 typedef typename Refract::Arg0 Arg0;
4160 typedef typename Refract::Arg1 Arg1;
4161 typedef typename Refract::Arg2 Arg2;
4162 typedef typename Refract::ArgExprs ArgExprs;
4163
getName(void) const4164 string getName (void) const
4165 {
4166 return "refract";
4167 }
4168
4169 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4170 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
4171 {
4172 const ExprP<Arg0>& i = args.a;
4173 const ExprP<Arg1>& n = args.b;
4174 const ExprP<Arg2>& eta = args.c;
4175 const ExprP<T> dotNI = bindExpression("dotNI", ctx, dot(n, i));
4176 const ExprP<T> k = bindExpression("k", ctx, getConstOne<T>() - eta * eta *
4177 (getConstOne<T>() - dotNI * dotNI));
4178 return cond(k < getConstZero<T>(),
4179 genXType<T, Size>(getConstZero<T>()),
4180 i * eta - n * (eta * dotNI + sqrt(k)));
4181 }
4182 };
4183
4184 template <class T>
4185 class PreciseFunc1 : public CFloatFunc1<T>
4186 {
4187 public:
PreciseFunc1(const string & name,DoubleFunc1 & func)4188 PreciseFunc1 (const string& name, DoubleFunc1& func) : CFloatFunc1<T> (name, func) {}
4189 protected:
precision(const EvalContext &,double,double) const4190 double precision (const EvalContext&, double, double) const { return 0.0; }
4191 };
4192
4193 template <class T>
4194 class Abs : public PreciseFunc1<T>
4195 {
4196 public:
Abs(void)4197 Abs (void) : PreciseFunc1<T> ("abs", deAbs) {}
4198 };
4199
4200 template <class T>
4201 class Sign : public PreciseFunc1<T>
4202 {
4203 public:
Sign(void)4204 Sign (void) : PreciseFunc1<T> ("sign", deSign) {}
4205 };
4206
4207 template <class T>
4208 class Floor : public PreciseFunc1<T>
4209 {
4210 public:
Floor(void)4211 Floor (void) : PreciseFunc1<T> ("floor", deFloor) {}
4212 };
4213
4214 template <class T>
4215 class Trunc : public PreciseFunc1<T>
4216 {
4217 public:
Trunc(void)4218 Trunc (void) : PreciseFunc1<T> ("trunc", deTrunc) {}
4219 };
4220
4221 template <class T>
4222 class Round : public FloatFunc1<T>
4223 {
4224 public:
getName(void) const4225 string getName (void) const { return "round"; }
4226
4227 protected:
applyPoint(const EvalContext &,double x) const4228 Interval applyPoint (const EvalContext&, double x) const
4229 {
4230 double truncated = 0.0;
4231 const double fract = deModf(x, &truncated);
4232 Interval ret;
4233
4234 if (fabs(fract) <= 0.5)
4235 ret |= truncated;
4236 if (fabs(fract) >= 0.5)
4237 ret |= truncated + deSign(fract);
4238
4239 return ret;
4240 }
4241
precision(const EvalContext &,double,double) const4242 double precision (const EvalContext&, double, double) const { return 0.0; }
4243 };
4244
4245 template <class T>
4246 class RoundEven : public PreciseFunc1<T>
4247 {
4248 public:
RoundEven(void)4249 RoundEven (void) : PreciseFunc1<T> ("roundEven", deRoundEven) {}
4250 };
4251
4252 template <class T>
4253 class Ceil : public PreciseFunc1<T>
4254 {
4255 public:
Ceil(void)4256 Ceil (void) : PreciseFunc1<T> ("ceil", deCeil) {}
4257 };
4258
4259 typedef Floor< Signature<float, float> > Floor32Bit;
4260 typedef Floor< Signature<deFloat16, deFloat16> > Floor16Bit;
4261 typedef Floor< Signature<double, double> > Floor64Bit;
4262
4263 typedef Trunc< Signature<float, float> > Trunc32Bit;
4264 typedef Trunc< Signature<deFloat16, deFloat16> > Trunc16Bit;
4265 typedef Trunc< Signature<double, double> > Trunc64Bit;
4266
4267 typedef Trunc< Signature<float, float> > Trunc32Bit;
4268 typedef Trunc< Signature<deFloat16, deFloat16> > Trunc16Bit;
4269
4270 DEFINE_DERIVED_FLOAT1(Fract, fract, x, x - app<Floor32Bit>(x))
4271 DEFINE_DERIVED_FLOAT1_16BIT(Fract16Bit, fract, x, x - app<Floor16Bit>(x))
4272 DEFINE_DERIVED_DOUBLE1(Fract64Bit, fract, x, x - app<Floor64Bit>(x))
4273
4274 template <class T>
4275 class PreciseFunc2 : public CFloatFunc2<T>
4276 {
4277 public:
PreciseFunc2(const string & name,DoubleFunc2 & func)4278 PreciseFunc2 (const string& name, DoubleFunc2& func) : CFloatFunc2<T> (name, func) {}
4279 protected:
precision(const EvalContext &,double,double,double) const4280 double precision (const EvalContext&, double, double, double) const { return 0.0; }
4281 };
4282
4283 DEFINE_DERIVED_FLOAT2(Mod32Bit, mod, x, y, x - y * app<Floor32Bit>(x / y))
4284 DEFINE_DERIVED_FLOAT2_16BIT(Mod16Bit, mod, x, y, x - y * app<Floor16Bit>(x / y))
4285 DEFINE_DERIVED_DOUBLE2(Mod64Bit, mod, x, y, x - y * app<Floor64Bit>(x / y))
4286
4287 DEFINE_CASED_DERIVED_FLOAT2(FRem32Bit, frem, x, y, x - y * app<Trunc32Bit>(x / y), SPIRV_CASETYPE_FREM)
4288 DEFINE_CASED_DERIVED_FLOAT2_16BIT(FRem16Bit, frem, x, y, x - y * app<Trunc16Bit>(x / y), SPIRV_CASETYPE_FREM)
4289 DEFINE_CASED_DERIVED_DOUBLE2(FRem64Bit, frem, x, y, x - y * app<Trunc64Bit>(x / y), SPIRV_CASETYPE_FREM)
4290
4291 template <class T>
4292 class Modf : public PrimitiveFunc<T>
4293 {
4294 public:
4295 typedef typename Modf<T>::IArgs TIArgs;
4296 typedef typename Modf<T>::IRet TIRet;
getName(void) const4297 string getName (void) const
4298 {
4299 return "modf";
4300 }
4301
4302 protected:
doApply(const EvalContext & ctx,const TIArgs & iargs) const4303 TIRet doApply (const EvalContext& ctx, const TIArgs& iargs) const
4304 {
4305 Interval fracIV;
4306 Interval& wholeIV = const_cast<Interval&>(iargs.b);
4307 double intPart = 0;
4308
4309 TCU_INTERVAL_APPLY_MONOTONE1(fracIV, x, iargs.a, frac, frac = deModf(x, &intPart));
4310 TCU_INTERVAL_APPLY_MONOTONE1(wholeIV, x, iargs.a, whole,
4311 deModf(x, &intPart); whole = intPart);
4312
4313 if (!iargs.a.isFinite(ctx.format.getMaxValue()))
4314 {
4315 // Behavior on modf(Inf) not well-defined, allow anything as a fractional part
4316 // See Khronos bug 13907
4317 fracIV |= TCU_NAN;
4318 }
4319
4320 return fracIV;
4321 }
4322
getOutParamIndex(void) const4323 int getOutParamIndex (void) const
4324 {
4325 return 1;
4326 }
4327 };
4328 typedef Modf< Signature<float, float, float> > Modf32Bit;
4329 typedef Modf< Signature<deFloat16, deFloat16, deFloat16> > Modf16Bit;
4330 typedef Modf< Signature<double, double, double> > Modf64Bit;
4331
4332 template <class T>
4333 class ModfStruct : public Modf<T>
4334 {
4335 public:
getName(void) const4336 virtual string getName (void) const { return "modfstruct"; }
getSpirvCase(void) const4337 virtual SpirVCaseT getSpirvCase (void) const { return SPIRV_CASETYPE_MODFSTRUCT; }
4338 };
4339 typedef ModfStruct< Signature<float, float, float> > ModfStruct32Bit;
4340 typedef ModfStruct< Signature<deFloat16, deFloat16, deFloat16> > ModfStruct16Bit;
4341 typedef ModfStruct< Signature<double, double, double> > ModfStruct64Bit;
4342
4343 template <class T>
Min(void)4344 class Min : public PreciseFunc2<T> { public: Min (void) : PreciseFunc2<T> ("min", deMin) {} };
4345 template <class T>
Max(void)4346 class Max : public PreciseFunc2<T> { public: Max (void) : PreciseFunc2<T> ("max", deMax) {} };
4347
4348 template <class T>
4349 class Clamp : public FloatFunc3<T>
4350 {
4351 public:
getName(void) const4352 string getName (void) const { return "clamp"; }
4353
applyExact(double x,double minVal,double maxVal) const4354 double applyExact (double x, double minVal, double maxVal) const
4355 {
4356 return de::min(de::max(x, minVal), maxVal);
4357 }
4358
precision(const EvalContext &,double,double,double minVal,double maxVal) const4359 double precision (const EvalContext&, double, double, double minVal, double maxVal) const
4360 {
4361 return minVal > maxVal ? TCU_NAN : 0.0;
4362 }
4363 };
4364
clamp(const ExprP<deFloat16> & x,const ExprP<deFloat16> & minVal,const ExprP<deFloat16> & maxVal)4365 ExprP<deFloat16> clamp(const ExprP<deFloat16>& x, const ExprP<deFloat16>& minVal, const ExprP<deFloat16>& maxVal)
4366 {
4367 return app<Clamp< Signature<deFloat16, deFloat16, deFloat16, deFloat16> > >(x, minVal, maxVal);
4368 }
4369
clamp(const ExprP<float> & x,const ExprP<float> & minVal,const ExprP<float> & maxVal)4370 ExprP<float> clamp(const ExprP<float>& x, const ExprP<float>& minVal, const ExprP<float>& maxVal)
4371 {
4372 return app<Clamp< Signature<float, float, float, float> > >(x, minVal, maxVal);
4373 }
4374
clamp(const ExprP<double> & x,const ExprP<double> & minVal,const ExprP<double> & maxVal)4375 ExprP<double> clamp(const ExprP<double>& x, const ExprP<double>& minVal, const ExprP<double>& maxVal)
4376 {
4377 return app<Clamp< Signature<double, double, double, double> > >(x, minVal, maxVal);
4378 }
4379
4380 template <class T>
4381 class NanIfGreaterOrEqual : public FloatFunc2<T>
4382 {
4383 public:
getName(void) const4384 string getName (void) const { return "nanIfGreaterOrEqual"; }
4385
applyExact(double edge0,double edge1) const4386 double applyExact (double edge0, double edge1) const
4387 {
4388 return (edge0 >= edge1) ? TCU_NAN : 0.0;
4389 }
4390
precision(const EvalContext &,double,double edge0,double edge1) const4391 double precision (const EvalContext&, double, double edge0, double edge1) const
4392 {
4393 return (edge0 >= edge1) ? TCU_NAN : 0.0;
4394 }
4395 };
4396
nanIfGreaterOrEqual(const ExprP<deFloat16> & edge0,const ExprP<deFloat16> & edge1)4397 ExprP<deFloat16> nanIfGreaterOrEqual(const ExprP<deFloat16>& edge0, const ExprP<deFloat16>& edge1)
4398 {
4399 return app<NanIfGreaterOrEqual< Signature<deFloat16, deFloat16, deFloat16> > >(edge0, edge1);
4400 }
4401
nanIfGreaterOrEqual(const ExprP<float> & edge0,const ExprP<float> & edge1)4402 ExprP<float> nanIfGreaterOrEqual(const ExprP<float>& edge0, const ExprP<float>& edge1)
4403 {
4404 return app<NanIfGreaterOrEqual< Signature<float, float, float> > >(edge0, edge1);
4405 }
4406
nanIfGreaterOrEqual(const ExprP<double> & edge0,const ExprP<double> & edge1)4407 ExprP<double> nanIfGreaterOrEqual(const ExprP<double>& edge0, const ExprP<double>& edge1)
4408 {
4409 return app<NanIfGreaterOrEqual< Signature<double, double, double> > >(edge0, edge1);
4410 }
4411
4412 DEFINE_DERIVED_FLOAT3(Mix, mix, x, y, a, alternatives((x * (constant(1.0f) - a)) + y * a,
4413 x + (y - x) * a))
4414
4415 DEFINE_DERIVED_FLOAT3_16BIT(Mix16Bit, mix, x, y, a, alternatives((x * (constant((deFloat16)FLOAT16_1_0) - a)) + y * a,
4416 x + (y - x) * a))
4417
4418 DEFINE_DERIVED_DOUBLE3(Mix64Bit, mix, x, y, a, alternatives((x * (constant(1.0) - a)) + y * a,
4419 x + (y - x) * a))
4420
step(double edge,double x)4421 static double step (double edge, double x)
4422 {
4423 return x < edge ? 0.0 : 1.0;
4424 }
4425
4426 template <class T>
Step(void)4427 class Step : public PreciseFunc2<T> { public: Step (void) : PreciseFunc2<T> ("step", step) {} };
4428
4429 template <class T>
4430 class SmoothStep : public DerivedFunc<T>
4431 {
4432 public:
4433 typedef typename SmoothStep<T>::ArgExprs TArgExprs;
4434 typedef typename SmoothStep<T>::Ret TRet;
getName(void) const4435 string getName (void) const
4436 {
4437 return "smoothstep";
4438 }
4439
4440 protected:
4441
4442 ExprP<TRet> doExpand (ExpandContext& ctx, const TArgExprs& args) const;
4443 };
4444
4445 template<>
doExpand(ExpandContext & ctx,const SmoothStep<Signature<float,float,float,float>>::ArgExprs & args) const4446 ExprP<SmoothStep< Signature<float, float, float, float> >::Ret> SmoothStep< Signature<float, float, float, float> >::doExpand (ExpandContext& ctx, const SmoothStep< Signature<float, float, float, float> >::ArgExprs& args) const
4447 {
4448 const ExprP<float>& edge0 = args.a;
4449 const ExprP<float>& edge1 = args.b;
4450 const ExprP<float>& x = args.c;
4451 const ExprP<float> tExpr = clamp((x - edge0) / (edge1 - edge0), constant(0.0f), constant(1.0f))
4452 + nanIfGreaterOrEqual(edge0, edge1); // force NaN (and non-analyzable result) for cases edge0 >= edge1
4453 const ExprP<float> t = bindExpression("t", ctx, tExpr);
4454
4455 return (t * t * (constant(3.0f) - constant(2.0f) * t));
4456 }
4457
4458 template<>
doExpand(ExpandContext & ctx,const TArgExprs & args) const4459 ExprP<SmoothStep< Signature<deFloat16, deFloat16, deFloat16, deFloat16> >::TRet> SmoothStep< Signature<deFloat16, deFloat16, deFloat16, deFloat16> >::doExpand (ExpandContext& ctx, const TArgExprs& args) const
4460 {
4461 const ExprP<deFloat16>& edge0 = args.a;
4462 const ExprP<deFloat16>& edge1 = args.b;
4463 const ExprP<deFloat16>& x = args.c;
4464 const ExprP<deFloat16> tExpr = clamp(( x - edge0 ) / ( edge1 - edge0 ),
4465 constant((deFloat16)FLOAT16_0_0), constant((deFloat16)FLOAT16_1_0))
4466 + nanIfGreaterOrEqual(edge0, edge1); // force NaN (and non-analyzable result) for cases edge0 >= edge1
4467 const ExprP<deFloat16> t = bindExpression("t", ctx, tExpr);
4468
4469 return (t * t * (constant((deFloat16)FLOAT16_3_0) - constant((deFloat16)FLOAT16_2_0) * t));
4470 }
4471
4472 template<>
doExpand(ExpandContext & ctx,const SmoothStep<Signature<double,double,double,double>>::ArgExprs & args) const4473 ExprP<SmoothStep< Signature<double, double, double, double> >::Ret> SmoothStep< Signature<double, double, double, double> >::doExpand (ExpandContext& ctx, const SmoothStep< Signature<double, double, double, double> >::ArgExprs& args) const
4474 {
4475 const ExprP<double>& edge0 = args.a;
4476 const ExprP<double>& edge1 = args.b;
4477 const ExprP<double>& x = args.c;
4478 const ExprP<double> tExpr = clamp((x - edge0) / (edge1 - edge0), constant(0.0), constant(1.0))
4479 + nanIfGreaterOrEqual(edge0, edge1); // force NaN (and non-analyzable result) for cases edge0 >= edge1
4480 const ExprP<double> t = bindExpression("t", ctx, tExpr);
4481
4482 return (t * t * (constant(3.0) - constant(2.0) * t));
4483 }
4484
4485 //Signature<float, float, int>
4486 //Signature<float, deFloat16, int>
4487 //Signature<double, double, int>
4488 template <class T>
4489 class FrExp : public PrimitiveFunc<T>
4490 {
4491 public:
getName(void) const4492 string getName (void) const
4493 {
4494 return "frexp";
4495 }
4496
4497 typedef typename FrExp::IRet IRet;
4498 typedef typename FrExp::IArgs IArgs;
4499 typedef typename FrExp::IArg0 IArg0;
4500 typedef typename FrExp::IArg1 IArg1;
4501
4502 protected:
doApply(const EvalContext &,const IArgs & iargs) const4503 IRet doApply (const EvalContext&, const IArgs& iargs) const
4504 {
4505 IRet ret;
4506 const IArg0& x = iargs.a;
4507 IArg1& exponent = const_cast<IArg1&>(iargs.b);
4508
4509 if (x.hasNaN() || x.contains(TCU_INFINITY) || x.contains(-TCU_INFINITY))
4510 {
4511 // GLSL (in contrast to IEEE) says that result of applying frexp
4512 // to infinity is undefined
4513 ret = Interval::unbounded() | TCU_NAN;
4514 exponent = Interval(-deLdExp(1.0, 31), deLdExp(1.0, 31)-1);
4515 }
4516 else if (!x.empty())
4517 {
4518 int loExp = 0;
4519 const double loFrac = deFrExp(x.lo(), &loExp);
4520 int hiExp = 0;
4521 const double hiFrac = deFrExp(x.hi(), &hiExp);
4522
4523 if (deSign(loFrac) != deSign(hiFrac))
4524 {
4525 exponent = Interval(-TCU_INFINITY, de::max(loExp, hiExp));
4526 ret = Interval();
4527 if (deSign(loFrac) < 0)
4528 ret |= Interval(-1.0 + DBL_EPSILON*0.5, 0.0);
4529 if (deSign(hiFrac) > 0)
4530 ret |= Interval(0.0, 1.0 - DBL_EPSILON*0.5);
4531 }
4532 else
4533 {
4534 exponent = Interval(loExp, hiExp);
4535 if (loExp == hiExp)
4536 ret = Interval(loFrac, hiFrac);
4537 else
4538 ret = deSign(loFrac) * Interval(0.5, 1.0 - DBL_EPSILON*0.5);
4539 }
4540 }
4541
4542 return ret;
4543 }
4544
getOutParamIndex(void) const4545 int getOutParamIndex (void) const
4546 {
4547 return 1;
4548 }
4549 };
4550 typedef FrExp< Signature<float, float, int> > Frexp32Bit;
4551 typedef FrExp< Signature<deFloat16, deFloat16, int> > Frexp16Bit;
4552 typedef FrExp< Signature<double, double, int> > Frexp64Bit;
4553
4554 template <class T>
4555 class FrexpStruct : public FrExp<T>
4556 {
4557 public:
getName(void) const4558 virtual string getName (void) const { return "frexpstruct"; }
getSpirvCase(void) const4559 virtual SpirVCaseT getSpirvCase (void) const { return SPIRV_CASETYPE_FREXPSTRUCT; }
4560 };
4561 typedef FrexpStruct< Signature<float, float, int> > FrexpStruct32Bit;
4562 typedef FrexpStruct< Signature<deFloat16, deFloat16, int> > FrexpStruct16Bit;
4563 typedef FrexpStruct< Signature<double, double, int> > FrexpStruct64Bit;
4564
4565 //Signature<float, float, int>
4566 //Signature<deFloat16, deFloat16, int>
4567 //Signature<double, double, int>
4568 template <class T>
4569 class LdExp : public PrimitiveFunc<T >
4570 {
4571 public:
4572 typedef typename LdExp::IRet IRet;
4573 typedef typename LdExp::IArgs IArgs;
4574
getName(void) const4575 string getName (void) const
4576 {
4577 return "ldexp";
4578 }
4579
4580 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const4581 Interval doApply (const EvalContext& ctx, const IArgs& iargs) const
4582 {
4583 const int minExp = ctx.format.getMinExp();
4584 const int maxExp = ctx.format.getMaxExp();
4585 // Restrictions from the GLSL.std.450 instruction set.
4586 // See Khronos bugzilla 11180 for rationale.
4587 bool any = iargs.a.hasNaN() || iargs.b.hi() > (maxExp + 1);
4588 Interval ret(any, ldexp(iargs.a.lo(), (int)iargs.b.lo()), ldexp(iargs.a.hi(), (int)iargs.b.hi()));
4589 if (iargs.b.lo() < minExp) ret |= 0.0;
4590 if (!ret.isFinite(ctx.format.getMaxValue())) ret |= TCU_NAN;
4591 return ctx.format.convert(ret);
4592 }
4593 };
4594
4595 template <>
doApply(const EvalContext & ctx,const IArgs & iargs) const4596 Interval LdExp <Signature<double, double, int>>::doApply(const EvalContext& ctx, const IArgs& iargs) const
4597 {
4598 const int minExp = ctx.format.getMinExp();
4599 const int maxExp = ctx.format.getMaxExp();
4600 // Restrictions from the GLSL.std.450 instruction set.
4601 // See Khronos bugzilla 11180 for rationale.
4602 bool any = iargs.a.hasNaN() || iargs.b.hi() > (maxExp + 1);
4603 Interval ret(any, ldexp(iargs.a.lo(), (int)iargs.b.lo()), ldexp(iargs.a.hi(), (int)iargs.b.hi()));
4604 // Add 1ULP precision tolerance to account for differing rounding modes between the GPU and deLdExp.
4605 ret += Interval(-ctx.format.ulp(ret.lo()), ctx.format.ulp(ret.hi()));
4606 if (iargs.b.lo() < minExp) ret |= 0.0;
4607 if (!ret.isFinite(ctx.format.getMaxValue())) ret |= TCU_NAN;
4608 return ctx.format.convert(ret);
4609 }
4610
4611 template<int Rows, int Columns, class T>
4612 class Transpose : public PrimitiveFunc<Signature<Matrix<T, Rows, Columns>,
4613 Matrix<T, Columns, Rows> > >
4614 {
4615 public:
4616 typedef typename Transpose::IRet IRet;
4617 typedef typename Transpose::IArgs IArgs;
4618
getName(void) const4619 string getName (void) const
4620 {
4621 return "transpose";
4622 }
4623
4624 protected:
doApply(const EvalContext &,const IArgs & iargs) const4625 IRet doApply (const EvalContext&, const IArgs& iargs) const
4626 {
4627 IRet ret;
4628
4629 for (int rowNdx = 0; rowNdx < Rows; ++rowNdx)
4630 {
4631 for (int colNdx = 0; colNdx < Columns; ++colNdx)
4632 ret(rowNdx, colNdx) = iargs.a(colNdx, rowNdx);
4633 }
4634
4635 return ret;
4636 }
4637 };
4638
4639 template<typename Ret, typename Arg0, typename Arg1>
4640 class MulFunc : public PrimitiveFunc<Signature<Ret, Arg0, Arg1> >
4641 {
4642 public:
getName(void) const4643 string getName (void) const { return "mul"; }
4644
4645 protected:
doPrint(ostream & os,const BaseArgExprs & args) const4646 void doPrint (ostream& os, const BaseArgExprs& args) const
4647 {
4648 os << "(" << *args[0] << " * " << *args[1] << ")";
4649 }
4650 };
4651
4652 template<typename T, int LeftRows, int Middle, int RightCols>
4653 class MatMul : public MulFunc<Matrix<T, LeftRows, RightCols>,
4654 Matrix<T, LeftRows, Middle>,
4655 Matrix<T, Middle, RightCols> >
4656 {
4657 protected:
4658 typedef typename MatMul::IRet IRet;
4659 typedef typename MatMul::IArgs IArgs;
4660 typedef typename MatMul::IArg0 IArg0;
4661 typedef typename MatMul::IArg1 IArg1;
4662
doApply(const EvalContext & ctx,const IArgs & iargs) const4663 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
4664 {
4665 const IArg0& left = iargs.a;
4666 const IArg1& right = iargs.b;
4667 IRet ret;
4668
4669 for (int row = 0; row < LeftRows; ++row)
4670 {
4671 for (int col = 0; col < RightCols; ++col)
4672 {
4673 Interval element (0.0);
4674
4675 for (int ndx = 0; ndx < Middle; ++ndx)
4676 element = call<Add< Signature<T, T, T> > >(ctx, element,
4677 call<Mul< Signature<T, T, T> > >(ctx, left[ndx][row], right[col][ndx]));
4678
4679 ret[col][row] = element;
4680 }
4681 }
4682
4683 return ret;
4684 }
4685 };
4686
4687 template<typename T, int Rows, int Cols>
4688 class VecMatMul : public MulFunc<Vector<T, Cols>,
4689 Vector<T, Rows>,
4690 Matrix<T, Rows, Cols> >
4691 {
4692 public:
4693 typedef typename VecMatMul::IRet IRet;
4694 typedef typename VecMatMul::IArgs IArgs;
4695 typedef typename VecMatMul::IArg0 IArg0;
4696 typedef typename VecMatMul::IArg1 IArg1;
4697
4698 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const4699 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
4700 {
4701 const IArg0& left = iargs.a;
4702 const IArg1& right = iargs.b;
4703 IRet ret;
4704
4705 for (int col = 0; col < Cols; ++col)
4706 {
4707 Interval element (0.0);
4708
4709 for (int row = 0; row < Rows; ++row)
4710 element = call<Add< Signature<T, T, T> > >(ctx, element, call<Mul< Signature<T, T, T> > >(ctx, left[row], right[col][row]));
4711
4712 ret[col] = element;
4713 }
4714
4715 return ret;
4716 }
4717 };
4718
4719 template<int Rows, int Cols, class T>
4720 class MatVecMul : public MulFunc<Vector<T, Rows>,
4721 Matrix<T, Rows, Cols>,
4722 Vector<T, Cols> >
4723 {
4724 public:
4725 typedef typename MatVecMul::IRet IRet;
4726 typedef typename MatVecMul::IArgs IArgs;
4727 typedef typename MatVecMul::IArg0 IArg0;
4728 typedef typename MatVecMul::IArg1 IArg1;
4729
4730 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const4731 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
4732 {
4733 const IArg0& left = iargs.a;
4734 const IArg1& right = iargs.b;
4735
4736 return call<VecMatMul<T, Cols, Rows> >(ctx, right,
4737 call<Transpose<Rows, Cols, T> >(ctx, left));
4738 }
4739 };
4740
4741 template<int Rows, int Cols, class T>
4742 class OuterProduct : public PrimitiveFunc<Signature<Matrix<T, Rows, Cols>,
4743 Vector<T, Rows>,
4744 Vector<T, Cols> > >
4745 {
4746 public:
4747 typedef typename OuterProduct::IRet IRet;
4748 typedef typename OuterProduct::IArgs IArgs;
4749
getName(void) const4750 string getName (void) const
4751 {
4752 return "outerProduct";
4753 }
4754
4755 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const4756 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
4757 {
4758 IRet ret;
4759
4760 for (int row = 0; row < Rows; ++row)
4761 {
4762 for (int col = 0; col < Cols; ++col)
4763 ret[col][row] = call<Mul< Signature<T, T, T> > >(ctx, iargs.a[row], iargs.b[col]);
4764 }
4765
4766 return ret;
4767 }
4768 };
4769
4770 template<int Rows, int Cols, class T>
outerProduct(const ExprP<Vector<T,Rows>> & left,const ExprP<Vector<T,Cols>> & right)4771 ExprP<Matrix<T, Rows, Cols> > outerProduct (const ExprP<Vector<T, Rows> >& left,
4772 const ExprP<Vector<T, Cols> >& right)
4773 {
4774 return app<OuterProduct<Rows, Cols, T> >(left, right);
4775 }
4776
4777 template<class T>
4778 class DeterminantBase : public DerivedFunc<T>
4779 {
4780 public:
getName(void) const4781 string getName (void) const { return "determinant"; }
4782 };
4783
4784 template<int Size> class Determinant;
4785 template<int Size> class Determinant16bit;
4786 template<int Size> class Determinant64bit;
4787
4788 template<int Size>
determinant(ExprP<Matrix<float,Size,Size>> mat)4789 ExprP<float> determinant (ExprP<Matrix<float, Size, Size> > mat)
4790 {
4791 return app<Determinant<Size> >(mat);
4792 }
4793
4794 template<int Size>
determinant(ExprP<Matrix<deFloat16,Size,Size>> mat)4795 ExprP<deFloat16> determinant (ExprP<Matrix<deFloat16, Size, Size> > mat)
4796 {
4797 return app<Determinant16bit<Size> >(mat);
4798 }
4799
4800 template<int Size>
determinant(ExprP<Matrix<double,Size,Size>> mat)4801 ExprP<double> determinant (ExprP<Matrix<double, Size, Size> > mat)
4802 {
4803 return app<Determinant64bit<Size> >(mat);
4804 }
4805
4806 template<>
4807 class Determinant<2> : public DeterminantBase<Signature<float, Matrix<float, 2, 2> > >
4808 {
4809 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4810 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
4811 {
4812 ExprP<Mat2> mat = args.a;
4813
4814 return mat[0][0] * mat[1][1] - mat[1][0] * mat[0][1];
4815 }
4816 };
4817
4818 template<>
4819 class Determinant<3> : public DeterminantBase<Signature<float, Matrix<float, 3, 3> > >
4820 {
4821 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4822 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
4823 {
4824 ExprP<Mat3> mat = args.a;
4825
4826 return (mat[0][0] * (mat[1][1] * mat[2][2] - mat[1][2] * mat[2][1]) +
4827 mat[0][1] * (mat[1][2] * mat[2][0] - mat[1][0] * mat[2][2]) +
4828 mat[0][2] * (mat[1][0] * mat[2][1] - mat[1][1] * mat[2][0]));
4829 }
4830 };
4831
4832 template<>
4833 class Determinant<4> : public DeterminantBase<Signature<float, Matrix<float, 4, 4> > >
4834 {
4835 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4836 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
4837 {
4838 ExprP<Mat4> mat = args.a;
4839 ExprP<Mat3> minors[4];
4840
4841 for (int ndx = 0; ndx < 4; ++ndx)
4842 {
4843 ExprP<Vec4> minorColumns[3];
4844 ExprP<Vec3> columns[3];
4845
4846 for (int col = 0; col < 3; ++col)
4847 minorColumns[col] = mat[col < ndx ? col : col + 1];
4848
4849 for (int col = 0; col < 3; ++col)
4850 columns[col] = vec3(minorColumns[0][col+1],
4851 minorColumns[1][col+1],
4852 minorColumns[2][col+1]);
4853
4854 minors[ndx] = bindExpression("minor", ctx,
4855 mat3(columns[0], columns[1], columns[2]));
4856 }
4857
4858 return (mat[0][0] * determinant(minors[0]) -
4859 mat[1][0] * determinant(minors[1]) +
4860 mat[2][0] * determinant(minors[2]) -
4861 mat[3][0] * determinant(minors[3]));
4862 }
4863 };
4864
4865 template<>
4866 class Determinant16bit<2> : public DeterminantBase<Signature<deFloat16, Matrix<deFloat16, 2, 2> > >
4867 {
4868 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4869 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
4870 {
4871 ExprP<Mat2_16b> mat = args.a;
4872
4873 return mat[0][0] * mat[1][1] - mat[1][0] * mat[0][1];
4874 }
4875 };
4876
4877 template<>
4878 class Determinant16bit<3> : public DeterminantBase<Signature<deFloat16, Matrix<deFloat16, 3, 3> > >
4879 {
4880 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4881 ExprP<Ret> doExpand(ExpandContext&, const ArgExprs& args) const
4882 {
4883 ExprP<Mat3_16b> mat = args.a;
4884
4885 return (mat[0][0] * (mat[1][1] * mat[2][2] - mat[1][2] * mat[2][1]) +
4886 mat[0][1] * (mat[1][2] * mat[2][0] - mat[1][0] * mat[2][2]) +
4887 mat[0][2] * (mat[1][0] * mat[2][1] - mat[1][1] * mat[2][0]));
4888 }
4889 };
4890
4891 template<>
4892 class Determinant16bit<4> : public DeterminantBase<Signature<deFloat16, Matrix<deFloat16, 4, 4> > >
4893 {
4894 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4895 ExprP<Ret> doExpand(ExpandContext& ctx, const ArgExprs& args) const
4896 {
4897 ExprP<Mat4_16b> mat = args.a;
4898 ExprP<Mat3_16b> minors[4];
4899
4900 for (int ndx = 0; ndx < 4; ++ndx)
4901 {
4902 ExprP<Vec4_16Bit> minorColumns[3];
4903 ExprP<Vec3_16Bit> columns[3];
4904
4905 for (int col = 0; col < 3; ++col)
4906 minorColumns[col] = mat[col < ndx ? col : col + 1];
4907
4908 for (int col = 0; col < 3; ++col)
4909 columns[col] = vec3(minorColumns[0][col + 1],
4910 minorColumns[1][col + 1],
4911 minorColumns[2][col + 1]);
4912
4913 minors[ndx] = bindExpression("minor", ctx,
4914 mat3(columns[0], columns[1], columns[2]));
4915 }
4916
4917 return (mat[0][0] * determinant(minors[0]) -
4918 mat[1][0] * determinant(minors[1]) +
4919 mat[2][0] * determinant(minors[2]) -
4920 mat[3][0] * determinant(minors[3]));
4921 }
4922 };
4923
4924 template<>
4925 class Determinant64bit<2> : public DeterminantBase<Signature<double, Matrix<double, 2, 2> > >
4926 {
4927 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4928 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& args) const
4929 {
4930 ExprP<Matrix2d> mat = args.a;
4931
4932 return mat[0][0] * mat[1][1] - mat[1][0] * mat[0][1];
4933 }
4934 };
4935
4936 template<>
4937 class Determinant64bit<3> : public DeterminantBase<Signature<double, Matrix<double, 3, 3> > >
4938 {
4939 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4940 ExprP<Ret> doExpand(ExpandContext&, const ArgExprs& args) const
4941 {
4942 ExprP<Matrix3d> mat = args.a;
4943
4944 return (mat[0][0] * (mat[1][1] * mat[2][2] - mat[1][2] * mat[2][1]) +
4945 mat[0][1] * (mat[1][2] * mat[2][0] - mat[1][0] * mat[2][2]) +
4946 mat[0][2] * (mat[1][0] * mat[2][1] - mat[1][1] * mat[2][0]));
4947 }
4948 };
4949
4950 template<>
4951 class Determinant64bit<4> : public DeterminantBase<Signature<double, Matrix<double, 4, 4> > >
4952 {
4953 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4954 ExprP<Ret> doExpand(ExpandContext& ctx, const ArgExprs& args) const
4955 {
4956 ExprP<Matrix4d> mat = args.a;
4957 ExprP<Matrix3d> minors[4];
4958
4959 for (int ndx = 0; ndx < 4; ++ndx)
4960 {
4961 ExprP<Vec4_64Bit> minorColumns[3];
4962 ExprP<Vec3_64Bit> columns[3];
4963
4964 for (int col = 0; col < 3; ++col)
4965 minorColumns[col] = mat[col < ndx ? col : col + 1];
4966
4967 for (int col = 0; col < 3; ++col)
4968 columns[col] = vec3(minorColumns[0][col + 1],
4969 minorColumns[1][col + 1],
4970 minorColumns[2][col + 1]);
4971
4972 minors[ndx] = bindExpression("minor", ctx,
4973 mat3(columns[0], columns[1], columns[2]));
4974 }
4975
4976 return (mat[0][0] * determinant(minors[0]) -
4977 mat[1][0] * determinant(minors[1]) +
4978 mat[2][0] * determinant(minors[2]) -
4979 mat[3][0] * determinant(minors[3]));
4980 }
4981 };
4982
4983 template<int Size> class Inverse;
4984
4985 template <int Size>
inverse(ExprP<Matrix<float,Size,Size>> mat)4986 ExprP<Matrix<float, Size, Size> > inverse (ExprP<Matrix<float, Size, Size> > mat)
4987 {
4988 return app<Inverse<Size> >(mat);
4989 }
4990
4991 template<int Size> class Inverse16bit;
4992
4993 template <int Size>
inverse(ExprP<Matrix<deFloat16,Size,Size>> mat)4994 ExprP<Matrix<deFloat16, Size, Size> > inverse (ExprP<Matrix<deFloat16, Size, Size> > mat)
4995 {
4996 return app<Inverse16bit<Size> >(mat);
4997 }
4998
4999 template<int Size> class Inverse64bit;
5000
5001 template <int Size>
inverse(ExprP<Matrix<double,Size,Size>> mat)5002 ExprP<Matrix<double, Size, Size> > inverse (ExprP<Matrix<double, Size, Size> > mat)
5003 {
5004 return app<Inverse64bit<Size> >(mat);
5005 }
5006
5007 template<>
5008 class Inverse<2> : public DerivedFunc<Signature<Mat2, Mat2> >
5009 {
5010 public:
getName(void) const5011 string getName (void) const
5012 {
5013 return "inverse";
5014 }
5015
5016 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5017 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
5018 {
5019 ExprP<Mat2> mat = args.a;
5020 ExprP<float> det = bindExpression("det", ctx, determinant(mat));
5021
5022 return mat2(vec2(mat[1][1] / det, -mat[0][1] / det),
5023 vec2(-mat[1][0] / det, mat[0][0] / det));
5024 }
5025 };
5026
5027 template<>
5028 class Inverse<3> : public DerivedFunc<Signature<Mat3, Mat3> >
5029 {
5030 public:
getName(void) const5031 string getName (void) const
5032 {
5033 return "inverse";
5034 }
5035
5036 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5037 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
5038 {
5039 ExprP<Mat3> mat = args.a;
5040 ExprP<Mat2> invA = bindExpression("invA", ctx,
5041 inverse(mat2(vec2(mat[0][0], mat[0][1]),
5042 vec2(mat[1][0], mat[1][1]))));
5043
5044 ExprP<Vec2> matB = bindExpression("matB", ctx, vec2(mat[2][0], mat[2][1]));
5045 ExprP<Vec2> matC = bindExpression("matC", ctx, vec2(mat[0][2], mat[1][2]));
5046 ExprP<float> matD = bindExpression("matD", ctx, mat[2][2]);
5047
5048 ExprP<float> schur = bindExpression("schur", ctx,
5049 constant(1.0f) /
5050 (matD - dot(matC * invA, matB)));
5051
5052 ExprP<Vec2> t1 = invA * matB;
5053 ExprP<Vec2> t2 = t1 * schur;
5054 ExprP<Mat2> t3 = outerProduct(t2, matC);
5055 ExprP<Mat2> t4 = t3 * invA;
5056 ExprP<Mat2> t5 = invA + t4;
5057 ExprP<Mat2> blockA = bindExpression("blockA", ctx, t5);
5058 ExprP<Vec2> blockB = bindExpression("blockB", ctx,
5059 (invA * matB) * -schur);
5060 ExprP<Vec2> blockC = bindExpression("blockC", ctx,
5061 (matC * invA) * -schur);
5062
5063 return mat3(vec3(blockA[0][0], blockA[0][1], blockC[0]),
5064 vec3(blockA[1][0], blockA[1][1], blockC[1]),
5065 vec3(blockB[0], blockB[1], schur));
5066 }
5067 };
5068
5069 template<>
5070 class Inverse<4> : public DerivedFunc<Signature<Mat4, Mat4> >
5071 {
5072 public:
getName(void) const5073 string getName (void) const { return "inverse"; }
5074
5075 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5076 ExprP<Ret> doExpand (ExpandContext& ctx,
5077 const ArgExprs& args) const
5078 {
5079 ExprP<Mat4> mat = args.a;
5080 ExprP<Mat2> invA = bindExpression("invA", ctx,
5081 inverse(mat2(vec2(mat[0][0], mat[0][1]),
5082 vec2(mat[1][0], mat[1][1]))));
5083 ExprP<Mat2> matB = bindExpression("matB", ctx,
5084 mat2(vec2(mat[2][0], mat[2][1]),
5085 vec2(mat[3][0], mat[3][1])));
5086 ExprP<Mat2> matC = bindExpression("matC", ctx,
5087 mat2(vec2(mat[0][2], mat[0][3]),
5088 vec2(mat[1][2], mat[1][3])));
5089 ExprP<Mat2> matD = bindExpression("matD", ctx,
5090 mat2(vec2(mat[2][2], mat[2][3]),
5091 vec2(mat[3][2], mat[3][3])));
5092 ExprP<Mat2> schur = bindExpression("schur", ctx,
5093 inverse(matD + -(matC * invA * matB)));
5094 ExprP<Mat2> blockA = bindExpression("blockA", ctx,
5095 invA + (invA * matB * schur * matC * invA));
5096 ExprP<Mat2> blockB = bindExpression("blockB", ctx,
5097 (-invA) * matB * schur);
5098 ExprP<Mat2> blockC = bindExpression("blockC", ctx,
5099 (-schur) * matC * invA);
5100
5101 return mat4(vec4(blockA[0][0], blockA[0][1], blockC[0][0], blockC[0][1]),
5102 vec4(blockA[1][0], blockA[1][1], blockC[1][0], blockC[1][1]),
5103 vec4(blockB[0][0], blockB[0][1], schur[0][0], schur[0][1]),
5104 vec4(blockB[1][0], blockB[1][1], schur[1][0], schur[1][1]));
5105 }
5106 };
5107
5108 template<>
5109 class Inverse16bit<2> : public DerivedFunc<Signature<Mat2_16b, Mat2_16b> >
5110 {
5111 public:
getName(void) const5112 string getName (void) const
5113 {
5114 return "inverse";
5115 }
5116
5117 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5118 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
5119 {
5120 ExprP<Mat2_16b> mat = args.a;
5121 ExprP<deFloat16> det = bindExpression("det", ctx, determinant(mat));
5122
5123 return mat2(vec2((mat[1][1] / det), (-mat[0][1] / det)),
5124 vec2((-mat[1][0] / det), (mat[0][0] / det)));
5125 }
5126 };
5127
5128 template<>
5129 class Inverse16bit<3> : public DerivedFunc<Signature<Mat3_16b, Mat3_16b> >
5130 {
5131 public:
getName(void) const5132 string getName(void) const
5133 {
5134 return "inverse";
5135 }
5136
5137 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5138 ExprP<Ret> doExpand(ExpandContext& ctx, const ArgExprs& args) const
5139 {
5140 ExprP<Mat3_16b> mat = args.a;
5141 ExprP<Mat2_16b> invA = bindExpression("invA", ctx,
5142 inverse(mat2(vec2(mat[0][0], mat[0][1]),
5143 vec2(mat[1][0], mat[1][1]))));
5144
5145 ExprP<Vec2_16Bit> matB = bindExpression("matB", ctx, vec2(mat[2][0], mat[2][1]));
5146 ExprP<Vec2_16Bit> matC = bindExpression("matC", ctx, vec2(mat[0][2], mat[1][2]));
5147 ExprP<Mat3_16b::Scalar> matD = bindExpression("matD", ctx, mat[2][2]);
5148
5149 ExprP<Mat3_16b::Scalar> schur = bindExpression("schur", ctx,
5150 constant((deFloat16)FLOAT16_1_0) /
5151 (matD - dot(matC * invA, matB)));
5152
5153 ExprP<Vec2_16Bit> t1 = invA * matB;
5154 ExprP<Vec2_16Bit> t2 = t1 * schur;
5155 ExprP<Mat2_16b> t3 = outerProduct(t2, matC);
5156 ExprP<Mat2_16b> t4 = t3 * invA;
5157 ExprP<Mat2_16b> t5 = invA + t4;
5158 ExprP<Mat2_16b> blockA = bindExpression("blockA", ctx, t5);
5159 ExprP<Vec2_16Bit> blockB = bindExpression("blockB", ctx,
5160 (invA * matB) * -schur);
5161 ExprP<Vec2_16Bit> blockC = bindExpression("blockC", ctx,
5162 (matC * invA) * -schur);
5163
5164 return mat3(vec3(blockA[0][0], blockA[0][1], blockC[0]),
5165 vec3(blockA[1][0], blockA[1][1], blockC[1]),
5166 vec3(blockB[0], blockB[1], schur));
5167 }
5168 };
5169
5170 template<>
5171 class Inverse16bit<4> : public DerivedFunc<Signature<Mat4_16b, Mat4_16b> >
5172 {
5173 public:
getName(void) const5174 string getName(void) const { return "inverse"; }
5175
5176 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5177 ExprP<Ret> doExpand(ExpandContext& ctx,
5178 const ArgExprs& args) const
5179 {
5180 ExprP<Mat4_16b> mat = args.a;
5181 ExprP<Mat2_16b> invA = bindExpression("invA", ctx,
5182 inverse(mat2(vec2(mat[0][0], mat[0][1]),
5183 vec2(mat[1][0], mat[1][1]))));
5184 ExprP<Mat2_16b> matB = bindExpression("matB", ctx,
5185 mat2(vec2(mat[2][0], mat[2][1]),
5186 vec2(mat[3][0], mat[3][1])));
5187 ExprP<Mat2_16b> matC = bindExpression("matC", ctx,
5188 mat2(vec2(mat[0][2], mat[0][3]),
5189 vec2(mat[1][2], mat[1][3])));
5190 ExprP<Mat2_16b> matD = bindExpression("matD", ctx,
5191 mat2(vec2(mat[2][2], mat[2][3]),
5192 vec2(mat[3][2], mat[3][3])));
5193 ExprP<Mat2_16b> schur = bindExpression("schur", ctx,
5194 inverse(matD + -(matC * invA * matB)));
5195 ExprP<Mat2_16b> blockA = bindExpression("blockA", ctx,
5196 invA + (invA * matB * schur * matC * invA));
5197 ExprP<Mat2_16b> blockB = bindExpression("blockB", ctx,
5198 (-invA) * matB * schur);
5199 ExprP<Mat2_16b> blockC = bindExpression("blockC", ctx,
5200 (-schur) * matC * invA);
5201
5202 return mat4(vec4(blockA[0][0], blockA[0][1], blockC[0][0], blockC[0][1]),
5203 vec4(blockA[1][0], blockA[1][1], blockC[1][0], blockC[1][1]),
5204 vec4(blockB[0][0], blockB[0][1], schur[0][0], schur[0][1]),
5205 vec4(blockB[1][0], blockB[1][1], schur[1][0], schur[1][1]));
5206 }
5207 };
5208
5209 template<>
5210 class Inverse64bit<2> : public DerivedFunc<Signature<Matrix2d, Matrix2d> >
5211 {
5212 public:
getName(void) const5213 string getName (void) const
5214 {
5215 return "inverse";
5216 }
5217
5218 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5219 ExprP<Ret> doExpand (ExpandContext& ctx, const ArgExprs& args) const
5220 {
5221 ExprP<Matrix2d> mat = args.a;
5222 ExprP<double> det = bindExpression("det", ctx, determinant(mat));
5223
5224 return mat2(vec2((mat[1][1] / det), (-mat[0][1] / det)),
5225 vec2((-mat[1][0] / det), (mat[0][0] / det)));
5226 }
5227 };
5228
5229 template<>
5230 class Inverse64bit<3> : public DerivedFunc<Signature<Matrix3d, Matrix3d> >
5231 {
5232 public:
getName(void) const5233 string getName(void) const
5234 {
5235 return "inverse";
5236 }
5237
5238 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5239 ExprP<Ret> doExpand(ExpandContext& ctx, const ArgExprs& args) const
5240 {
5241 ExprP<Matrix3d> mat = args.a;
5242 ExprP<Matrix2d> invA = bindExpression("invA", ctx,
5243 inverse(mat2(vec2(mat[0][0], mat[0][1]),
5244 vec2(mat[1][0], mat[1][1]))));
5245
5246 ExprP<Vec2_64Bit> matB = bindExpression("matB", ctx, vec2(mat[2][0], mat[2][1]));
5247 ExprP<Vec2_64Bit> matC = bindExpression("matC", ctx, vec2(mat[0][2], mat[1][2]));
5248 ExprP<Matrix3d::Scalar> matD = bindExpression("matD", ctx, mat[2][2]);
5249
5250 ExprP<Matrix3d::Scalar> schur = bindExpression("schur", ctx,
5251 constant(1.0) /
5252 (matD - dot(matC * invA, matB)));
5253
5254 ExprP<Vec2_64Bit> t1 = invA * matB;
5255 ExprP<Vec2_64Bit> t2 = t1 * schur;
5256 ExprP<Matrix2d> t3 = outerProduct(t2, matC);
5257 ExprP<Matrix2d> t4 = t3 * invA;
5258 ExprP<Matrix2d> t5 = invA + t4;
5259 ExprP<Matrix2d> blockA = bindExpression("blockA", ctx, t5);
5260 ExprP<Vec2_64Bit> blockB = bindExpression("blockB", ctx,
5261 (invA * matB) * -schur);
5262 ExprP<Vec2_64Bit> blockC = bindExpression("blockC", ctx,
5263 (matC * invA) * -schur);
5264
5265 return mat3(vec3(blockA[0][0], blockA[0][1], blockC[0]),
5266 vec3(blockA[1][0], blockA[1][1], blockC[1]),
5267 vec3(blockB[0], blockB[1], schur));
5268 }
5269 };
5270
5271 template<>
5272 class Inverse64bit<4> : public DerivedFunc<Signature<Matrix4d, Matrix4d> >
5273 {
5274 public:
getName(void) const5275 string getName(void) const { return "inverse"; }
5276
5277 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5278 ExprP<Ret> doExpand(ExpandContext& ctx,
5279 const ArgExprs& args) const
5280 {
5281 ExprP<Matrix4d> mat = args.a;
5282 ExprP<Matrix2d> invA = bindExpression("invA", ctx,
5283 inverse(mat2(vec2(mat[0][0], mat[0][1]),
5284 vec2(mat[1][0], mat[1][1]))));
5285 ExprP<Matrix2d> matB = bindExpression("matB", ctx,
5286 mat2(vec2(mat[2][0], mat[2][1]),
5287 vec2(mat[3][0], mat[3][1])));
5288 ExprP<Matrix2d> matC = bindExpression("matC", ctx,
5289 mat2(vec2(mat[0][2], mat[0][3]),
5290 vec2(mat[1][2], mat[1][3])));
5291 ExprP<Matrix2d> matD = bindExpression("matD", ctx,
5292 mat2(vec2(mat[2][2], mat[2][3]),
5293 vec2(mat[3][2], mat[3][3])));
5294 ExprP<Matrix2d> schur = bindExpression("schur", ctx,
5295 inverse(matD + -(matC * invA * matB)));
5296 ExprP<Matrix2d> blockA = bindExpression("blockA", ctx,
5297 invA + (invA * matB * schur * matC * invA));
5298 ExprP<Matrix2d> blockB = bindExpression("blockB", ctx,
5299 (-invA) * matB * schur);
5300 ExprP<Matrix2d> blockC = bindExpression("blockC", ctx,
5301 (-schur) * matC * invA);
5302
5303 return mat4(vec4(blockA[0][0], blockA[0][1], blockC[0][0], blockC[0][1]),
5304 vec4(blockA[1][0], blockA[1][1], blockC[1][0], blockC[1][1]),
5305 vec4(blockB[0][0], blockB[0][1], schur[0][0], schur[0][1]),
5306 vec4(blockB[1][0], blockB[1][1], schur[1][0], schur[1][1]));
5307 }
5308 };
5309
5310 //Signature<float, float, float, float>
5311 //Signature<deFloat16, deFloat16, deFloat16, deFloat16>
5312 //Signature<double, double, double, double>
5313 template <class T>
5314 class Fma : public DerivedFunc<T>
5315 {
5316 public:
5317 typedef typename Fma::ArgExprs ArgExprs;
5318 typedef typename Fma::Ret Ret;
5319
getName(void) const5320 string getName (void) const
5321 {
5322 return "fma";
5323 }
5324
5325 protected:
doExpand(ExpandContext &,const ArgExprs & x) const5326 ExprP<Ret> doExpand (ExpandContext&, const ArgExprs& x) const
5327 {
5328 return x.a * x.b + x.c;
5329 }
5330 };
5331
5332 } // Functions
5333
5334 using namespace Functions;
5335
5336 template <typename T>
operator [](int i) const5337 ExprP<typename T::Element> ContainerExprPBase<T>::operator[] (int i) const
5338 {
5339 return Functions::getComponent(exprP<T>(*this), i);
5340 }
5341
operator +(const ExprP<float> & arg0,const ExprP<float> & arg1)5342 ExprP<float> operator+ (const ExprP<float>& arg0, const ExprP<float>& arg1)
5343 {
5344 return app<Add< Signature<float, float, float> > >(arg0, arg1);
5345 }
5346
operator +(const ExprP<deFloat16> & arg0,const ExprP<deFloat16> & arg1)5347 ExprP<deFloat16> operator+ (const ExprP<deFloat16>& arg0, const ExprP<deFloat16>& arg1)
5348 {
5349 return app<Add< Signature<deFloat16, deFloat16, deFloat16> > >(arg0, arg1);
5350 }
5351
operator +(const ExprP<double> & arg0,const ExprP<double> & arg1)5352 ExprP<double> operator+ (const ExprP<double>& arg0, const ExprP<double>& arg1)
5353 {
5354 return app<Add< Signature<double, double, double> > >(arg0, arg1);
5355 }
5356
5357 template <typename T>
operator -(const ExprP<T> & arg0,const ExprP<T> & arg1)5358 ExprP<T> operator- (const ExprP<T>& arg0, const ExprP<T>& arg1)
5359 {
5360 return app<Sub <Signature <T,T,T> > >(arg0, arg1);
5361 }
5362
5363 template <typename T>
operator -(const ExprP<T> & arg0)5364 ExprP<T> operator- (const ExprP<T>& arg0)
5365 {
5366 return app<Negate< Signature<T, T> > >(arg0);
5367 }
5368
operator *(const ExprP<float> & arg0,const ExprP<float> & arg1)5369 ExprP<float> operator* (const ExprP<float>& arg0, const ExprP<float>& arg1)
5370 {
5371 return app<Mul< Signature<float, float, float> > >(arg0, arg1);
5372 }
5373
operator *(const ExprP<deFloat16> & arg0,const ExprP<deFloat16> & arg1)5374 ExprP<deFloat16> operator* (const ExprP<deFloat16>& arg0, const ExprP<deFloat16>& arg1)
5375 {
5376 return app<Mul< Signature<deFloat16, deFloat16, deFloat16> > >(arg0, arg1);
5377 }
5378
operator *(const ExprP<double> & arg0,const ExprP<double> & arg1)5379 ExprP<double> operator* (const ExprP<double>& arg0, const ExprP<double>& arg1)
5380 {
5381 return app<Mul< Signature<double, double, double> > >(arg0, arg1);
5382 }
5383
5384 template <typename T>
operator /(const ExprP<T> & arg0,const ExprP<T> & arg1)5385 ExprP<T> operator/ (const ExprP<T>& arg0, const ExprP<T>& arg1)
5386 {
5387 return app<Div< Signature<T, T, T> > >(arg0, arg1);
5388 }
5389
5390
5391 template <typename Sig_, int Size>
5392 class GenFunc : public PrimitiveFunc<Signature<
5393 typename ContainerOf<typename Sig_::Ret, Size>::Container,
5394 typename ContainerOf<typename Sig_::Arg0, Size>::Container,
5395 typename ContainerOf<typename Sig_::Arg1, Size>::Container,
5396 typename ContainerOf<typename Sig_::Arg2, Size>::Container,
5397 typename ContainerOf<typename Sig_::Arg3, Size>::Container> >
5398 {
5399 public:
5400 typedef typename GenFunc::IArgs IArgs;
5401 typedef typename GenFunc::IRet IRet;
5402
GenFunc(const Func<Sig_> & scalarFunc)5403 GenFunc (const Func<Sig_>& scalarFunc) : m_func (scalarFunc) {}
5404
getSpirvCase(void) const5405 SpirVCaseT getSpirvCase (void) const
5406 {
5407 return m_func.getSpirvCase();
5408 }
5409
getName(void) const5410 string getName (void) const
5411 {
5412 return m_func.getName();
5413 }
5414
getOutParamIndex(void) const5415 int getOutParamIndex (void) const
5416 {
5417 return m_func.getOutParamIndex();
5418 }
5419
getRequiredExtension(void) const5420 string getRequiredExtension (void) const
5421 {
5422 return m_func.getRequiredExtension();
5423 }
5424
getInputRange(const bool is16bit) const5425 Interval getInputRange (const bool is16bit) const
5426 {
5427 return m_func.getInputRange(is16bit);
5428 }
5429
5430 protected:
doPrint(ostream & os,const BaseArgExprs & args) const5431 void doPrint (ostream& os, const BaseArgExprs& args) const
5432 {
5433 m_func.print(os, args);
5434 }
5435
doApply(const EvalContext & ctx,const IArgs & iargs) const5436 IRet doApply (const EvalContext& ctx, const IArgs& iargs) const
5437 {
5438 IRet ret;
5439
5440 for (int ndx = 0; ndx < Size; ++ndx)
5441 {
5442 ret[ndx] =
5443 m_func.apply(ctx, iargs.a[ndx], iargs.b[ndx], iargs.c[ndx], iargs.d[ndx]);
5444 }
5445
5446 return ret;
5447 }
5448
doFail(const EvalContext & ctx,const IArgs & iargs) const5449 IRet doFail (const EvalContext& ctx, const IArgs& iargs) const
5450 {
5451 IRet ret;
5452
5453 for (int ndx = 0; ndx < Size; ++ndx)
5454 {
5455 ret[ndx] =
5456 m_func.fail(ctx, iargs.a[ndx], iargs.b[ndx], iargs.c[ndx], iargs.d[ndx]);
5457 }
5458
5459 return ret;
5460 }
5461
doGetUsedFuncs(FuncSet & dst) const5462 void doGetUsedFuncs (FuncSet& dst) const
5463 {
5464 m_func.getUsedFuncs(dst);
5465 }
5466
5467 const Func<Sig_>& m_func;
5468 };
5469
5470 template <typename F, int Size>
5471 class VectorizedFunc : public GenFunc<typename F::Sig, Size>
5472 {
5473 public:
VectorizedFunc(void)5474 VectorizedFunc (void) : GenFunc<typename F::Sig, Size>(instance<F>()) {}
5475 };
5476
5477 template <typename Sig_, int Size>
5478 class FixedGenFunc : public PrimitiveFunc <Signature<
5479 typename ContainerOf<typename Sig_::Ret, Size>::Container,
5480 typename ContainerOf<typename Sig_::Arg0, Size>::Container,
5481 typename Sig_::Arg1,
5482 typename ContainerOf<typename Sig_::Arg2, Size>::Container,
5483 typename ContainerOf<typename Sig_::Arg3, Size>::Container> >
5484 {
5485 public:
5486 typedef typename FixedGenFunc::IArgs IArgs;
5487 typedef typename FixedGenFunc::IRet IRet;
5488
getName(void) const5489 string getName (void) const
5490 {
5491 return this->doGetScalarFunc().getName();
5492 }
5493
getSpirvCase(void) const5494 SpirVCaseT getSpirvCase (void) const
5495 {
5496 return this->doGetScalarFunc().getSpirvCase();
5497 }
5498
5499 protected:
doPrint(ostream & os,const BaseArgExprs & args) const5500 void doPrint (ostream& os, const BaseArgExprs& args) const
5501 {
5502 this->doGetScalarFunc().print(os, args);
5503 }
5504
doApply(const EvalContext & ctx,const IArgs & iargs) const5505 IRet doApply (const EvalContext& ctx,
5506 const IArgs& iargs) const
5507 {
5508 IRet ret;
5509 const Func<Sig_>& func = this->doGetScalarFunc();
5510
5511 for (int ndx = 0; ndx < Size; ++ndx)
5512 ret[ndx] = func.apply(ctx, iargs.a[ndx], iargs.b, iargs.c[ndx], iargs.d[ndx]);
5513
5514 return ret;
5515 }
5516
5517 virtual const Func<Sig_>& doGetScalarFunc (void) const = 0;
5518 };
5519
5520 template <typename F, int Size>
5521 class FixedVecFunc : public FixedGenFunc<typename F::Sig, Size>
5522 {
5523 protected:
doGetScalarFunc(void) const5524 const Func<typename F::Sig>& doGetScalarFunc (void) const { return instance<F>(); }
5525 };
5526
5527 template<typename Sig>
5528 struct GenFuncs
5529 {
GenFuncsvkt::shaderexecutor::GenFuncs5530 GenFuncs (const Func<Sig>& func_,
5531 const GenFunc<Sig, 2>& func2_,
5532 const GenFunc<Sig, 3>& func3_,
5533 const GenFunc<Sig, 4>& func4_)
5534 : func (func_)
5535 , func2 (func2_)
5536 , func3 (func3_)
5537 , func4 (func4_)
5538 {}
5539
5540 const Func<Sig>& func;
5541 const GenFunc<Sig, 2>& func2;
5542 const GenFunc<Sig, 3>& func3;
5543 const GenFunc<Sig, 4>& func4;
5544 };
5545
5546 template<typename F>
makeVectorizedFuncs(void)5547 GenFuncs<typename F::Sig> makeVectorizedFuncs (void)
5548 {
5549 return GenFuncs<typename F::Sig>(instance<F>(),
5550 instance<VectorizedFunc<F, 2> >(),
5551 instance<VectorizedFunc<F, 3> >(),
5552 instance<VectorizedFunc<F, 4> >());
5553 }
5554
5555 template<typename T, int Size>
operator /(const ExprP<Vector<T,Size>> & arg0,const ExprP<T> & arg1)5556 ExprP<Vector<T, Size> > operator/(const ExprP<Vector<T, Size> >& arg0,
5557 const ExprP<T>& arg1)
5558 {
5559 return app<FixedVecFunc<Div< Signature<T, T, T> >, Size> >(arg0, arg1);
5560 }
5561
5562 template<typename T, int Size>
operator -(const ExprP<Vector<T,Size>> & arg0)5563 ExprP<Vector<T, Size> > operator-(const ExprP<Vector<T, Size> >& arg0)
5564 {
5565 return app<VectorizedFunc<Negate< Signature<T, T> >, Size> >(arg0);
5566 }
5567
5568 template<typename T, int Size>
operator -(const ExprP<Vector<T,Size>> & arg0,const ExprP<Vector<T,Size>> & arg1)5569 ExprP<Vector<T, Size> > operator-(const ExprP<Vector<T, Size> >& arg0,
5570 const ExprP<Vector<T, Size> >& arg1)
5571 {
5572 return app<VectorizedFunc<Sub<Signature<T, T, T> >, Size> >(arg0, arg1);
5573 }
5574
5575 template<int Size, typename T>
operator *(const ExprP<Vector<T,Size>> & arg0,const ExprP<T> & arg1)5576 ExprP<Vector<T, Size> > operator*(const ExprP<Vector<T, Size> >& arg0,
5577 const ExprP<T>& arg1)
5578 {
5579 return app<FixedVecFunc<Mul< Signature<T, T, T> >, Size> >(arg0, arg1);
5580 }
5581
5582 template<typename T, int Size>
operator *(const ExprP<Vector<T,Size>> & arg0,const ExprP<Vector<T,Size>> & arg1)5583 ExprP<Vector<T, Size> > operator*(const ExprP<Vector<T, Size> >& arg0,
5584 const ExprP<Vector<T, Size> >& arg1)
5585 {
5586 return app<VectorizedFunc<Mul< Signature<T, T, T> >, Size> >(arg0, arg1);
5587 }
5588
5589 template<int LeftRows, int Middle, int RightCols, typename T>
5590 ExprP<Matrix<T, LeftRows, RightCols> >
operator *(const ExprP<Matrix<T,LeftRows,Middle>> & left,const ExprP<Matrix<T,Middle,RightCols>> & right)5591 operator* (const ExprP<Matrix<T, LeftRows, Middle> >& left,
5592 const ExprP<Matrix<T, Middle, RightCols> >& right)
5593 {
5594 return app<MatMul<T, LeftRows, Middle, RightCols> >(left, right);
5595 }
5596
5597 template<int Rows, int Cols, typename T>
operator *(const ExprP<Vector<T,Cols>> & left,const ExprP<Matrix<T,Rows,Cols>> & right)5598 ExprP<Vector<T, Rows> > operator* (const ExprP<Vector<T, Cols> >& left,
5599 const ExprP<Matrix<T, Rows, Cols> >& right)
5600 {
5601 return app<VecMatMul<T, Rows, Cols> >(left, right);
5602 }
5603
5604 template<int Rows, int Cols, class T>
operator *(const ExprP<Matrix<T,Rows,Cols>> & left,const ExprP<Vector<T,Rows>> & right)5605 ExprP<Vector<T, Cols> > operator* (const ExprP<Matrix<T, Rows, Cols> >& left,
5606 const ExprP<Vector<T, Rows> >& right)
5607 {
5608 return app<MatVecMul<Rows, Cols, T> >(left, right);
5609 }
5610
5611 template<int Rows, int Cols, typename T>
operator *(const ExprP<Matrix<T,Rows,Cols>> & left,const ExprP<T> & right)5612 ExprP<Matrix<T, Rows, Cols> > operator* (const ExprP<Matrix<T, Rows, Cols> >& left,
5613 const ExprP<T>& right)
5614 {
5615 return app<ScalarMatFunc<Mul< Signature<T, T, T> >, Rows, Cols> >(left, right);
5616 }
5617
5618 template<int Rows, int Cols>
operator +(const ExprP<Matrix<float,Rows,Cols>> & left,const ExprP<Matrix<float,Rows,Cols>> & right)5619 ExprP<Matrix<float, Rows, Cols> > operator+ (const ExprP<Matrix<float, Rows, Cols> >& left,
5620 const ExprP<Matrix<float, Rows, Cols> >& right)
5621 {
5622 return app<CompMatFunc<Add< Signature<float, float, float> >,float, Rows, Cols> >(left, right);
5623 }
5624
5625 template<int Rows, int Cols>
operator +(const ExprP<Matrix<deFloat16,Rows,Cols>> & left,const ExprP<Matrix<deFloat16,Rows,Cols>> & right)5626 ExprP<Matrix<deFloat16, Rows, Cols> > operator+ (const ExprP<Matrix<deFloat16, Rows, Cols> >& left,
5627 const ExprP<Matrix<deFloat16, Rows, Cols> >& right)
5628 {
5629 return app<CompMatFunc<Add< Signature<deFloat16, deFloat16, deFloat16> >, deFloat16, Rows, Cols> >(left, right);
5630 }
5631
5632 template<int Rows, int Cols>
operator +(const ExprP<Matrix<double,Rows,Cols>> & left,const ExprP<Matrix<double,Rows,Cols>> & right)5633 ExprP<Matrix<double, Rows, Cols> > operator+ (const ExprP<Matrix<double, Rows, Cols> >& left,
5634 const ExprP<Matrix<double, Rows, Cols> >& right)
5635 {
5636 return app<CompMatFunc<Add< Signature<double, double, double> >, double, Rows, Cols> >(left, right);
5637 }
5638
5639 template<typename T, int Rows, int Cols>
operator -(const ExprP<Matrix<T,Rows,Cols>> & mat)5640 ExprP<Matrix<T, Rows, Cols> > operator- (const ExprP<Matrix<T, Rows, Cols> >& mat)
5641 {
5642 return app<MatNeg<T, Rows, Cols> >(mat);
5643 }
5644
5645 template <typename T>
5646 class Sampling
5647 {
5648 public:
genFixeds(const FloatFormat &,const Precision,vector<T> &,const Interval &) const5649 virtual void genFixeds (const FloatFormat&, const Precision, vector<T>&, const Interval&) const {}
genRandom(const FloatFormat &,const Precision,Random &,const Interval &) const5650 virtual T genRandom (const FloatFormat&,const Precision, Random&, const Interval&) const { return T(); }
removeNotInRange(vector<T> &,const Interval &,const Precision) const5651 virtual void removeNotInRange (vector<T>&, const Interval&, const Precision) const {}
5652 };
5653
5654 template <>
5655 class DefaultSampling<Void> : public Sampling<Void>
5656 {
5657 public:
genFixeds(const FloatFormat &,const Precision,vector<Void> & dst,const Interval &) const5658 void genFixeds (const FloatFormat&, const Precision, vector<Void>& dst, const Interval&) const { dst.push_back(Void()); }
5659 };
5660
5661 template <>
5662 class DefaultSampling<bool> : public Sampling<bool>
5663 {
5664 public:
genFixeds(const FloatFormat &,const Precision,vector<bool> & dst,const Interval &) const5665 void genFixeds (const FloatFormat&, const Precision, vector<bool>& dst, const Interval&) const
5666 {
5667 dst.push_back(true);
5668 dst.push_back(false);
5669 }
5670 };
5671
5672 template <>
5673 class DefaultSampling<int> : public Sampling<int>
5674 {
5675 public:
genRandom(const FloatFormat &,const Precision prec,Random & rnd,const Interval &) const5676 int genRandom (const FloatFormat&, const Precision prec, Random& rnd, const Interval&) const
5677 {
5678 const int exp = rnd.getInt(0, getNumBits(prec)-2);
5679 const int sign = rnd.getBool() ? -1 : 1;
5680
5681 return sign * rnd.getInt(0, (deInt32)1 << exp);
5682 }
5683
genFixeds(const FloatFormat &,const Precision,vector<int> & dst,const Interval &) const5684 void genFixeds (const FloatFormat&, const Precision, vector<int>& dst, const Interval&) const
5685 {
5686 dst.push_back(0);
5687 dst.push_back(-1);
5688 dst.push_back(1);
5689 }
5690
5691 private:
getNumBits(Precision prec)5692 static inline int getNumBits (Precision prec)
5693 {
5694 switch (prec)
5695 {
5696 case glu::PRECISION_LAST:
5697 case glu::PRECISION_MEDIUMP: return 16;
5698 case glu::PRECISION_HIGHP: return 32;
5699 default:
5700 DE_ASSERT(false);
5701 return 0;
5702 }
5703 }
5704 };
5705
5706 template <>
5707 class DefaultSampling<float> : public Sampling<float>
5708 {
5709 public:
5710 float genRandom (const FloatFormat& format, const Precision prec, Random& rnd, const Interval& inputRange) const;
5711 void genFixeds (const FloatFormat& format, const Precision prec, vector<float>& dst, const Interval& inputRange) const;
5712 void removeNotInRange (vector<float>& dst, const Interval& inputRange, const Precision prec) const;
5713 };
5714
5715 template <>
5716 class DefaultSampling<double> : public Sampling<double>
5717 {
5718 public:
5719 double genRandom (const FloatFormat& format, const Precision prec, Random& rnd, const Interval& inputRange) const;
5720 void genFixeds (const FloatFormat& format, const Precision prec, vector<double>& dst, const Interval& inputRange) const;
5721 void removeNotInRange (vector<double>& dst, const Interval& inputRange, const Precision prec) const;
5722 };
5723
isDenorm16(deFloat16 v)5724 static bool isDenorm16(deFloat16 v)
5725 {
5726 const deUint16 mantissa = 0x03FF;
5727 const deUint16 exponent = 0x7C00;
5728 return ((exponent & v) == 0 && (mantissa & v) != 0);
5729 }
5730
5731 //! Generate a random double from a reasonable general-purpose distribution.
randomDouble(const FloatFormat & format,Random & rnd,const Interval & inputRange)5732 double randomDouble(const FloatFormat& format, Random& rnd, const Interval& inputRange)
5733 {
5734 // No testing of subnormals. TODO: Could integrate float controls for some operations.
5735 const int minExp = format.getMinExp();
5736 const int maxExp = format.getMaxExp();
5737 const bool haveSubnormal = false;
5738 const double midpoint = inputRange.midpoint();
5739
5740 // Choose exponent so that the cumulative distribution is cubic.
5741 // This makes the probability distribution quadratic, with the peak centered on zero.
5742 const double minRoot = deCbrt(minExp - 0.5 - (haveSubnormal ? 1.0 : 0.0));
5743 const double maxRoot = deCbrt(maxExp + 0.5);
5744 const int fractionBits = format.getFractionBits();
5745 const int exp = int(deRoundEven(dePow(rnd.getDouble(minRoot, maxRoot), 3.0)));
5746
5747 // Generate some occasional special numbers
5748 switch (rnd.getInt(0, 64))
5749 {
5750 case 0: return inputRange.contains(0) ? 0 : midpoint;
5751 case 1: return inputRange.contains(TCU_INFINITY) ? TCU_INFINITY : midpoint;
5752 case 2: return inputRange.contains(-TCU_INFINITY) ? -TCU_INFINITY : midpoint;
5753 case 3: return inputRange.contains(TCU_NAN) ? TCU_NAN : midpoint;
5754 default: break;
5755 }
5756
5757 DE_ASSERT(fractionBits < std::numeric_limits<double>::digits);
5758
5759 // Normal number
5760 double base = deLdExp(1.0, exp);
5761 double quantum = deLdExp(1.0, exp - fractionBits); // smallest representable difference in the binade
5762 double significand = 0.0;
5763 switch (rnd.getInt(0, 16))
5764 {
5765 case 0: // The highest number in this binade, significand is all bits one.
5766 significand = base - quantum;
5767 break;
5768 case 1: // Significand is one.
5769 significand = quantum;
5770 break;
5771 case 2: // Significand is zero.
5772 significand = 0.0;
5773 break;
5774 default: // Random (evenly distributed) significand.
5775 {
5776 deUint64 intFraction = rnd.getUint64() & ((1 << fractionBits) - 1);
5777 significand = double(intFraction) * quantum;
5778 }
5779 }
5780
5781 // Produce positive numbers more often than negative.
5782 double value = (rnd.getInt(0, 3) == 0 ? -1.0 : 1.0) * (base + significand);
5783 return inputRange.contains(value) ? value : midpoint;
5784 }
5785
5786 //! Generate a random float from a reasonable general-purpose distribution.
genRandom(const FloatFormat & format,Precision prec,Random & rnd,const Interval & inputRange) const5787 float DefaultSampling<float>::genRandom (const FloatFormat& format,
5788 Precision prec,
5789 Random& rnd,
5790 const Interval& inputRange) const
5791 {
5792 DE_UNREF(prec);
5793 return (float)randomDouble(format, rnd, inputRange);
5794 }
5795
5796 //! Generate a standard set of floats that should always be tested.
genFixeds(const FloatFormat & format,const Precision prec,vector<float> & dst,const Interval & inputRange) const5797 void DefaultSampling<float>::genFixeds (const FloatFormat& format, const Precision prec, vector<float>& dst, const Interval& inputRange) const
5798 {
5799 const int minExp = format.getMinExp();
5800 const int maxExp = format.getMaxExp();
5801 const int fractionBits = format.getFractionBits();
5802 const float minQuantum = deFloatLdExp(1.0f, minExp - fractionBits);
5803 const float minNormalized = deFloatLdExp(1.0f, minExp);
5804 const float maxQuantum = deFloatLdExp(1.0f, maxExp - fractionBits);
5805
5806 // NaN
5807 dst.push_back(TCU_NAN);
5808 // Zero
5809 dst.push_back(0.0f);
5810
5811 for (int sign = -1; sign <= 1; sign += 2)
5812 {
5813 // Smallest normalized
5814 dst.push_back((float)sign * minNormalized);
5815
5816 // Next smallest normalized
5817 dst.push_back((float)sign * (minNormalized + minQuantum));
5818
5819 dst.push_back((float)sign * 0.5f);
5820 dst.push_back((float)sign * 1.0f);
5821 dst.push_back((float)sign * 2.0f);
5822
5823 // Largest number
5824 dst.push_back((float)sign * (deFloatLdExp(1.0f, maxExp) +
5825 (deFloatLdExp(1.0f, maxExp) - maxQuantum)));
5826
5827 dst.push_back((float)sign * TCU_INFINITY);
5828 }
5829 removeNotInRange(dst, inputRange, prec);
5830 }
5831
removeNotInRange(vector<float> & dst,const Interval & inputRange,const Precision prec) const5832 void DefaultSampling<float>::removeNotInRange (vector<float>& dst, const Interval& inputRange, const Precision prec) const
5833 {
5834 for (vector<float>::iterator it = dst.begin(); it < dst.end();)
5835 {
5836 // Remove out of range values. PRECISION_LAST means this is an FP16 test so remove any values that
5837 // will be denorms when converted to FP16. (This is used in the precision_fp16_storage32b test group).
5838 if ( !inputRange.contains(static_cast<double>(*it)) || (prec == glu::PRECISION_LAST && isDenorm16(deFloat32To16Round(*it, DE_ROUNDINGMODE_TO_ZERO))))
5839 it = dst.erase(it);
5840 else
5841 ++it;
5842 }
5843 }
5844
5845 //! Generate a random double from a reasonable general-purpose distribution.
genRandom(const FloatFormat & format,Precision prec,Random & rnd,const Interval & inputRange) const5846 double DefaultSampling<double>::genRandom (const FloatFormat& format,
5847 Precision prec,
5848 Random& rnd,
5849 const Interval& inputRange) const
5850 {
5851 DE_UNREF(prec);
5852 return randomDouble(format, rnd, inputRange);
5853 }
5854
5855 //! Generate a standard set of floats that should always be tested.
genFixeds(const FloatFormat & format,const Precision prec,vector<double> & dst,const Interval & inputRange) const5856 void DefaultSampling<double>::genFixeds (const FloatFormat& format, const Precision prec, vector<double>& dst, const Interval& inputRange) const
5857 {
5858 const int minExp = format.getMinExp();
5859 const int maxExp = format.getMaxExp();
5860 const int fractionBits = format.getFractionBits();
5861 const double minQuantum = deLdExp(1.0, minExp - fractionBits);
5862 const double minNormalized = deLdExp(1.0, minExp);
5863 const double maxQuantum = deLdExp(1.0, maxExp - fractionBits);
5864
5865 // NaN
5866 dst.push_back(TCU_NAN);
5867 // Zero
5868 dst.push_back(0.0);
5869
5870 for (int sign = -1; sign <= 1; sign += 2)
5871 {
5872 // Smallest normalized
5873 dst.push_back((double)sign * minNormalized);
5874
5875 // Next smallest normalized
5876 dst.push_back((double)sign * (minNormalized + minQuantum));
5877
5878 dst.push_back((double)sign * 0.5);
5879 dst.push_back((double)sign * 1.0);
5880 dst.push_back((double)sign * 2.0);
5881
5882 // Largest number
5883 dst.push_back((double)sign * (deLdExp(1.0, maxExp) + (deLdExp(1.0, maxExp) - maxQuantum)));
5884
5885 dst.push_back((double)sign * TCU_INFINITY);
5886 }
5887 removeNotInRange(dst, inputRange, prec);
5888 }
5889
removeNotInRange(vector<double> & dst,const Interval & inputRange,const Precision) const5890 void DefaultSampling<double>::removeNotInRange (vector<double>& dst, const Interval& inputRange, const Precision) const
5891 {
5892 for (vector<double>::iterator it = dst.begin(); it < dst.end();)
5893 {
5894 if ( !inputRange.contains(*it) )
5895 it = dst.erase(it);
5896 else
5897 ++it;
5898 }
5899 }
5900
5901 template <>
5902 class DefaultSampling<deFloat16> : public Sampling<deFloat16>
5903 {
5904 public:
5905 deFloat16 genRandom (const FloatFormat& format, const Precision prec, Random& rnd, const Interval& inputRange) const;
5906 void genFixeds (const FloatFormat& format, const Precision prec, vector<deFloat16>& dst, const Interval& inputRange) const;
5907 private:
5908 void removeNotInRange(vector<deFloat16>& dst, const Interval& inputRange, const Precision prec) const;
5909 };
5910
5911 //! Generate a random float from a reasonable general-purpose distribution.
genRandom(const FloatFormat & format,const Precision prec,Random & rnd,const Interval & inputRange) const5912 deFloat16 DefaultSampling<deFloat16>::genRandom (const FloatFormat& format, const Precision prec,
5913 Random& rnd, const Interval& inputRange) const
5914 {
5915 DE_UNREF(prec);
5916 return deFloat64To16Round(randomDouble(format, rnd, inputRange), DE_ROUNDINGMODE_TO_NEAREST_EVEN);
5917 }
5918
5919 //! Generate a standard set of floats that should always be tested.
genFixeds(const FloatFormat & format,const Precision prec,vector<deFloat16> & dst,const Interval & inputRange) const5920 void DefaultSampling<deFloat16>::genFixeds (const FloatFormat& format, const Precision prec, vector<deFloat16>& dst, const Interval& inputRange) const
5921 {
5922 dst.push_back(deUint16(0x3E00)); //1.5
5923 dst.push_back(deUint16(0x3D00)); //1.25
5924 dst.push_back(deUint16(0x3F00)); //1.75
5925 // Zero
5926 dst.push_back(deUint16(0x0000));
5927 dst.push_back(deUint16(0x8000));
5928 // Infinity
5929 dst.push_back(deUint16(0x7c00));
5930 dst.push_back(deUint16(0xfc00));
5931 // SNaN
5932 dst.push_back(deUint16(0x7c0f));
5933 dst.push_back(deUint16(0xfc0f));
5934 // QNaN
5935 dst.push_back(deUint16(0x7cf0));
5936 dst.push_back(deUint16(0xfcf0));
5937 // Normalized
5938 dst.push_back(deUint16(0x0401));
5939 dst.push_back(deUint16(0x8401));
5940 // Some normal number
5941 dst.push_back(deUint16(0x14cb));
5942 dst.push_back(deUint16(0x94cb));
5943
5944 const int minExp = format.getMinExp();
5945 const int maxExp = format.getMaxExp();
5946 const int fractionBits = format.getFractionBits();
5947 const float minQuantum = deFloatLdExp(1.0f, minExp - fractionBits);
5948 const float minNormalized = deFloatLdExp(1.0f, minExp);
5949 const float maxQuantum = deFloatLdExp(1.0f, maxExp - fractionBits);
5950
5951 for (float sign = -1.0; sign <= 1.0f; sign += 2.0f)
5952 {
5953 // Smallest normalized
5954 dst.push_back(deFloat32To16Round(sign * minNormalized, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5955
5956 // Next smallest normalized
5957 dst.push_back(deFloat32To16Round(sign * (minNormalized + minQuantum), DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5958
5959 dst.push_back(deFloat32To16Round(sign * 0.5f, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5960 dst.push_back(deFloat32To16Round(sign * 1.0f, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5961 dst.push_back(deFloat32To16Round(sign * 2.0f, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5962
5963 // Largest number
5964 dst.push_back(deFloat32To16Round(sign * (deFloatLdExp(1.0f, maxExp) +
5965 (deFloatLdExp(1.0f, maxExp) - maxQuantum)), DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5966
5967 dst.push_back(deFloat32To16Round(sign * TCU_INFINITY, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
5968 }
5969 removeNotInRange(dst, inputRange, prec);
5970 }
5971
removeNotInRange(vector<deFloat16> & dst,const Interval & inputRange,const Precision) const5972 void DefaultSampling<deFloat16>::removeNotInRange(vector<deFloat16>& dst, const Interval& inputRange, const Precision) const
5973 {
5974 for (vector<deFloat16>::iterator it = dst.begin(); it < dst.end();)
5975 {
5976 if (inputRange.contains(static_cast<double>(*it)))
5977 ++it;
5978 else
5979 it = dst.erase(it);
5980 }
5981 }
5982
5983 template <typename T, int Size>
5984 class DefaultSampling<Vector<T, Size> > : public Sampling<Vector<T, Size> >
5985 {
5986 public:
5987 typedef Vector<T, Size> Value;
5988
genRandom(const FloatFormat & fmt,const Precision prec,Random & rnd,const Interval & inputRange) const5989 Value genRandom (const FloatFormat& fmt, const Precision prec, Random& rnd, const Interval& inputRange) const
5990 {
5991 Value ret;
5992
5993 for (int ndx = 0; ndx < Size; ++ndx)
5994 ret[ndx] = instance<DefaultSampling<T> >().genRandom(fmt, prec, rnd, inputRange);
5995
5996 return ret;
5997 }
5998
genFixeds(const FloatFormat & fmt,const Precision prec,vector<Value> & dst,const Interval & inputRange) const5999 void genFixeds (const FloatFormat& fmt, const Precision prec, vector<Value>& dst, const Interval& inputRange) const
6000 {
6001 vector<T> scalars;
6002
6003 instance<DefaultSampling<T> >().genFixeds(fmt, prec, scalars, inputRange);
6004
6005 for (size_t scalarNdx = 0; scalarNdx < scalars.size(); ++scalarNdx)
6006 dst.push_back(Value(scalars[scalarNdx]));
6007 }
6008 };
6009
6010 template <typename T, int Rows, int Columns>
6011 class DefaultSampling<Matrix<T, Rows, Columns> > : public Sampling<Matrix<T, Rows, Columns> >
6012 {
6013 public:
6014 typedef Matrix<T, Rows, Columns> Value;
6015
genRandom(const FloatFormat & fmt,const Precision prec,Random & rnd,const Interval & inputRange) const6016 Value genRandom (const FloatFormat& fmt, const Precision prec, Random& rnd, const Interval& inputRange) const
6017 {
6018 Value ret;
6019
6020 for (int rowNdx = 0; rowNdx < Rows; ++rowNdx)
6021 for (int colNdx = 0; colNdx < Columns; ++colNdx)
6022 ret(rowNdx, colNdx) = instance<DefaultSampling<T> >().genRandom(fmt, prec, rnd, inputRange);
6023
6024 return ret;
6025 }
6026
genFixeds(const FloatFormat & fmt,const Precision prec,vector<Value> & dst,const Interval & inputRange) const6027 void genFixeds (const FloatFormat& fmt, const Precision prec, vector<Value>& dst, const Interval& inputRange) const
6028 {
6029 vector<T> scalars;
6030
6031 instance<DefaultSampling<T> >().genFixeds(fmt, prec, scalars, inputRange);
6032
6033 for (size_t scalarNdx = 0; scalarNdx < scalars.size(); ++scalarNdx)
6034 dst.push_back(Value(scalars[scalarNdx]));
6035
6036 if (Columns == Rows)
6037 {
6038 Value mat (T(0.0));
6039 T x = T(1.0f);
6040 mat[0][0] = x;
6041 for (int ndx = 0; ndx < Columns; ++ndx)
6042 {
6043 mat[Columns-1-ndx][ndx] = x;
6044 x = static_cast<T>(x * static_cast<T>(2.0f));
6045 }
6046 dst.push_back(mat);
6047 }
6048 }
6049 };
6050
6051 struct CaseContext
6052 {
CaseContextvkt::shaderexecutor::CaseContext6053 CaseContext (const string& name_,
6054 TestContext& testContext_,
6055 const FloatFormat& floatFormat_,
6056 const FloatFormat& highpFormat_,
6057 const Precision precision_,
6058 const ShaderType shaderType_,
6059 const size_t numRandoms_,
6060 const PrecisionTestFeatures precisionTestFeatures_ = PRECISION_TEST_FEATURES_NONE,
6061 const bool isPackFloat16b_ = false,
6062 const bool isFloat64b_ = false)
6063 : name (name_)
6064 , testContext (testContext_)
6065 , floatFormat (floatFormat_)
6066 , highpFormat (highpFormat_)
6067 , precision (precision_)
6068 , shaderType (shaderType_)
6069 , numRandoms (numRandoms_)
6070 , inputRange (-TCU_INFINITY, TCU_INFINITY)
6071 , precisionTestFeatures (precisionTestFeatures_)
6072 , isPackFloat16b (isPackFloat16b_)
6073 , isFloat64b (isFloat64b_)
6074 {}
6075
6076 string name;
6077 TestContext& testContext;
6078 FloatFormat floatFormat;
6079 FloatFormat highpFormat;
6080 Precision precision;
6081 ShaderType shaderType;
6082 size_t numRandoms;
6083 Interval inputRange;
6084 PrecisionTestFeatures precisionTestFeatures;
6085 bool isPackFloat16b;
6086 bool isFloat64b;
6087 };
6088
6089 template<typename In0_ = Void, typename In1_ = Void, typename In2_ = Void, typename In3_ = Void>
6090 struct InTypes
6091 {
6092 typedef In0_ In0;
6093 typedef In1_ In1;
6094 typedef In2_ In2;
6095 typedef In3_ In3;
6096 };
6097
6098 template <typename In>
numInputs(void)6099 int numInputs (void)
6100 {
6101 return (!isTypeValid<typename In::In0>() ? 0 :
6102 !isTypeValid<typename In::In1>() ? 1 :
6103 !isTypeValid<typename In::In2>() ? 2 :
6104 !isTypeValid<typename In::In3>() ? 3 :
6105 4);
6106 }
6107
6108 template<typename Out0_, typename Out1_ = Void>
6109 struct OutTypes
6110 {
6111 typedef Out0_ Out0;
6112 typedef Out1_ Out1;
6113 };
6114
6115 template <typename Out>
numOutputs(void)6116 int numOutputs (void)
6117 {
6118 return (!isTypeValid<typename Out::Out0>() ? 0 :
6119 !isTypeValid<typename Out::Out1>() ? 1 :
6120 2);
6121 }
6122
6123 template<typename In>
6124 struct Inputs
6125 {
6126 vector<typename In::In0> in0;
6127 vector<typename In::In1> in1;
6128 vector<typename In::In2> in2;
6129 vector<typename In::In3> in3;
6130 };
6131
6132 template<typename Out>
6133 struct Outputs
6134 {
Outputsvkt::shaderexecutor::Outputs6135 Outputs (size_t size) : out0(size), out1(size) {}
6136
6137 vector<typename Out::Out0> out0;
6138 vector<typename Out::Out1> out1;
6139 };
6140
6141 template<typename In, typename Out>
6142 struct Variables
6143 {
6144 VariableP<typename In::In0> in0;
6145 VariableP<typename In::In1> in1;
6146 VariableP<typename In::In2> in2;
6147 VariableP<typename In::In3> in3;
6148 VariableP<typename Out::Out0> out0;
6149 VariableP<typename Out::Out1> out1;
6150 };
6151
6152 template<typename In>
6153 struct Samplings
6154 {
Samplingsvkt::shaderexecutor::Samplings6155 Samplings (const Sampling<typename In::In0>& in0_,
6156 const Sampling<typename In::In1>& in1_,
6157 const Sampling<typename In::In2>& in2_,
6158 const Sampling<typename In::In3>& in3_)
6159 : in0 (in0_), in1 (in1_), in2 (in2_), in3 (in3_) {}
6160
6161 const Sampling<typename In::In0>& in0;
6162 const Sampling<typename In::In1>& in1;
6163 const Sampling<typename In::In2>& in2;
6164 const Sampling<typename In::In3>& in3;
6165 };
6166
6167 template<typename In>
6168 struct DefaultSamplings : Samplings<In>
6169 {
DefaultSamplingsvkt::shaderexecutor::DefaultSamplings6170 DefaultSamplings (void)
6171 : Samplings<In>(instance<DefaultSampling<typename In::In0> >(),
6172 instance<DefaultSampling<typename In::In1> >(),
6173 instance<DefaultSampling<typename In::In2> >(),
6174 instance<DefaultSampling<typename In::In3> >()) {}
6175 };
6176
6177 template <typename In, typename Out>
6178 class BuiltinPrecisionCaseTestInstance : public TestInstance
6179 {
6180 public:
BuiltinPrecisionCaseTestInstance(Context & context,const CaseContext caseCtx,const ShaderSpec & shaderSpec,const Variables<In,Out> variables,const Samplings<In> & samplings,const StatementP stmt,bool modularOp=false)6181 BuiltinPrecisionCaseTestInstance (Context& context,
6182 const CaseContext caseCtx,
6183 const ShaderSpec& shaderSpec,
6184 const Variables<In, Out> variables,
6185 const Samplings<In>& samplings,
6186 const StatementP stmt,
6187 bool modularOp = false)
6188 : TestInstance (context)
6189 , m_caseCtx (caseCtx)
6190 , m_variables (variables)
6191 , m_samplings (samplings)
6192 , m_stmt (stmt)
6193 , m_executor (createExecutor(context, caseCtx.shaderType, shaderSpec))
6194 , m_modularOp (modularOp)
6195 {
6196 }
6197 virtual tcu::TestStatus iterate (void);
6198
6199 protected:
6200 CaseContext m_caseCtx;
6201 Variables<In, Out> m_variables;
6202 const Samplings<In>& m_samplings;
6203 StatementP m_stmt;
6204 de::UniquePtr<ShaderExecutor> m_executor;
6205 bool m_modularOp;
6206 };
6207
6208 template<class In, class Out>
iterate(void)6209 tcu::TestStatus BuiltinPrecisionCaseTestInstance<In, Out>::iterate (void)
6210 {
6211 typedef typename In::In0 In0;
6212 typedef typename In::In1 In1;
6213 typedef typename In::In2 In2;
6214 typedef typename In::In3 In3;
6215 typedef typename Out::Out0 Out0;
6216 typedef typename Out::Out1 Out1;
6217
6218 areFeaturesSupported(m_context, m_caseCtx.precisionTestFeatures);
6219 Inputs<In> inputs = generateInputs(m_samplings, m_caseCtx.floatFormat, m_caseCtx.precision, m_caseCtx.numRandoms, 0xdeadbeefu + m_caseCtx.testContext.getCommandLine().getBaseSeed(), m_caseCtx.inputRange);
6220 const FloatFormat& fmt = m_caseCtx.floatFormat;
6221 const int inCount = numInputs<In>();
6222 const int outCount = numOutputs<Out>();
6223 const size_t numValues = (inCount > 0) ? inputs.in0.size() : 1;
6224 Outputs<Out> outputs (numValues);
6225 const FloatFormat highpFmt = m_caseCtx.highpFormat;
6226 const int maxMsgs = 100;
6227 int numErrors = 0;
6228 Environment env; // Hoisted out of the inner loop for optimization.
6229 ResultCollector status;
6230 TestLog& testLog = m_context.getTestContext().getLog();
6231
6232 // Module operations need exactly two inputs and have exactly one output.
6233 if (m_modularOp)
6234 {
6235 DE_ASSERT(inCount == 2);
6236 DE_ASSERT(outCount == 1);
6237 }
6238
6239 const void* inputArr[] =
6240 {
6241 inputs.in0.data(), inputs.in1.data(), inputs.in2.data(), inputs.in3.data(),
6242 };
6243 void* outputArr[] =
6244 {
6245 outputs.out0.data(), outputs.out1.data(),
6246 };
6247
6248 // Print out the statement and its definitions
6249 testLog << TestLog::Message << "Statement: " << m_stmt << TestLog::EndMessage;
6250 {
6251 ostringstream oss;
6252 FuncSet funcs;
6253
6254 m_stmt->getUsedFuncs(funcs);
6255 for (FuncSet::const_iterator it = funcs.begin(); it != funcs.end(); ++it)
6256 {
6257 (*it)->printDefinition(oss);
6258 }
6259 if (!funcs.empty())
6260 testLog << TestLog::Message << "Reference definitions:\n" << oss.str()
6261 << TestLog::EndMessage;
6262 }
6263 switch (inCount)
6264 {
6265 case 4:
6266 DE_ASSERT(inputs.in3.size() == numValues);
6267 // Fallthrough
6268 case 3:
6269 DE_ASSERT(inputs.in2.size() == numValues);
6270 // Fallthrough
6271 case 2:
6272 DE_ASSERT(inputs.in1.size() == numValues);
6273 // Fallthrough
6274 case 1:
6275 DE_ASSERT(inputs.in0.size() == numValues);
6276 // Fallthrough
6277 default:
6278 break;
6279 }
6280
6281 m_executor->execute(int(numValues), inputArr, outputArr);
6282
6283 // Initialize environment with unused values so we don't need to bind in inner loop.
6284 {
6285 const typename Traits<In0>::IVal in0;
6286 const typename Traits<In1>::IVal in1;
6287 const typename Traits<In2>::IVal in2;
6288 const typename Traits<In3>::IVal in3;
6289 const typename Traits<Out0>::IVal reference0;
6290 const typename Traits<Out1>::IVal reference1;
6291
6292 env.bind(*m_variables.in0, in0);
6293 env.bind(*m_variables.in1, in1);
6294 env.bind(*m_variables.in2, in2);
6295 env.bind(*m_variables.in3, in3);
6296 env.bind(*m_variables.out0, reference0);
6297 env.bind(*m_variables.out1, reference1);
6298 }
6299
6300 // For each input tuple, compute output reference interval and compare
6301 // shader output to the reference.
6302 for (size_t valueNdx = 0; valueNdx < numValues; valueNdx++)
6303 {
6304 bool result = true;
6305 const bool isInput16Bit = m_executor->areInputs16Bit();
6306 const bool isInput64Bit = m_executor->areInputs64Bit();
6307
6308 DE_ASSERT(!(isInput16Bit && isInput64Bit));
6309
6310 typename Traits<Out0>::IVal reference0;
6311 typename Traits<Out1>::IVal reference1;
6312
6313 if (valueNdx % (size_t)TOUCH_WATCHDOG_VALUE_FREQUENCY == 0)
6314 m_context.getTestContext().touchWatchdog();
6315
6316 env.lookup(*m_variables.in0) = convert<In0>(fmt, round(fmt, inputs.in0[valueNdx]));
6317 env.lookup(*m_variables.in1) = convert<In1>(fmt, round(fmt, inputs.in1[valueNdx]));
6318 env.lookup(*m_variables.in2) = convert<In2>(fmt, round(fmt, inputs.in2[valueNdx]));
6319 env.lookup(*m_variables.in3) = convert<In3>(fmt, round(fmt, inputs.in3[valueNdx]));
6320
6321 {
6322 EvalContext ctx (fmt, m_caseCtx.precision, env, 0);
6323 m_stmt->execute(ctx);
6324
6325 switch (outCount)
6326 {
6327 case 2:
6328 reference1 = convert<Out1>(highpFmt, env.lookup(*m_variables.out1));
6329 if (!status.check(contains(reference1, outputs.out1[valueNdx], m_caseCtx.isPackFloat16b), "Shader output 1 is outside acceptable range"))
6330 result = false;
6331 // Fallthrough
6332 case 1:
6333 {
6334 // Pass b from mod(a, b) if we are in the modulo operation.
6335 const tcu::Maybe<In1> modularDivisor = (m_modularOp ? tcu::just(inputs.in1[valueNdx]) : tcu::Nothing);
6336
6337 reference0 = convert<Out0>(highpFmt, env.lookup(*m_variables.out0));
6338 if (!status.check(contains(reference0, outputs.out0[valueNdx], m_caseCtx.isPackFloat16b, modularDivisor), "Shader output 0 is outside acceptable range"))
6339 {
6340 m_stmt->failed(ctx);
6341 reference0 = convert<Out0>(highpFmt, env.lookup(*m_variables.out0));
6342 if (!status.check(contains(reference0, outputs.out0[valueNdx], m_caseCtx.isPackFloat16b, modularDivisor), "Shader output 0 is outside acceptable range"))
6343 result = false;
6344 }
6345 }
6346 // Fallthrough
6347 default: break;
6348 }
6349
6350 }
6351 if (!result)
6352 ++numErrors;
6353
6354 if ((!result && numErrors <= maxMsgs) || GLS_LOG_ALL_RESULTS)
6355 {
6356 MessageBuilder builder = testLog.message();
6357
6358 builder << (result ? "Passed" : "Failed") << " sample:\n";
6359
6360 if (inCount > 0)
6361 {
6362 builder << "\t" << m_variables.in0->getName() << " = "
6363 << (isInput64Bit ? value64ToString(highpFmt, inputs.in0[valueNdx]) : (isInput16Bit ? value16ToString(highpFmt, inputs.in0[valueNdx]) : value32ToString(highpFmt, inputs.in0[valueNdx]))) << "\n";
6364 }
6365
6366 if (inCount > 1)
6367 {
6368 builder << "\t" << m_variables.in1->getName() << " = "
6369 << (isInput64Bit ? value64ToString(highpFmt, inputs.in1[valueNdx]) : (isInput16Bit ? value16ToString(highpFmt, inputs.in1[valueNdx]) : value32ToString(highpFmt, inputs.in1[valueNdx]))) << "\n";
6370 }
6371
6372 if (inCount > 2)
6373 {
6374 builder << "\t" << m_variables.in2->getName() << " = "
6375 << (isInput64Bit ? value64ToString(highpFmt, inputs.in2[valueNdx]) : (isInput16Bit ? value16ToString(highpFmt, inputs.in2[valueNdx]) : value32ToString(highpFmt, inputs.in2[valueNdx]))) << "\n";
6376 }
6377
6378 if (inCount > 3)
6379 {
6380 builder << "\t" << m_variables.in3->getName() << " = "
6381 << (isInput64Bit ? value64ToString(highpFmt, inputs.in3[valueNdx]) : (isInput16Bit ? value16ToString(highpFmt, inputs.in3[valueNdx]) : value32ToString(highpFmt, inputs.in3[valueNdx]))) << "\n";
6382 }
6383
6384 if (outCount > 0)
6385 {
6386 if (m_executor->spirvCase() == SPIRV_CASETYPE_COMPARE)
6387 {
6388 builder << "Output:\n"
6389 << comparisonMessage(outputs.out0[valueNdx])
6390 << "Expected result:\n"
6391 << comparisonMessageInterval<typename Out::Out0>(reference0) << "\n";
6392 }
6393 else
6394 {
6395 builder << "\t" << m_variables.out0->getName() << " = "
6396 << (m_executor->isOutput64Bit(0u) ? value64ToString(highpFmt, outputs.out0[valueNdx]) : (m_executor->isOutput16Bit(0u) || m_caseCtx.isPackFloat16b ? value16ToString(highpFmt, outputs.out0[valueNdx]) : value32ToString(highpFmt, outputs.out0[valueNdx]))) << "\n"
6397 << "\tExpected range: "
6398 << intervalToString<typename Out::Out0>(highpFmt, reference0) << "\n";
6399 }
6400 }
6401
6402 if (outCount > 1)
6403 {
6404 builder << "\t" << m_variables.out1->getName() << " = "
6405 << (m_executor->isOutput64Bit(1u) ? value64ToString(highpFmt, outputs.out1[valueNdx]) : (m_executor->isOutput16Bit(1u) || m_caseCtx.isPackFloat16b ? value16ToString(highpFmt, outputs.out1[valueNdx]) : value32ToString(highpFmt, outputs.out1[valueNdx]))) << "\n"
6406 << "\tExpected range: "
6407 << intervalToString<typename Out::Out1>(highpFmt, reference1) << "\n";
6408 }
6409
6410 builder << TestLog::EndMessage;
6411 }
6412 }
6413
6414 if (numErrors > maxMsgs)
6415 {
6416 testLog << TestLog::Message << "(Skipped " << (numErrors - maxMsgs) << " messages.)"
6417 << TestLog::EndMessage;
6418 }
6419
6420 if (numErrors == 0)
6421 {
6422 testLog << TestLog::Message << "All " << numValues << " inputs passed."
6423 << TestLog::EndMessage;
6424 }
6425 else
6426 {
6427 testLog << TestLog::Message << numErrors << "/" << numValues << " inputs failed."
6428 << TestLog::EndMessage;
6429 }
6430
6431 if (numErrors)
6432 return tcu::TestStatus::fail(de::toString(numErrors) + string(" test failed. Check log for the details"));
6433 else
6434 return tcu::TestStatus::pass("Pass");
6435
6436 }
6437
6438 class PrecisionCase : public TestCase
6439 {
6440 protected:
PrecisionCase(const CaseContext & context,const string & name,const Interval & inputRange,const string & extension="")6441 PrecisionCase (const CaseContext& context, const string& name, const Interval& inputRange, const string& extension = "")
6442 : TestCase (context.testContext, name.c_str())
6443 , m_ctx (context)
6444 , m_extension (extension)
6445 {
6446 m_ctx.inputRange = inputRange;
6447 m_spec.packFloat16Bit = context.isPackFloat16b;
6448 }
6449
initPrograms(vk::SourceCollections & programCollection) const6450 virtual void initPrograms (vk::SourceCollections& programCollection) const
6451 {
6452 generateSources(m_ctx.shaderType, m_spec, programCollection);
6453 }
6454
getFormat(void) const6455 const FloatFormat& getFormat (void) const { return m_ctx.floatFormat; }
6456
6457 template <typename In, typename Out>
6458 void testStatement (const Variables<In, Out>& variables, const Statement& stmt, SpirVCaseT spirvCase);
6459
6460 template<typename T>
makeSymbol(const Variable<T> & variable)6461 Symbol makeSymbol (const Variable<T>& variable)
6462 {
6463 return Symbol(variable.getName(), getVarTypeOf<T>(m_ctx.precision));
6464 }
6465
6466 CaseContext m_ctx;
6467 const string m_extension;
6468 ShaderSpec m_spec;
6469 };
6470
6471 template <typename In, typename Out>
testStatement(const Variables<In,Out> & variables,const Statement & stmt,SpirVCaseT spirvCase)6472 void PrecisionCase::testStatement (const Variables<In, Out>& variables, const Statement& stmt, SpirVCaseT spirvCase)
6473 {
6474 const int inCount = numInputs<In>();
6475 const int outCount = numOutputs<Out>();
6476 Environment env; // Hoisted out of the inner loop for optimization.
6477
6478 // Initialize ShaderSpec from precision, variables and statement.
6479 if (m_ctx.precision != glu::PRECISION_LAST)
6480 {
6481 ostringstream os;
6482 os << "precision " << glu::getPrecisionName(m_ctx.precision) << " float;\n";
6483 m_spec.globalDeclarations = os.str();
6484 }
6485
6486 if (!m_extension.empty())
6487 m_spec.globalDeclarations = "#extension " + m_extension + " : require\n";
6488
6489 m_spec.inputs.resize(inCount);
6490
6491 switch (inCount)
6492 {
6493 case 4:
6494 m_spec.inputs[3] = makeSymbol(*variables.in3);
6495 // Fallthrough
6496 case 3:
6497 m_spec.inputs[2] = makeSymbol(*variables.in2);
6498 // Fallthrough
6499 case 2:
6500 m_spec.inputs[1] = makeSymbol(*variables.in1);
6501 // Fallthrough
6502 case 1:
6503 m_spec.inputs[0] = makeSymbol(*variables.in0);
6504 // Fallthrough
6505 default:
6506 break;
6507 }
6508
6509 bool inputs16Bit = false;
6510 for (vector<Symbol>::const_iterator symIter = m_spec.inputs.begin(); symIter != m_spec.inputs.end(); ++symIter)
6511 inputs16Bit = inputs16Bit || glu::isDataTypeFloat16OrVec(symIter->varType.getBasicType());
6512
6513 if (inputs16Bit || m_spec.packFloat16Bit)
6514 m_spec.globalDeclarations += "#extension GL_EXT_shader_explicit_arithmetic_types: require\n";
6515
6516 m_spec.outputs.resize(outCount);
6517
6518 switch (outCount)
6519 {
6520 case 2:
6521 m_spec.outputs[1] = makeSymbol(*variables.out1);
6522 // Fallthrough
6523 case 1:
6524 m_spec.outputs[0] = makeSymbol(*variables.out0);
6525 // Fallthrough
6526 default:
6527 break;
6528 }
6529
6530 m_spec.source = de::toString(stmt);
6531 m_spec.spirvCase = spirvCase;
6532 }
6533
6534 template <typename T>
6535 struct InputLess
6536 {
operator ()vkt::shaderexecutor::InputLess6537 bool operator() (const T& val1, const T& val2) const
6538 {
6539 return val1 < val2;
6540 }
6541 };
6542
6543 template <typename T>
inputLess(const T & val1,const T & val2)6544 bool inputLess (const T& val1, const T& val2)
6545 {
6546 return InputLess<T>()(val1, val2);
6547 }
6548
6549 template <>
6550 struct InputLess<float>
6551 {
operator ()vkt::shaderexecutor::InputLess6552 bool operator() (const float& val1, const float& val2) const
6553 {
6554 if (deIsNaN(val1))
6555 return false;
6556 if (deIsNaN(val2))
6557 return true;
6558 return val1 < val2;
6559 }
6560 };
6561
6562 template <typename T, int Size>
6563 struct InputLess<Vector<T, Size> >
6564 {
operator ()vkt::shaderexecutor::InputLess6565 bool operator() (const Vector<T, Size>& vec1, const Vector<T, Size>& vec2) const
6566 {
6567 for (int ndx = 0; ndx < Size; ++ndx)
6568 {
6569 if (inputLess(vec1[ndx], vec2[ndx]))
6570 return true;
6571 if (inputLess(vec2[ndx], vec1[ndx]))
6572 return false;
6573 }
6574
6575 return false;
6576 }
6577 };
6578
6579 template <typename T, int Rows, int Cols>
6580 struct InputLess<Matrix<T, Rows, Cols> >
6581 {
operator ()vkt::shaderexecutor::InputLess6582 bool operator() (const Matrix<T, Rows, Cols>& mat1,
6583 const Matrix<T, Rows, Cols>& mat2) const
6584 {
6585 for (int col = 0; col < Cols; ++col)
6586 {
6587 if (inputLess(mat1[col], mat2[col]))
6588 return true;
6589 if (inputLess(mat2[col], mat1[col]))
6590 return false;
6591 }
6592
6593 return false;
6594 }
6595 };
6596
6597 template <typename In>
6598 struct InTuple :
6599 public Tuple4<typename In::In0, typename In::In1, typename In::In2, typename In::In3>
6600 {
InTuplevkt::shaderexecutor::InTuple6601 InTuple (const typename In::In0& in0,
6602 const typename In::In1& in1,
6603 const typename In::In2& in2,
6604 const typename In::In3& in3)
6605 : Tuple4<typename In::In0, typename In::In1, typename In::In2, typename In::In3>
6606 (in0, in1, in2, in3) {}
6607 };
6608
6609 template <typename In>
6610 struct InputLess<InTuple<In> >
6611 {
operator ()vkt::shaderexecutor::InputLess6612 bool operator() (const InTuple<In>& in1, const InTuple<In>& in2) const
6613 {
6614 if (inputLess(in1.a, in2.a))
6615 return true;
6616 if (inputLess(in2.a, in1.a))
6617 return false;
6618 if (inputLess(in1.b, in2.b))
6619 return true;
6620 if (inputLess(in2.b, in1.b))
6621 return false;
6622 if (inputLess(in1.c, in2.c))
6623 return true;
6624 if (inputLess(in2.c, in1.c))
6625 return false;
6626 if (inputLess(in1.d, in2.d))
6627 return true;
6628 return false;
6629 }
6630 };
6631
6632 template<typename In>
generateInputs(const Samplings<In> & samplings,const FloatFormat & floatFormat,Precision intPrecision,size_t numSamples,deUint32 seed,const Interval & inputRange)6633 Inputs<In> generateInputs (const Samplings<In>& samplings,
6634 const FloatFormat& floatFormat,
6635 Precision intPrecision,
6636 size_t numSamples,
6637 deUint32 seed,
6638 const Interval& inputRange)
6639 {
6640 Random rnd(seed);
6641 Inputs<In> ret;
6642 Inputs<In> fixedInputs;
6643 set<InTuple<In>, InputLess<InTuple<In> > > seenInputs;
6644
6645 samplings.in0.genFixeds(floatFormat, intPrecision, fixedInputs.in0, inputRange);
6646 samplings.in1.genFixeds(floatFormat, intPrecision, fixedInputs.in1, inputRange);
6647 samplings.in2.genFixeds(floatFormat, intPrecision, fixedInputs.in2, inputRange);
6648 samplings.in3.genFixeds(floatFormat, intPrecision, fixedInputs.in3, inputRange);
6649
6650 for (size_t ndx0 = 0; ndx0 < fixedInputs.in0.size(); ++ndx0)
6651 {
6652 for (size_t ndx1 = 0; ndx1 < fixedInputs.in1.size(); ++ndx1)
6653 {
6654 for (size_t ndx2 = 0; ndx2 < fixedInputs.in2.size(); ++ndx2)
6655 {
6656 for (size_t ndx3 = 0; ndx3 < fixedInputs.in3.size(); ++ndx3)
6657 {
6658 const InTuple<In> tuple (fixedInputs.in0[ndx0],
6659 fixedInputs.in1[ndx1],
6660 fixedInputs.in2[ndx2],
6661 fixedInputs.in3[ndx3]);
6662
6663 seenInputs.insert(tuple);
6664 ret.in0.push_back(tuple.a);
6665 ret.in1.push_back(tuple.b);
6666 ret.in2.push_back(tuple.c);
6667 ret.in3.push_back(tuple.d);
6668 }
6669 }
6670 }
6671 }
6672
6673 for (size_t ndx = 0; ndx < numSamples; ++ndx)
6674 {
6675 const typename In::In0 in0 = samplings.in0.genRandom(floatFormat, intPrecision, rnd, inputRange);
6676 const typename In::In1 in1 = samplings.in1.genRandom(floatFormat, intPrecision, rnd, inputRange);
6677 const typename In::In2 in2 = samplings.in2.genRandom(floatFormat, intPrecision, rnd, inputRange);
6678 const typename In::In3 in3 = samplings.in3.genRandom(floatFormat, intPrecision, rnd, inputRange);
6679 const InTuple<In> tuple (in0, in1, in2, in3);
6680
6681 if (de::contains(seenInputs, tuple))
6682 continue;
6683
6684 seenInputs.insert(tuple);
6685 ret.in0.push_back(in0);
6686 ret.in1.push_back(in1);
6687 ret.in2.push_back(in2);
6688 ret.in3.push_back(in3);
6689 }
6690
6691 return ret;
6692 }
6693
6694 class FuncCaseBase : public PrecisionCase
6695 {
6696 protected:
FuncCaseBase(const CaseContext & context,const string & name,const FuncBase & func)6697 FuncCaseBase (const CaseContext& context, const string& name, const FuncBase& func)
6698 : PrecisionCase (context, name, func.getInputRange(!context.isFloat64b && (context.precision == glu::PRECISION_LAST || context.isPackFloat16b)), func.getRequiredExtension())
6699 {
6700 }
6701
6702 StatementP m_stmt;
6703 };
6704
6705 template <typename Sig>
6706 class FuncCase : public FuncCaseBase
6707 {
6708 public:
6709 typedef Func<Sig> CaseFunc;
6710 typedef typename Sig::Ret Ret;
6711 typedef typename Sig::Arg0 Arg0;
6712 typedef typename Sig::Arg1 Arg1;
6713 typedef typename Sig::Arg2 Arg2;
6714 typedef typename Sig::Arg3 Arg3;
6715 typedef InTypes<Arg0, Arg1, Arg2, Arg3> In;
6716 typedef OutTypes<Ret> Out;
6717
FuncCase(const CaseContext & context,const string & name,const CaseFunc & func,bool modularOp=false)6718 FuncCase (const CaseContext& context, const string& name, const CaseFunc& func, bool modularOp = false)
6719 : FuncCaseBase (context, name, func)
6720 , m_func (func)
6721 , m_modularOp (modularOp)
6722 {
6723 buildTest();
6724 }
6725
createInstance(Context & context) const6726 virtual TestInstance* createInstance (Context& context) const
6727 {
6728 return new BuiltinPrecisionCaseTestInstance<In, Out>(context, m_ctx, m_spec, m_variables, getSamplings(), m_stmt, m_modularOp);
6729 }
6730
6731 protected:
6732 void buildTest (void);
getSamplings(void) const6733 virtual const Samplings<In>& getSamplings (void) const
6734 {
6735 return instance<DefaultSamplings<In> >();
6736 }
6737
6738 private:
6739 const CaseFunc& m_func;
6740 Variables<In, Out> m_variables;
6741 bool m_modularOp;
6742 };
6743
6744 template <typename Sig>
buildTest(void)6745 void FuncCase<Sig>::buildTest (void)
6746 {
6747 m_variables.out0 = variable<Ret>("out0");
6748 m_variables.out1 = variable<Void>("out1");
6749 m_variables.in0 = variable<Arg0>("in0");
6750 m_variables.in1 = variable<Arg1>("in1");
6751 m_variables.in2 = variable<Arg2>("in2");
6752 m_variables.in3 = variable<Arg3>("in3");
6753
6754 {
6755 ExprP<Ret> expr = applyVar(m_func, m_variables.in0, m_variables.in1, m_variables.in2, m_variables.in3);
6756 m_stmt = variableAssignment(m_variables.out0, expr);
6757
6758 this->testStatement(m_variables, *m_stmt, m_func.getSpirvCase());
6759 }
6760 }
6761
6762 template <typename Sig>
6763 class InOutFuncCase : public FuncCaseBase
6764 {
6765 public:
6766 typedef Func<Sig> CaseFunc;
6767 typedef typename Sig::Ret Ret;
6768 typedef typename Sig::Arg0 Arg0;
6769 typedef typename Sig::Arg1 Arg1;
6770 typedef typename Sig::Arg2 Arg2;
6771 typedef typename Sig::Arg3 Arg3;
6772 typedef InTypes<Arg0, Arg2, Arg3> In;
6773 typedef OutTypes<Ret, Arg1> Out;
6774
InOutFuncCase(const CaseContext & context,const string & name,const CaseFunc & func,bool modularOp=false)6775 InOutFuncCase (const CaseContext& context, const string& name, const CaseFunc& func, bool modularOp = false)
6776 : FuncCaseBase (context, name, func)
6777 , m_func (func)
6778 , m_modularOp (modularOp)
6779 {
6780 buildTest();
6781 }
createInstance(Context & context) const6782 virtual TestInstance* createInstance (Context& context) const
6783 {
6784 return new BuiltinPrecisionCaseTestInstance<In, Out>(context, m_ctx, m_spec, m_variables, getSamplings(), m_stmt, m_modularOp);
6785 }
6786
6787 protected:
6788 void buildTest (void);
getSamplings(void) const6789 virtual const Samplings<In>& getSamplings (void) const
6790 {
6791 return instance<DefaultSamplings<In> >();
6792 }
6793
6794 private:
6795 const CaseFunc& m_func;
6796 Variables<In, Out> m_variables;
6797 bool m_modularOp;
6798 };
6799
6800 template <typename Sig>
buildTest(void)6801 void InOutFuncCase<Sig>::buildTest (void)
6802 {
6803 m_variables.out0 = variable<Ret>("out0");
6804 m_variables.out1 = variable<Arg1>("out1");
6805 m_variables.in0 = variable<Arg0>("in0");
6806 m_variables.in1 = variable<Arg2>("in1");
6807 m_variables.in2 = variable<Arg3>("in2");
6808 m_variables.in3 = variable<Void>("in3");
6809
6810 {
6811 ExprP<Ret> expr = applyVar(m_func, m_variables.in0, m_variables.out1, m_variables.in1, m_variables.in2);
6812 m_stmt = variableAssignment(m_variables.out0, expr);
6813
6814 this->testStatement(m_variables, *m_stmt, m_func.getSpirvCase());
6815 }
6816 }
6817
6818 template <typename Sig>
createFuncCase(const CaseContext & context,const string & name,const Func<Sig> & func,bool modularOp=false)6819 PrecisionCase* createFuncCase (const CaseContext& context, const string& name, const Func<Sig>& func, bool modularOp = false)
6820 {
6821 switch (func.getOutParamIndex())
6822 {
6823 case -1:
6824 return new FuncCase<Sig>(context, name, func, modularOp);
6825 case 1:
6826 return new InOutFuncCase<Sig>(context, name, func, modularOp);
6827 default:
6828 DE_FATAL("Impossible");
6829 }
6830 return DE_NULL;
6831 }
6832
6833 class CaseFactory
6834 {
6835 public:
~CaseFactory(void)6836 virtual ~CaseFactory (void) {}
6837 virtual MovePtr<TestNode> createCase (const CaseContext& ctx) const = 0;
6838 virtual string getName (void) const = 0;
6839 virtual string getDesc (void) const = 0;
6840 };
6841
6842 class FuncCaseFactory : public CaseFactory
6843 {
6844 public:
6845 virtual const FuncBase& getFunc (void) const = 0;
getName(void) const6846 string getName (void) const { return de::toLower(getFunc().getName()); }
getDesc(void) const6847 string getDesc (void) const { return "Function '" + getFunc().getName() + "'"; }
6848 };
6849
6850 template <typename Sig>
6851 class GenFuncCaseFactory : public CaseFactory
6852 {
6853 public:
GenFuncCaseFactory(const GenFuncs<Sig> & funcs,const string & name,bool modularOp=false)6854 GenFuncCaseFactory (const GenFuncs<Sig>& funcs, const string& name, bool modularOp = false)
6855 : m_funcs (funcs)
6856 , m_name (de::toLower(name))
6857 , m_modularOp (modularOp)
6858 {
6859 }
6860
createCase(const CaseContext & ctx) const6861 MovePtr<TestNode> createCase (const CaseContext& ctx) const
6862 {
6863 TestCaseGroup* group = new TestCaseGroup(ctx.testContext, ctx.name.c_str());
6864
6865 group->addChild(createFuncCase(ctx, "scalar", m_funcs.func, m_modularOp));
6866 group->addChild(createFuncCase(ctx, "vec2", m_funcs.func2, m_modularOp));
6867 group->addChild(createFuncCase(ctx, "vec3", m_funcs.func3, m_modularOp));
6868 group->addChild(createFuncCase(ctx, "vec4", m_funcs.func4, m_modularOp));
6869 return MovePtr<TestNode>(group);
6870 }
6871
getName(void) const6872 string getName (void) const { return m_name; }
getDesc(void) const6873 string getDesc (void) const { return "Function '" + m_funcs.func.getName() + "'"; }
6874
6875 private:
6876 const GenFuncs<Sig> m_funcs;
6877 string m_name;
6878 bool m_modularOp;
6879 };
6880
6881 template <template <int, class> class GenF, typename T>
6882 class TemplateFuncCaseFactory : public FuncCaseFactory
6883 {
6884 public:
createCase(const CaseContext & ctx) const6885 MovePtr<TestNode> createCase (const CaseContext& ctx) const
6886 {
6887 TestCaseGroup* group = new TestCaseGroup(ctx.testContext, ctx.name.c_str());
6888
6889 group->addChild(createFuncCase(ctx, "scalar", instance<GenF<1, T> >()));
6890 group->addChild(createFuncCase(ctx, "vec2", instance<GenF<2, T> >()));
6891 group->addChild(createFuncCase(ctx, "vec3", instance<GenF<3, T> >()));
6892 group->addChild(createFuncCase(ctx, "vec4", instance<GenF<4, T> >()));
6893
6894 return MovePtr<TestNode>(group);
6895 }
6896
getFunc(void) const6897 const FuncBase& getFunc (void) const { return instance<GenF<1, T> >(); }
6898 };
6899
6900 #ifndef CTS_USES_VULKANSC
6901 template <template <int> class GenF>
6902 class SquareMatrixFuncCaseFactory : public FuncCaseFactory
6903 {
6904 public:
createCase(const CaseContext & ctx) const6905 MovePtr<TestNode> createCase (const CaseContext& ctx) const
6906 {
6907 TestCaseGroup* group = new TestCaseGroup(ctx.testContext, ctx.name.c_str());
6908
6909 group->addChild(createFuncCase(ctx, "mat2", instance<GenF<2> >()));
6910
6911 // There is no defined precision for mediump/RelaxedPrecision in Vulkan
6912 if (ctx.name != "mediump")
6913 {
6914 static const char dataDir[] = "builtin/precision/square_matrix";
6915 std::string fileName = getFunc().getName() + "_" + ctx.name;
6916 std::vector<std::string> requirements;
6917
6918 if (ctx.name == "compute")
6919 {
6920 if (ctx.isFloat64b)
6921 {
6922 requirements.push_back("Features.shaderFloat64");
6923 fileName += "_fp64";
6924 }
6925 else
6926 {
6927 requirements.push_back("Float16Int8Features.shaderFloat16");
6928 requirements.push_back("VK_KHR_16bit_storage");
6929 requirements.push_back("VK_KHR_storage_buffer_storage_class");
6930 fileName += "_fp16";
6931
6932 if (ctx.isPackFloat16b == true)
6933 {
6934 fileName += "_32bit";
6935 }
6936 else
6937 {
6938 requirements.push_back("Storage16BitFeatures.storageBuffer16BitAccess");
6939 }
6940 }
6941 }
6942
6943 group->addChild(cts_amber::createAmberTestCase(ctx.testContext, "mat3", "Square matrix 3x3 precision tests", dataDir, fileName + "_mat_3x3.amber", requirements));
6944 group->addChild(cts_amber::createAmberTestCase(ctx.testContext, "mat4", "Square matrix 4x4 precision tests", dataDir, fileName + "_mat_4x4.amber", requirements));
6945 }
6946
6947 return MovePtr<TestNode>(group);
6948 }
6949
getFunc(void) const6950 const FuncBase& getFunc (void) const { return instance<GenF<2> >(); }
6951 };
6952 #endif // CTS_USES_VULKANSC
6953
6954 template <template <int, int, class> class GenF, typename T>
6955 class MatrixFuncCaseFactory : public FuncCaseFactory
6956 {
6957 public:
createCase(const CaseContext & ctx) const6958 MovePtr<TestNode> createCase (const CaseContext& ctx) const
6959 {
6960 TestCaseGroup* const group = new TestCaseGroup(ctx.testContext, ctx.name.c_str());
6961
6962 this->addCase<2, 2>(ctx, group);
6963 this->addCase<3, 2>(ctx, group);
6964 this->addCase<4, 2>(ctx, group);
6965 this->addCase<2, 3>(ctx, group);
6966 this->addCase<3, 3>(ctx, group);
6967 this->addCase<4, 3>(ctx, group);
6968 this->addCase<2, 4>(ctx, group);
6969 this->addCase<3, 4>(ctx, group);
6970 this->addCase<4, 4>(ctx, group);
6971
6972 return MovePtr<TestNode>(group);
6973 }
6974
getFunc(void) const6975 const FuncBase& getFunc (void) const { return instance<GenF<2,2, T> >(); }
6976
6977 private:
6978 template <int Rows, int Cols>
addCase(const CaseContext & ctx,TestCaseGroup * group) const6979 void addCase (const CaseContext& ctx, TestCaseGroup* group) const
6980 {
6981 const char* const name = dataTypeNameOf<Matrix<float, Rows, Cols> >();
6982 group->addChild(createFuncCase(ctx, name, instance<GenF<Rows, Cols, T> >()));
6983 }
6984 };
6985
6986 template <typename Sig>
6987 class SimpleFuncCaseFactory : public CaseFactory
6988 {
6989 public:
SimpleFuncCaseFactory(const Func<Sig> & func)6990 SimpleFuncCaseFactory (const Func<Sig>& func) : m_func(func) {}
6991
createCase(const CaseContext & ctx) const6992 MovePtr<TestNode> createCase (const CaseContext& ctx) const { return MovePtr<TestNode>(createFuncCase(ctx, ctx.name.c_str(), m_func)); }
getName(void) const6993 string getName (void) const { return de::toLower(m_func.getName()); }
getDesc(void) const6994 string getDesc (void) const { return "Function '" + getName() + "'"; }
6995
6996 private:
6997 const Func<Sig>& m_func;
6998 };
6999
7000 template <typename F>
createSimpleFuncCaseFactory(void)7001 SharedPtr<SimpleFuncCaseFactory<typename F::Sig> > createSimpleFuncCaseFactory (void)
7002 {
7003 return SharedPtr<SimpleFuncCaseFactory<typename F::Sig> >(new SimpleFuncCaseFactory<typename F::Sig>(instance<F>()));
7004 }
7005
7006 class CaseFactories
7007 {
7008 public:
~CaseFactories(void)7009 virtual ~CaseFactories (void) {}
7010 virtual const std::vector<const CaseFactory*> getFactories (void) const = 0;
7011 };
7012
7013 class BuiltinFuncs : public CaseFactories
7014 {
7015 public:
getFactories(void) const7016 const vector<const CaseFactory*> getFactories (void) const
7017 {
7018 vector<const CaseFactory*> ret;
7019
7020 for (size_t ndx = 0; ndx < m_factories.size(); ++ndx)
7021 ret.push_back(m_factories[ndx].get());
7022
7023 return ret;
7024 }
7025
addFactory(SharedPtr<const CaseFactory> fact)7026 void addFactory (SharedPtr<const CaseFactory> fact) { m_factories.push_back(fact); }
7027
7028 private:
7029 vector<SharedPtr<const CaseFactory> > m_factories;
7030 };
7031
7032 template <typename F>
addScalarFactory(BuiltinFuncs & funcs,string name="",bool modularOp=false)7033 void addScalarFactory (BuiltinFuncs& funcs, string name = "", bool modularOp = false)
7034 {
7035 if (name.empty())
7036 name = instance<F>().getName();
7037
7038 funcs.addFactory(SharedPtr<const CaseFactory>(new GenFuncCaseFactory<typename F::Sig>(makeVectorizedFuncs<F>(), name, modularOp)));
7039 }
7040
createBuiltinCases()7041 MovePtr<const CaseFactories> createBuiltinCases ()
7042 {
7043 MovePtr<BuiltinFuncs> funcs (new BuiltinFuncs());
7044
7045 // Tests for ES3 builtins
7046 addScalarFactory<Comparison< Signature<int, float, float> > >(*funcs);
7047 addScalarFactory<Add< Signature<float, float, float> > >(*funcs);
7048 addScalarFactory<Sub< Signature<float, float, float> > >(*funcs);
7049 addScalarFactory<Mul< Signature<float, float, float> > >(*funcs);
7050 addScalarFactory<Div< Signature<float, float, float> > >(*funcs);
7051
7052 addScalarFactory<Radians>(*funcs);
7053 addScalarFactory<Degrees>(*funcs);
7054 addScalarFactory<Sin<Signature<float, float> > >(*funcs);
7055 addScalarFactory<Cos<Signature<float, float> > >(*funcs);
7056 addScalarFactory<Tan>(*funcs);
7057
7058 addScalarFactory<ASin>(*funcs);
7059 addScalarFactory<ACos>(*funcs);
7060 addScalarFactory<ATan2< Signature<float, float, float> > >(*funcs, "atan2");
7061 addScalarFactory<ATan<Signature<float, float> > >(*funcs);
7062 addScalarFactory<Sinh>(*funcs);
7063 addScalarFactory<Cosh>(*funcs);
7064 addScalarFactory<Tanh>(*funcs);
7065 addScalarFactory<ASinh>(*funcs);
7066 addScalarFactory<ACosh>(*funcs);
7067 addScalarFactory<ATanh>(*funcs);
7068
7069 addScalarFactory<Pow>(*funcs);
7070 addScalarFactory<Exp<Signature<float, float> > >(*funcs);
7071 addScalarFactory<Log< Signature<float, float> > >(*funcs);
7072 addScalarFactory<Exp2<Signature<float, float> > >(*funcs);
7073 addScalarFactory<Log2< Signature<float, float> > >(*funcs);
7074 addScalarFactory<Sqrt32Bit>(*funcs);
7075 addScalarFactory<InverseSqrt< Signature<float, float> > >(*funcs);
7076
7077 addScalarFactory<Abs< Signature<float, float> > >(*funcs);
7078 addScalarFactory<Sign< Signature<float, float> > >(*funcs);
7079 addScalarFactory<Floor32Bit>(*funcs);
7080 addScalarFactory<Trunc32Bit>(*funcs);
7081 addScalarFactory<Round< Signature<float, float> > >(*funcs);
7082 addScalarFactory<RoundEven< Signature<float, float> > >(*funcs);
7083 addScalarFactory<Ceil< Signature<float, float> > >(*funcs);
7084 addScalarFactory<Fract>(*funcs);
7085
7086 addScalarFactory<Mod32Bit>(*funcs, "mod", true);
7087 addScalarFactory<FRem32Bit>(*funcs);
7088
7089 addScalarFactory<Modf32Bit>(*funcs);
7090 addScalarFactory<ModfStruct32Bit>(*funcs);
7091 addScalarFactory<Min< Signature<float, float, float> > >(*funcs);
7092 addScalarFactory<Max< Signature<float, float, float> > >(*funcs);
7093 addScalarFactory<Clamp< Signature<float, float, float, float> > >(*funcs);
7094 addScalarFactory<Mix>(*funcs);
7095 addScalarFactory<Step< Signature<float, float, float> > >(*funcs);
7096 addScalarFactory<SmoothStep< Signature<float, float, float, float> > >(*funcs);
7097
7098 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Length, float>()));
7099 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Distance, float>()));
7100 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Dot, float>()));
7101 funcs->addFactory(createSimpleFuncCaseFactory<Cross>());
7102 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Normalize, float>()));
7103 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<FaceForward, float>()));
7104 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Reflect, float>()));
7105 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Refract, float>()));
7106
7107 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<MatrixCompMult, float>()));
7108 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<OuterProduct, float>()));
7109 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<Transpose, float>()));
7110 #ifndef CTS_USES_VULKANSC
7111 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Determinant>()));
7112 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Inverse>()));
7113 #endif // CTS_USES_VULKANSC
7114
7115 addScalarFactory<Frexp32Bit>(*funcs);
7116 addScalarFactory<FrexpStruct32Bit>(*funcs);
7117 addScalarFactory<LdExp <Signature<float, float, int> > >(*funcs);
7118 addScalarFactory<Fma <Signature<float, float, float, float> > >(*funcs);
7119
7120 return MovePtr<const CaseFactories>(funcs.release());
7121 }
7122
createBuiltinDoubleCases()7123 MovePtr<const CaseFactories> createBuiltinDoubleCases ()
7124 {
7125 MovePtr<BuiltinFuncs> funcs (new BuiltinFuncs());
7126
7127 // Tests for ES3 builtins
7128 addScalarFactory<Comparison<Signature<int, double, double>>>(*funcs);
7129 addScalarFactory<Add<Signature<double, double, double>>>(*funcs);
7130 addScalarFactory<Sub<Signature<double, double, double>>>(*funcs);
7131 addScalarFactory<Mul<Signature<double, double, double>>>(*funcs);
7132 addScalarFactory<Div<Signature<double, double, double>>>(*funcs);
7133
7134 // Radians, degrees, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh, atan2, pow, exp, log, exp2 and log2
7135 // only work with 16-bit and 32-bit floating point types according to the spec.
7136
7137 addScalarFactory<Sqrt64Bit>(*funcs);
7138 addScalarFactory<InverseSqrt<Signature<double, double>>>(*funcs);
7139
7140 addScalarFactory<Abs<Signature<double, double>>>(*funcs);
7141 addScalarFactory<Sign<Signature<double, double>>>(*funcs);
7142 addScalarFactory<Floor64Bit>(*funcs);
7143 addScalarFactory<Trunc64Bit>(*funcs);
7144 addScalarFactory<Round<Signature<double, double>>>(*funcs);
7145 addScalarFactory<RoundEven<Signature<double, double>>>(*funcs);
7146 addScalarFactory<Ceil<Signature<double, double>>>(*funcs);
7147 addScalarFactory<Fract64Bit>(*funcs);
7148
7149 addScalarFactory<Mod64Bit>(*funcs, "mod", true);
7150 addScalarFactory<FRem64Bit>(*funcs);
7151
7152 addScalarFactory<Modf64Bit>(*funcs);
7153 addScalarFactory<ModfStruct64Bit>(*funcs);
7154 addScalarFactory<Min<Signature<double, double, double>>>(*funcs);
7155 addScalarFactory<Max<Signature<double, double, double>>>(*funcs);
7156 addScalarFactory<Clamp<Signature<double, double, double, double>>>(*funcs);
7157 addScalarFactory<Mix64Bit>(*funcs);
7158 addScalarFactory<Step<Signature<double, double, double>>>(*funcs);
7159 addScalarFactory<SmoothStep<Signature<double, double, double, double>>>(*funcs);
7160
7161 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Length, double>()));
7162 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Distance, double>()));
7163 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Dot, double>()));
7164 funcs->addFactory(createSimpleFuncCaseFactory<Cross64Bit>());
7165 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Normalize, double>()));
7166 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<FaceForward, double>()));
7167 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Reflect, double>()));
7168 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Refract, double>()));
7169
7170 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<MatrixCompMult, double>()));
7171 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<OuterProduct, double>()));
7172 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<Transpose, double>()));
7173 #ifndef CTS_USES_VULKANSC
7174 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Determinant64bit>()));
7175 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Inverse64bit>()));
7176 #endif // CTS_USES_VULKANSC
7177
7178 addScalarFactory<Frexp64Bit>(*funcs);
7179 addScalarFactory<FrexpStruct64Bit>(*funcs);
7180 addScalarFactory<LdExp<Signature<double, double, int>>>(*funcs);
7181 addScalarFactory<Fma<Signature<double, double, double, double>>>(*funcs);
7182
7183 return MovePtr<const CaseFactories>(funcs.release());
7184 }
7185
createBuiltinCases16Bit(void)7186 MovePtr<const CaseFactories> createBuiltinCases16Bit(void)
7187 {
7188 MovePtr<BuiltinFuncs> funcs(new BuiltinFuncs());
7189
7190 addScalarFactory<Comparison< Signature<int, deFloat16, deFloat16> > >(*funcs);
7191 addScalarFactory<Add< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7192 addScalarFactory<Sub< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7193 addScalarFactory<Mul< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7194 addScalarFactory<Div< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7195
7196 addScalarFactory<Radians16>(*funcs);
7197 addScalarFactory<Degrees16>(*funcs);
7198
7199 addScalarFactory<Sin<Signature<deFloat16, deFloat16> > >(*funcs);
7200 addScalarFactory<Cos<Signature<deFloat16, deFloat16> > >(*funcs);
7201 addScalarFactory<Tan16Bit>(*funcs);
7202 addScalarFactory<ASin16Bit>(*funcs);
7203 addScalarFactory<ACos16Bit>(*funcs);
7204 addScalarFactory<ATan2< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs, "atan2");
7205 addScalarFactory<ATan<Signature<deFloat16, deFloat16> > >(*funcs);
7206
7207 addScalarFactory<Sinh16Bit>(*funcs);
7208 addScalarFactory<Cosh16Bit>(*funcs);
7209 addScalarFactory<Tanh16Bit>(*funcs);
7210 addScalarFactory<ASinh16Bit>(*funcs);
7211 addScalarFactory<ACosh16Bit>(*funcs);
7212 addScalarFactory<ATanh16Bit>(*funcs);
7213
7214 addScalarFactory<Pow16>(*funcs);
7215 addScalarFactory<Exp< Signature<deFloat16, deFloat16> > >(*funcs);
7216 addScalarFactory<Log< Signature<deFloat16, deFloat16> > >(*funcs);
7217 addScalarFactory<Exp2< Signature<deFloat16, deFloat16> > >(*funcs);
7218 addScalarFactory<Log2< Signature<deFloat16, deFloat16> > >(*funcs);
7219 addScalarFactory<Sqrt16Bit>(*funcs);
7220 addScalarFactory<InverseSqrt16Bit>(*funcs);
7221
7222 addScalarFactory<Abs< Signature<deFloat16, deFloat16> > >(*funcs);
7223 addScalarFactory<Sign< Signature<deFloat16, deFloat16> > >(*funcs);
7224 addScalarFactory<Floor16Bit>(*funcs);
7225 addScalarFactory<Trunc16Bit>(*funcs);
7226 addScalarFactory<Round< Signature<deFloat16, deFloat16> > >(*funcs);
7227 addScalarFactory<RoundEven< Signature<deFloat16, deFloat16> > >(*funcs);
7228 addScalarFactory<Ceil< Signature<deFloat16, deFloat16> > >(*funcs);
7229 addScalarFactory<Fract16Bit>(*funcs);
7230
7231 addScalarFactory<Mod16Bit>(*funcs, "mod", true);
7232 addScalarFactory<FRem16Bit>(*funcs);
7233
7234 addScalarFactory<Modf16Bit>(*funcs);
7235 addScalarFactory<ModfStruct16Bit>(*funcs);
7236 addScalarFactory<Min< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7237 addScalarFactory<Max< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7238 addScalarFactory<Clamp< Signature<deFloat16, deFloat16, deFloat16, deFloat16> > >(*funcs);
7239 addScalarFactory<Mix16Bit>(*funcs);
7240 addScalarFactory<Step< Signature<deFloat16, deFloat16, deFloat16> > >(*funcs);
7241 addScalarFactory<SmoothStep< Signature<deFloat16, deFloat16, deFloat16, deFloat16> > >(*funcs);
7242
7243 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Length, deFloat16>()));
7244 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Distance, deFloat16>()));
7245 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Dot, deFloat16>()));
7246 funcs->addFactory(createSimpleFuncCaseFactory<Cross16Bit>());
7247 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Normalize, deFloat16>()));
7248 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<FaceForward, deFloat16>()));
7249 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Reflect, deFloat16>()));
7250 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Refract, deFloat16>()));
7251
7252 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<OuterProduct, deFloat16>()));
7253 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<Transpose, deFloat16>()));
7254 #ifndef CTS_USES_VULKANSC
7255 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Determinant16bit>()));
7256 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Inverse16bit>()));
7257 #endif // CTS_USES_VULKANSC
7258
7259 addScalarFactory<Frexp16Bit>(*funcs);
7260 addScalarFactory<FrexpStruct16Bit>(*funcs);
7261 addScalarFactory<LdExp <Signature<deFloat16, deFloat16, int> > >(*funcs);
7262 addScalarFactory<Fma <Signature<deFloat16, deFloat16, deFloat16, deFloat16> > >(*funcs);
7263
7264 return MovePtr<const CaseFactories>(funcs.release());
7265 }
7266
createFuncGroup(TestContext & ctx,const CaseFactory & factory,int numRandoms)7267 TestCaseGroup* createFuncGroup (TestContext& ctx, const CaseFactory& factory, int numRandoms)
7268 {
7269 TestCaseGroup* const group = new TestCaseGroup(ctx, factory.getName().c_str());
7270 const FloatFormat highp (-126, 127, 23, true,
7271 tcu::MAYBE, // subnormals
7272 tcu::YES, // infinities
7273 tcu::MAYBE); // NaN
7274 const FloatFormat mediump (-14, 13, 10, false, tcu::MAYBE);
7275
7276 for (int precNdx = glu::PRECISION_MEDIUMP; precNdx < glu::PRECISION_LAST; ++precNdx)
7277 {
7278 const Precision precision = Precision(precNdx);
7279 const string precName (glu::getPrecisionName(precision));
7280 const FloatFormat& fmt = precNdx == glu::PRECISION_MEDIUMP ? mediump : highp;
7281
7282 const CaseContext caseCtx (precName, ctx, fmt, highp, precision, glu::SHADERTYPE_COMPUTE, numRandoms);
7283
7284 group->addChild(factory.createCase(caseCtx).release());
7285 }
7286
7287 return group;
7288 }
7289
createFuncGroupDouble(TestContext & ctx,const CaseFactory & factory,int numRandoms)7290 TestCaseGroup* createFuncGroupDouble (TestContext& ctx, const CaseFactory& factory, int numRandoms)
7291 {
7292 TestCaseGroup* const group = new TestCaseGroup(ctx, factory.getName().c_str());
7293 const Precision precision = Precision(glu::PRECISION_LAST);
7294 const FloatFormat highp (-1022, 1023, 52, true,
7295 tcu::MAYBE, // subnormals
7296 tcu::YES, // infinities
7297 tcu::MAYBE); // NaN
7298
7299 PrecisionTestFeatures precisionTestFeatures = PRECISION_TEST_FEATURES_64BIT_SHADER_FLOAT;
7300
7301 const CaseContext caseCtx("compute", ctx, highp, highp, precision, glu::SHADERTYPE_COMPUTE, numRandoms, precisionTestFeatures, false, true);
7302 group->addChild(factory.createCase(caseCtx).release());
7303
7304 return group;
7305 }
7306
createFuncGroup16Bit(TestContext & ctx,const CaseFactory & factory,int numRandoms,bool storage32)7307 TestCaseGroup* createFuncGroup16Bit(TestContext& ctx, const CaseFactory& factory, int numRandoms, bool storage32)
7308 {
7309 TestCaseGroup* const group = new TestCaseGroup(ctx, factory.getName().c_str());
7310 const Precision precision = Precision(glu::PRECISION_LAST);
7311 const FloatFormat float16 (-14, 15, 10, true, tcu::MAYBE);
7312
7313 PrecisionTestFeatures precisionTestFeatures = PRECISION_TEST_FEATURES_16BIT_SHADER_FLOAT;
7314 if (!storage32)
7315 precisionTestFeatures |= PRECISION_TEST_FEATURES_16BIT_UNIFORM_AND_STORAGE_BUFFER_ACCESS;
7316
7317 const CaseContext caseCtx("compute", ctx, float16, float16, precision, glu::SHADERTYPE_COMPUTE, numRandoms, precisionTestFeatures, storage32);
7318 group->addChild(factory.createCase(caseCtx).release());
7319
7320 return group;
7321 }
7322
7323 const int defRandoms = 16384;
7324
addBuiltinPrecisionTests(TestContext & ctx,TestCaseGroup & dstGroup,const bool test16Bit=false,const bool storage32Bit=false)7325 void addBuiltinPrecisionTests (TestContext& ctx,
7326 TestCaseGroup& dstGroup,
7327 const bool test16Bit = false,
7328 const bool storage32Bit = false)
7329 {
7330 const int userRandoms = ctx.getCommandLine().getTestIterationCount();
7331 const int numRandoms = userRandoms > 0 ? userRandoms : defRandoms;
7332
7333 MovePtr<const CaseFactories> cases = (test16Bit && !storage32Bit) ? createBuiltinCases16Bit()
7334 : createBuiltinCases();
7335 for (size_t ndx = 0; ndx < cases->getFactories().size(); ++ndx)
7336 {
7337 if (!test16Bit)
7338 dstGroup.addChild(createFuncGroup(ctx, *cases->getFactories()[ndx], numRandoms));
7339 else
7340 dstGroup.addChild(createFuncGroup16Bit(ctx, *cases->getFactories()[ndx], numRandoms, storage32Bit));
7341 }
7342 }
7343
addBuiltinPrecisionDoubleTests(TestContext & ctx,TestCaseGroup & dstGroup)7344 void addBuiltinPrecisionDoubleTests (TestContext& ctx,
7345 TestCaseGroup& dstGroup)
7346 {
7347 const int userRandoms = ctx.getCommandLine().getTestIterationCount();
7348 const int numRandoms = userRandoms > 0 ? userRandoms : defRandoms;
7349
7350 MovePtr<const CaseFactories> cases = createBuiltinDoubleCases();
7351 for (size_t ndx = 0; ndx < cases->getFactories().size(); ++ndx)
7352 {
7353 dstGroup.addChild(createFuncGroupDouble(ctx, *cases->getFactories()[ndx], numRandoms));
7354 }
7355 }
7356
BuiltinPrecisionTests(tcu::TestContext & testCtx)7357 BuiltinPrecisionTests::BuiltinPrecisionTests (tcu::TestContext& testCtx)
7358 : tcu::TestCaseGroup(testCtx, "precision")
7359 {
7360 }
7361
~BuiltinPrecisionTests(void)7362 BuiltinPrecisionTests::~BuiltinPrecisionTests (void)
7363 {
7364 }
7365
init(void)7366 void BuiltinPrecisionTests::init (void)
7367 {
7368 addBuiltinPrecisionTests(m_testCtx, *this);
7369 }
7370
BuiltinPrecisionDoubleTests(tcu::TestContext & testCtx)7371 BuiltinPrecisionDoubleTests::BuiltinPrecisionDoubleTests (tcu::TestContext& testCtx)
7372 : tcu::TestCaseGroup(testCtx, "precision_double")
7373 {
7374 }
7375
~BuiltinPrecisionDoubleTests(void)7376 BuiltinPrecisionDoubleTests::~BuiltinPrecisionDoubleTests (void)
7377 {
7378 }
7379
init(void)7380 void BuiltinPrecisionDoubleTests::init (void)
7381 {
7382 addBuiltinPrecisionDoubleTests(m_testCtx, *this);
7383 }
7384
BuiltinPrecision16BitTests(tcu::TestContext & testCtx)7385 BuiltinPrecision16BitTests::BuiltinPrecision16BitTests (tcu::TestContext& testCtx)
7386 : tcu::TestCaseGroup(testCtx, "precision_fp16_storage16b")
7387 {
7388 }
7389
~BuiltinPrecision16BitTests(void)7390 BuiltinPrecision16BitTests::~BuiltinPrecision16BitTests (void)
7391 {
7392 }
7393
init(void)7394 void BuiltinPrecision16BitTests::init (void)
7395 {
7396 addBuiltinPrecisionTests(m_testCtx, *this, true);
7397 }
7398
BuiltinPrecision16Storage32BitTests(tcu::TestContext & testCtx)7399 BuiltinPrecision16Storage32BitTests::BuiltinPrecision16Storage32BitTests(tcu::TestContext& testCtx)
7400 : tcu::TestCaseGroup(testCtx, "precision_fp16_storage32b")
7401 {
7402 }
7403
~BuiltinPrecision16Storage32BitTests(void)7404 BuiltinPrecision16Storage32BitTests::~BuiltinPrecision16Storage32BitTests(void)
7405 {
7406 }
7407
init(void)7408 void BuiltinPrecision16Storage32BitTests::init(void)
7409 {
7410 addBuiltinPrecisionTests(m_testCtx, *this, true, true);
7411 }
7412
7413 } // shaderexecutor
7414 } // vkt
7415