1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Validates correctness of derivative SPIR-V instructions.
16
17 #include <string>
18
19 #include "source/opcode.h"
20 #include "source/val/instruction.h"
21 #include "source/val/validate.h"
22 #include "source/val/validation_state.h"
23
24 namespace spvtools {
25 namespace val {
26
27 // Validates correctness of derivative instructions.
DerivativesPass(ValidationState_t & _,const Instruction * inst)28 spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
29 const spv::Op opcode = inst->opcode();
30 const uint32_t result_type = inst->type_id();
31
32 switch (opcode) {
33 case spv::Op::OpDPdx:
34 case spv::Op::OpDPdy:
35 case spv::Op::OpFwidth:
36 case spv::Op::OpDPdxFine:
37 case spv::Op::OpDPdyFine:
38 case spv::Op::OpFwidthFine:
39 case spv::Op::OpDPdxCoarse:
40 case spv::Op::OpDPdyCoarse:
41 case spv::Op::OpFwidthCoarse: {
42 if (!_.IsFloatScalarOrVectorType(result_type)) {
43 return _.diag(SPV_ERROR_INVALID_DATA, inst)
44 << "Expected Result Type to be float scalar or vector type: "
45 << spvOpcodeString(opcode);
46 }
47 if (!_.ContainsSizedIntOrFloatType(result_type, spv::Op::OpTypeFloat,
48 32)) {
49 return _.diag(SPV_ERROR_INVALID_DATA, inst)
50 << "Result type component width must be 32 bits";
51 }
52
53 const uint32_t p_type = _.GetOperandTypeId(inst, 2);
54 if (p_type != result_type) {
55 return _.diag(SPV_ERROR_INVALID_DATA, inst)
56 << "Expected P type and Result Type to be the same: "
57 << spvOpcodeString(opcode);
58 }
59 _.function(inst->function()->id())
60 ->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model,
61 std::string* message) {
62 if (model != spv::ExecutionModel::Fragment &&
63 model != spv::ExecutionModel::GLCompute) {
64 if (message) {
65 *message =
66 std::string(
67 "Derivative instructions require Fragment or GLCompute "
68 "execution model: ") +
69 spvOpcodeString(opcode);
70 }
71 return false;
72 }
73 return true;
74 });
75 _.function(inst->function()->id())
76 ->RegisterLimitation([opcode](const ValidationState_t& state,
77 const Function* entry_point,
78 std::string* message) {
79 const auto* models = state.GetExecutionModels(entry_point->id());
80 const auto* modes = state.GetExecutionModes(entry_point->id());
81 if (models &&
82 models->find(spv::ExecutionModel::GLCompute) != models->end() &&
83 (!modes ||
84 (modes->find(spv::ExecutionMode::DerivativeGroupLinearNV) ==
85 modes->end() &&
86 modes->find(spv::ExecutionMode::DerivativeGroupQuadsNV) ==
87 modes->end()))) {
88 if (message) {
89 *message = std::string(
90 "Derivative instructions require "
91 "DerivativeGroupQuadsNV "
92 "or DerivativeGroupLinearNV execution mode for "
93 "GLCompute execution model: ") +
94 spvOpcodeString(opcode);
95 }
96 return false;
97 }
98 return true;
99 });
100 break;
101 }
102
103 default:
104 break;
105 }
106
107 return SPV_SUCCESS;
108 }
109
110 } // namespace val
111 } // namespace spvtools
112