1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Validates correctness of derivative SPIR-V instructions.
16
17 #include "source/val/validate.h"
18
19 #include <string>
20
21 #include "source/diagnostic.h"
22 #include "source/opcode.h"
23 #include "source/val/instruction.h"
24 #include "source/val/validation_state.h"
25
26 namespace spvtools {
27 namespace val {
28
29 // Validates correctness of derivative instructions.
DerivativesPass(ValidationState_t & _,const Instruction * inst)30 spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
31 const SpvOp opcode = inst->opcode();
32 const uint32_t result_type = inst->type_id();
33
34 switch (opcode) {
35 case SpvOpDPdx:
36 case SpvOpDPdy:
37 case SpvOpFwidth:
38 case SpvOpDPdxFine:
39 case SpvOpDPdyFine:
40 case SpvOpFwidthFine:
41 case SpvOpDPdxCoarse:
42 case SpvOpDPdyCoarse:
43 case SpvOpFwidthCoarse: {
44 if (!_.IsFloatScalarOrVectorType(result_type)) {
45 return _.diag(SPV_ERROR_INVALID_DATA, inst)
46 << "Expected Result Type to be float scalar or vector type: "
47 << spvOpcodeString(opcode);
48 }
49 if (!_.ContainsSizedIntOrFloatType(result_type, SpvOpTypeFloat, 32)) {
50 return _.diag(SPV_ERROR_INVALID_DATA, inst)
51 << "Result type component width must be 32 bits";
52 }
53
54 const uint32_t p_type = _.GetOperandTypeId(inst, 2);
55 if (p_type != result_type) {
56 return _.diag(SPV_ERROR_INVALID_DATA, inst)
57 << "Expected P type and Result Type to be the same: "
58 << spvOpcodeString(opcode);
59 }
60 _.function(inst->function()->id())
61 ->RegisterExecutionModelLimitation([opcode](SpvExecutionModel model,
62 std::string* message) {
63 if (model != SpvExecutionModelFragment &&
64 model != SpvExecutionModelGLCompute) {
65 if (message) {
66 *message =
67 std::string(
68 "Derivative instructions require Fragment or GLCompute "
69 "execution model: ") +
70 spvOpcodeString(opcode);
71 }
72 return false;
73 }
74 return true;
75 });
76 _.function(inst->function()->id())
77 ->RegisterLimitation([opcode](const ValidationState_t& state,
78 const Function* entry_point,
79 std::string* message) {
80 const auto* models = state.GetExecutionModels(entry_point->id());
81 const auto* modes = state.GetExecutionModes(entry_point->id());
82 if (models->find(SpvExecutionModelGLCompute) != models->end() &&
83 modes->find(SpvExecutionModeDerivativeGroupLinearNV) ==
84 modes->end() &&
85 modes->find(SpvExecutionModeDerivativeGroupQuadsNV) ==
86 modes->end()) {
87 if (message) {
88 *message = std::string(
89 "Derivative instructions require "
90 "DerivativeGroupQuadsNV "
91 "or DerivativeGroupLinearNV execution mode for "
92 "GLCompute execution model: ") +
93 spvOpcodeString(opcode);
94 }
95 return false;
96 }
97 return true;
98 });
99 break;
100 }
101
102 default:
103 break;
104 }
105
106 return SPV_SUCCESS;
107 }
108
109 } // namespace val
110 } // namespace spvtools
111