1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "spirv_stats.h"
16
17 #include <cassert>
18
19 #include <algorithm>
20 #include <memory>
21 #include <string>
22 #include <vector>
23
24 #include "binary.h"
25 #include "diagnostic.h"
26 #include "enum_string_mapping.h"
27 #include "extensions.h"
28 #include "instruction.h"
29 #include "opcode.h"
30 #include "operand.h"
31 #include "spirv-tools/libspirv.h"
32 #include "spirv_endian.h"
33 #include "spirv_validator_options.h"
34 #include "validate.h"
35 #include "val/instruction.h"
36 #include "val/validation_state.h"
37
38 using libspirv::Instruction;
39 using libspirv::SpirvStats;
40 using libspirv::ValidationState_t;
41
42 namespace {
43
44 // Helper class for stats aggregation. Receives as in/out parameter.
45 // Constructs ValidationState and updates it by running validator for each
46 // instruction.
47 class StatsAggregator {
48 public:
StatsAggregator(SpirvStats * in_out_stats,const spv_const_context context)49 StatsAggregator(SpirvStats* in_out_stats, const spv_const_context context) {
50 stats_ = in_out_stats;
51 vstate_.reset(new ValidationState_t(context, &validator_options_));
52 }
53
54 // Collects header statistics and sets correct id_bound.
ProcessHeader(spv_endianness_t,uint32_t,uint32_t version,uint32_t generator,uint32_t id_bound,uint32_t)55 spv_result_t ProcessHeader(
56 spv_endianness_t /* endian */, uint32_t /* magic */,
57 uint32_t version, uint32_t generator, uint32_t id_bound,
58 uint32_t /* schema */) {
59 vstate_->setIdBound(id_bound);
60 ++stats_->version_hist[version];
61 ++stats_->generator_hist[generator];
62 return SPV_SUCCESS;
63 }
64
65 // Runs validator to validate the instruction and update vstate_,
66 // then procession the instruction to collect stats.
ProcessInstruction(const spv_parsed_instruction_t * inst)67 spv_result_t ProcessInstruction(const spv_parsed_instruction_t* inst) {
68 const spv_result_t validation_result =
69 spvtools::ValidateInstructionAndUpdateValidationState(vstate_.get(), inst);
70 if (validation_result != SPV_SUCCESS)
71 return validation_result;
72
73 ProcessOpcode();
74 ProcessCapability();
75 ProcessExtension();
76 ProcessConstant();
77
78 return SPV_SUCCESS;
79 }
80
81 // Collects OpCapability statistics.
ProcessCapability()82 void ProcessCapability() {
83 const Instruction& inst = GetCurrentInstruction();
84 if (inst.opcode() != SpvOpCapability) return;
85 const uint32_t capability = inst.word(inst.operands()[0].offset);
86 ++stats_->capability_hist[capability];
87 }
88
89 // Collects OpExtension statistics.
ProcessExtension()90 void ProcessExtension() {
91 const Instruction& inst = GetCurrentInstruction();
92 if (inst.opcode() != SpvOpExtension) return;
93 const std::string extension = libspirv::GetExtensionString(&inst.c_inst());
94 ++stats_->extension_hist[extension];
95 }
96
97 // Collects OpCode statistics.
ProcessOpcode()98 void ProcessOpcode() {
99 auto inst_it = vstate_->ordered_instructions().rbegin();
100 const SpvOp opcode = inst_it->opcode();
101 ++stats_->opcode_hist[opcode];
102
103 ++inst_it;
104 auto step_it = stats_->opcode_markov_hist.begin();
105 for (; inst_it != vstate_->ordered_instructions().rend() &&
106 step_it != stats_->opcode_markov_hist.end(); ++inst_it, ++step_it) {
107 auto& hist = (*step_it)[inst_it->opcode()];
108 ++hist[opcode];
109 }
110 }
111
112 // Collects OpConstant statistics.
ProcessConstant()113 void ProcessConstant() {
114 const Instruction& inst = GetCurrentInstruction();
115 if (inst.opcode() != SpvOpConstant) return;
116 const uint32_t type_id = inst.GetOperandAs<uint32_t>(0);
117 const auto type_decl_it = vstate_->all_definitions().find(type_id);
118 assert(type_decl_it != vstate_->all_definitions().end());
119 const Instruction& type_decl_inst = *type_decl_it->second;
120 const SpvOp type_op = type_decl_inst.opcode();
121 if (type_op == SpvOpTypeInt) {
122 const uint32_t bit_width = type_decl_inst.GetOperandAs<uint32_t>(1);
123 const uint32_t is_signed = type_decl_inst.GetOperandAs<uint32_t>(2);
124 assert(is_signed == 0 || is_signed == 1);
125 if (bit_width == 16) {
126 if (is_signed)
127 ++stats_->s16_constant_hist[inst.GetOperandAs<int16_t>(2)];
128 else
129 ++stats_->u16_constant_hist[inst.GetOperandAs<uint16_t>(2)];
130 } else if (bit_width == 32) {
131 if (is_signed)
132 ++stats_->s32_constant_hist[inst.GetOperandAs<int32_t>(2)];
133 else
134 ++stats_->u32_constant_hist[inst.GetOperandAs<uint32_t>(2)];
135 } else if (bit_width == 64) {
136 if (is_signed)
137 ++stats_->s64_constant_hist[inst.GetOperandAs<int64_t>(2)];
138 else
139 ++stats_->u64_constant_hist[inst.GetOperandAs<uint64_t>(2)];
140 } else {
141 assert(false && "TypeInt bit width is not 16, 32 or 64");
142 }
143 } else if (type_op == SpvOpTypeFloat) {
144 const uint32_t bit_width = type_decl_inst.GetOperandAs<uint32_t>(1);
145 if (bit_width == 32) {
146 ++stats_->f32_constant_hist[inst.GetOperandAs<float>(2)];
147 } else if (bit_width == 64) {
148 ++stats_->f64_constant_hist[inst.GetOperandAs<double>(2)];
149 } else {
150 assert(bit_width == 16);
151 }
152 }
153 }
154
stats()155 SpirvStats* stats() {
156 return stats_;
157 }
158
159 private:
160 // Returns the current instruction (the one last processed by the validator).
GetCurrentInstruction() const161 const Instruction& GetCurrentInstruction() const {
162 return vstate_->ordered_instructions().back();
163 }
164
165 SpirvStats* stats_;
166 spv_validator_options_t validator_options_;
167 std::unique_ptr<ValidationState_t> vstate_;
168 };
169
ProcessHeader(void * user_data,spv_endianness_t endian,uint32_t magic,uint32_t version,uint32_t generator,uint32_t id_bound,uint32_t schema)170 spv_result_t ProcessHeader(
171 void* user_data, spv_endianness_t endian, uint32_t magic,
172 uint32_t version, uint32_t generator, uint32_t id_bound,
173 uint32_t schema) {
174 StatsAggregator* stats_aggregator =
175 reinterpret_cast<StatsAggregator*>(user_data);
176 return stats_aggregator->ProcessHeader(
177 endian, magic, version, generator, id_bound, schema);
178 }
179
ProcessInstruction(void * user_data,const spv_parsed_instruction_t * inst)180 spv_result_t ProcessInstruction(
181 void* user_data, const spv_parsed_instruction_t* inst) {
182 StatsAggregator* stats_aggregator =
183 reinterpret_cast<StatsAggregator*>(user_data);
184 return stats_aggregator->ProcessInstruction(inst);
185 }
186
187 } // namespace
188
189 namespace libspirv {
190
AggregateStats(const spv_context_t & context,const uint32_t * words,const size_t num_words,spv_diagnostic * pDiagnostic,SpirvStats * stats)191 spv_result_t AggregateStats(
192 const spv_context_t& context, const uint32_t* words, const size_t num_words,
193 spv_diagnostic* pDiagnostic, SpirvStats* stats) {
194 spv_const_binary_t binary = {words, num_words};
195
196 spv_endianness_t endian;
197 spv_position_t position = {};
198 if (spvBinaryEndianness(&binary, &endian)) {
199 return libspirv::DiagnosticStream(position, context.consumer,
200 SPV_ERROR_INVALID_BINARY)
201 << "Invalid SPIR-V magic number.";
202 }
203
204 spv_header_t header;
205 if (spvBinaryHeaderGet(&binary, endian, &header)) {
206 return libspirv::DiagnosticStream(position, context.consumer,
207 SPV_ERROR_INVALID_BINARY)
208 << "Invalid SPIR-V header.";
209 }
210
211 StatsAggregator stats_aggregator(stats, &context);
212
213 return spvBinaryParse(&context, &stats_aggregator, words, num_words,
214 ProcessHeader, ProcessInstruction, pDiagnostic);
215 }
216
217 } // namespace libspirv
218