• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "spirv_stats.h"
16 
17 #include <cassert>
18 
19 #include <algorithm>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "binary.h"
25 #include "diagnostic.h"
26 #include "enum_string_mapping.h"
27 #include "extensions.h"
28 #include "instruction.h"
29 #include "opcode.h"
30 #include "operand.h"
31 #include "spirv-tools/libspirv.h"
32 #include "spirv_endian.h"
33 #include "spirv_validator_options.h"
34 #include "validate.h"
35 #include "val/instruction.h"
36 #include "val/validation_state.h"
37 
38 using libspirv::Instruction;
39 using libspirv::SpirvStats;
40 using libspirv::ValidationState_t;
41 
42 namespace {
43 
44 // Helper class for stats aggregation. Receives as in/out parameter.
45 // Constructs ValidationState and updates it by running validator for each
46 // instruction.
47 class StatsAggregator {
48  public:
StatsAggregator(SpirvStats * in_out_stats,const spv_const_context context)49   StatsAggregator(SpirvStats* in_out_stats, const spv_const_context context) {
50     stats_ = in_out_stats;
51     vstate_.reset(new ValidationState_t(context, &validator_options_));
52   }
53 
54   // Collects header statistics and sets correct id_bound.
ProcessHeader(spv_endianness_t,uint32_t,uint32_t version,uint32_t generator,uint32_t id_bound,uint32_t)55   spv_result_t ProcessHeader(
56       spv_endianness_t /* endian */, uint32_t /* magic */,
57       uint32_t version, uint32_t generator, uint32_t id_bound,
58       uint32_t /* schema */) {
59     vstate_->setIdBound(id_bound);
60     ++stats_->version_hist[version];
61     ++stats_->generator_hist[generator];
62     return SPV_SUCCESS;
63   }
64 
65   // Runs validator to validate the instruction and update vstate_,
66   // then procession the instruction to collect stats.
ProcessInstruction(const spv_parsed_instruction_t * inst)67   spv_result_t ProcessInstruction(const spv_parsed_instruction_t* inst) {
68     const spv_result_t validation_result =
69         spvtools::ValidateInstructionAndUpdateValidationState(vstate_.get(), inst);
70     if (validation_result != SPV_SUCCESS)
71       return validation_result;
72 
73     ProcessOpcode();
74     ProcessCapability();
75     ProcessExtension();
76     ProcessConstant();
77 
78     return SPV_SUCCESS;
79   }
80 
81   // Collects OpCapability statistics.
ProcessCapability()82   void ProcessCapability() {
83     const Instruction& inst = GetCurrentInstruction();
84     if (inst.opcode() != SpvOpCapability) return;
85     const uint32_t capability = inst.word(inst.operands()[0].offset);
86     ++stats_->capability_hist[capability];
87   }
88 
89   // Collects OpExtension statistics.
ProcessExtension()90   void ProcessExtension() {
91     const Instruction& inst = GetCurrentInstruction();
92     if (inst.opcode() != SpvOpExtension) return;
93     const std::string extension = libspirv::GetExtensionString(&inst.c_inst());
94     ++stats_->extension_hist[extension];
95   }
96 
97   // Collects OpCode statistics.
ProcessOpcode()98   void ProcessOpcode() {
99     auto inst_it = vstate_->ordered_instructions().rbegin();
100     const SpvOp opcode = inst_it->opcode();
101     ++stats_->opcode_hist[opcode];
102 
103     ++inst_it;
104     auto step_it = stats_->opcode_markov_hist.begin();
105     for (; inst_it != vstate_->ordered_instructions().rend() &&
106          step_it != stats_->opcode_markov_hist.end(); ++inst_it, ++step_it) {
107       auto& hist = (*step_it)[inst_it->opcode()];
108       ++hist[opcode];
109     }
110   }
111 
112   // Collects OpConstant statistics.
ProcessConstant()113   void ProcessConstant() {
114     const Instruction& inst = GetCurrentInstruction();
115     if (inst.opcode() != SpvOpConstant) return;
116     const uint32_t type_id = inst.GetOperandAs<uint32_t>(0);
117     const auto type_decl_it = vstate_->all_definitions().find(type_id);
118     assert(type_decl_it != vstate_->all_definitions().end());
119     const Instruction& type_decl_inst = *type_decl_it->second;
120     const SpvOp type_op = type_decl_inst.opcode();
121     if (type_op == SpvOpTypeInt) {
122       const uint32_t bit_width = type_decl_inst.GetOperandAs<uint32_t>(1);
123       const uint32_t is_signed = type_decl_inst.GetOperandAs<uint32_t>(2);
124       assert(is_signed == 0 || is_signed == 1);
125       if (bit_width == 16) {
126         if (is_signed)
127           ++stats_->s16_constant_hist[inst.GetOperandAs<int16_t>(2)];
128         else
129           ++stats_->u16_constant_hist[inst.GetOperandAs<uint16_t>(2)];
130       } else if (bit_width == 32) {
131         if (is_signed)
132           ++stats_->s32_constant_hist[inst.GetOperandAs<int32_t>(2)];
133         else
134           ++stats_->u32_constant_hist[inst.GetOperandAs<uint32_t>(2)];
135       } else if (bit_width == 64) {
136         if (is_signed)
137           ++stats_->s64_constant_hist[inst.GetOperandAs<int64_t>(2)];
138         else
139           ++stats_->u64_constant_hist[inst.GetOperandAs<uint64_t>(2)];
140       } else {
141         assert(false && "TypeInt bit width is not 16, 32 or 64");
142       }
143     } else if (type_op == SpvOpTypeFloat) {
144       const uint32_t bit_width = type_decl_inst.GetOperandAs<uint32_t>(1);
145       if (bit_width == 32) {
146         ++stats_->f32_constant_hist[inst.GetOperandAs<float>(2)];
147       } else if (bit_width == 64) {
148         ++stats_->f64_constant_hist[inst.GetOperandAs<double>(2)];
149       } else {
150         assert(bit_width == 16);
151       }
152     }
153   }
154 
stats()155   SpirvStats* stats() {
156     return stats_;
157   }
158 
159  private:
160   // Returns the current instruction (the one last processed by the validator).
GetCurrentInstruction() const161   const Instruction& GetCurrentInstruction() const {
162     return vstate_->ordered_instructions().back();
163   }
164 
165   SpirvStats* stats_;
166   spv_validator_options_t validator_options_;
167   std::unique_ptr<ValidationState_t> vstate_;
168 };
169 
ProcessHeader(void * user_data,spv_endianness_t endian,uint32_t magic,uint32_t version,uint32_t generator,uint32_t id_bound,uint32_t schema)170 spv_result_t ProcessHeader(
171     void* user_data, spv_endianness_t endian, uint32_t magic,
172     uint32_t version, uint32_t generator, uint32_t id_bound,
173     uint32_t schema) {
174   StatsAggregator* stats_aggregator =
175       reinterpret_cast<StatsAggregator*>(user_data);
176   return stats_aggregator->ProcessHeader(
177       endian, magic, version, generator, id_bound, schema);
178 }
179 
ProcessInstruction(void * user_data,const spv_parsed_instruction_t * inst)180 spv_result_t ProcessInstruction(
181     void* user_data, const spv_parsed_instruction_t* inst) {
182   StatsAggregator* stats_aggregator =
183       reinterpret_cast<StatsAggregator*>(user_data);
184   return stats_aggregator->ProcessInstruction(inst);
185 }
186 
187 }  // namespace
188 
189 namespace libspirv {
190 
AggregateStats(const spv_context_t & context,const uint32_t * words,const size_t num_words,spv_diagnostic * pDiagnostic,SpirvStats * stats)191 spv_result_t AggregateStats(
192     const spv_context_t& context, const uint32_t* words, const size_t num_words,
193     spv_diagnostic* pDiagnostic, SpirvStats* stats) {
194   spv_const_binary_t binary = {words, num_words};
195 
196   spv_endianness_t endian;
197   spv_position_t position = {};
198   if (spvBinaryEndianness(&binary, &endian)) {
199     return libspirv::DiagnosticStream(position, context.consumer,
200                                       SPV_ERROR_INVALID_BINARY)
201         << "Invalid SPIR-V magic number.";
202   }
203 
204   spv_header_t header;
205   if (spvBinaryHeaderGet(&binary, endian, &header)) {
206     return libspirv::DiagnosticStream(position, context.consumer,
207                                       SPV_ERROR_INVALID_BINARY)
208         << "Invalid SPIR-V header.";
209   }
210 
211   StatsAggregator stats_aggregator(stats, &context);
212 
213   return spvBinaryParse(&context, &stats_aggregator, words, num_words,
214                         ProcessHeader, ProcessInstruction, pDiagnostic);
215 }
216 
217 }  // namespace libspirv
218