1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Validates correctness of derivative SPIR-V instructions.
16
17 #include "source/val/validate.h"
18
19 #include <string>
20
21 #include "source/diagnostic.h"
22 #include "source/opcode.h"
23 #include "source/val/instruction.h"
24 #include "source/val/validation_state.h"
25
26 namespace spvtools {
27 namespace val {
28
29 // Validates correctness of derivative instructions.
DerivativesPass(ValidationState_t & _,const Instruction * inst)30 spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
31 const spv::Op opcode = inst->opcode();
32 const uint32_t result_type = inst->type_id();
33
34 switch (opcode) {
35 case spv::Op::OpDPdx:
36 case spv::Op::OpDPdy:
37 case spv::Op::OpFwidth:
38 case spv::Op::OpDPdxFine:
39 case spv::Op::OpDPdyFine:
40 case spv::Op::OpFwidthFine:
41 case spv::Op::OpDPdxCoarse:
42 case spv::Op::OpDPdyCoarse:
43 case spv::Op::OpFwidthCoarse: {
44 if (!_.IsFloatScalarOrVectorType(result_type)) {
45 return _.diag(SPV_ERROR_INVALID_DATA, inst)
46 << "Expected Result Type to be float scalar or vector type: "
47 << spvOpcodeString(opcode);
48 }
49 if (!_.ContainsSizedIntOrFloatType(result_type, spv::Op::OpTypeFloat,
50 32)) {
51 return _.diag(SPV_ERROR_INVALID_DATA, inst)
52 << "Result type component width must be 32 bits";
53 }
54
55 const uint32_t p_type = _.GetOperandTypeId(inst, 2);
56 if (p_type != result_type) {
57 return _.diag(SPV_ERROR_INVALID_DATA, inst)
58 << "Expected P type and Result Type to be the same: "
59 << spvOpcodeString(opcode);
60 }
61 _.function(inst->function()->id())
62 ->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model,
63 std::string* message) {
64 if (model != spv::ExecutionModel::Fragment &&
65 model != spv::ExecutionModel::GLCompute) {
66 if (message) {
67 *message =
68 std::string(
69 "Derivative instructions require Fragment or GLCompute "
70 "execution model: ") +
71 spvOpcodeString(opcode);
72 }
73 return false;
74 }
75 return true;
76 });
77 _.function(inst->function()->id())
78 ->RegisterLimitation([opcode](const ValidationState_t& state,
79 const Function* entry_point,
80 std::string* message) {
81 const auto* models = state.GetExecutionModels(entry_point->id());
82 const auto* modes = state.GetExecutionModes(entry_point->id());
83 if (models &&
84 models->find(spv::ExecutionModel::GLCompute) != models->end() &&
85 (!modes ||
86 (modes->find(spv::ExecutionMode::DerivativeGroupLinearNV) ==
87 modes->end() &&
88 modes->find(spv::ExecutionMode::DerivativeGroupQuadsNV) ==
89 modes->end()))) {
90 if (message) {
91 *message = std::string(
92 "Derivative instructions require "
93 "DerivativeGroupQuadsNV "
94 "or DerivativeGroupLinearNV execution mode for "
95 "GLCompute execution model: ") +
96 spvOpcodeString(opcode);
97 }
98 return false;
99 }
100 return true;
101 });
102 break;
103 }
104
105 default:
106 break;
107 }
108
109 return SPV_SUCCESS;
110 }
111
112 } // namespace val
113 } // namespace spvtools
114