1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Contains
16 // - SPIR-V to MARK-V encoder
17 // - MARK-V to SPIR-V decoder
18 //
19 // MARK-V is a compression format for SPIR-V binaries. It strips away
20 // non-essential information (such as result ids which can be regenerated) and
21 // uses various bit reduction techiniques to reduce the size of the binary.
22 //
23 // MarkvModel is a flatbuffers object containing a set of rules defining how
24 // compression/decompression is done (coding schemes, dictionaries).
25
26 #include <algorithm>
27 #include <cassert>
28 #include <cstring>
29 #include <functional>
30 #include <iostream>
31 #include <list>
32 #include <memory>
33 #include <numeric>
34 #include <string>
35 #include <vector>
36
37 #include "binary.h"
38 #include "diagnostic.h"
39 #include "enum_string_mapping.h"
40 #include "extensions.h"
41 #include "ext_inst.h"
42 #include "instruction.h"
43 #include "opcode.h"
44 #include "operand.h"
45 #include "spirv-tools/libspirv.h"
46 #include "spirv-tools/markv.h"
47 #include "spirv_endian.h"
48 #include "spirv_validator_options.h"
49 #include "util/bit_stream.h"
50 #include "util/parse_number.h"
51 #include "validate.h"
52 #include "val/instruction.h"
53 #include "val/validation_state.h"
54
55 using libspirv::Instruction;
56 using libspirv::ValidationState_t;
57 using spvtools::ValidateInstructionAndUpdateValidationState;
58 using spvutils::BitReaderWord64;
59 using spvutils::BitWriterWord64;
60
61 struct spv_markv_encoder_options_t {
62 };
63
64 struct spv_markv_decoder_options_t {
65 };
66
67 namespace {
68
69 const uint32_t kSpirvMagicNumber = SpvMagicNumber;
70 const uint32_t kMarkvMagicNumber = 0x07230303;
71
72 enum {
73 kMarkvFirstOpcode = 65536,
74 kMarkvOpNextInstructionEncodesResultId = 65536,
75 };
76
77 const size_t kCommentNumWhitespaces = 2;
78
79 // TODO(atgoo@github.com): This is a placeholder for an autogenerated flatbuffer
80 // containing MARK-V model for a specific dataset.
81 class MarkvModel {
82 public:
opcode_chunk_length() const83 size_t opcode_chunk_length() const { return 7; }
num_operands_chunk_length() const84 size_t num_operands_chunk_length() const { return 3; }
id_index_chunk_length() const85 size_t id_index_chunk_length() const { return 3; }
86
u16_chunk_length() const87 size_t u16_chunk_length() const { return 4; }
s16_chunk_length() const88 size_t s16_chunk_length() const { return 4; }
s16_block_exponent() const89 size_t s16_block_exponent() const { return 6; }
90
u32_chunk_length() const91 size_t u32_chunk_length() const { return 8; }
s32_chunk_length() const92 size_t s32_chunk_length() const { return 8; }
s32_block_exponent() const93 size_t s32_block_exponent() const { return 10; }
94
u64_chunk_length() const95 size_t u64_chunk_length() const { return 8; }
s64_chunk_length() const96 size_t s64_chunk_length() const { return 8; }
s64_block_exponent() const97 size_t s64_block_exponent() const { return 10; }
98 };
99
GetDefaultModel()100 const MarkvModel* GetDefaultModel() {
101 static MarkvModel model;
102 return &model;
103 }
104
105 // Returns chunk length used for variable length encoding of spirv operand
106 // words. Returns zero if operand type corresponds to potentially multiple
107 // words or a word which is not expected to profit from variable width encoding.
108 // Chunk length is selected based on the size of expected value.
109 // Most of these values will later be encoded with probability-based coding,
110 // but variable width integer coding is a good quick solution.
111 // TODO(atgoo@github.com): Put this in MarkvModel flatbuffer.
GetOperandVariableWidthChunkLength(spv_operand_type_t type)112 size_t GetOperandVariableWidthChunkLength(spv_operand_type_t type) {
113 switch (type) {
114 case SPV_OPERAND_TYPE_TYPE_ID:
115 return 4;
116 case SPV_OPERAND_TYPE_RESULT_ID:
117 case SPV_OPERAND_TYPE_ID:
118 case SPV_OPERAND_TYPE_SCOPE_ID:
119 case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID:
120 return 8;
121 case SPV_OPERAND_TYPE_LITERAL_INTEGER:
122 case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER:
123 return 6;
124 case SPV_OPERAND_TYPE_CAPABILITY:
125 return 6;
126 case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
127 case SPV_OPERAND_TYPE_EXECUTION_MODEL:
128 return 3;
129 case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
130 case SPV_OPERAND_TYPE_MEMORY_MODEL:
131 return 2;
132 case SPV_OPERAND_TYPE_EXECUTION_MODE:
133 return 6;
134 case SPV_OPERAND_TYPE_STORAGE_CLASS:
135 return 4;
136 case SPV_OPERAND_TYPE_DIMENSIONALITY:
137 case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
138 return 3;
139 case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
140 return 2;
141 case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
142 return 6;
143 case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
144 case SPV_OPERAND_TYPE_LINKAGE_TYPE:
145 case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
146 case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
147 return 2;
148 case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
149 return 3;
150 case SPV_OPERAND_TYPE_DECORATION:
151 case SPV_OPERAND_TYPE_BUILT_IN:
152 return 6;
153 case SPV_OPERAND_TYPE_GROUP_OPERATION:
154 case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
155 case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO:
156 return 2;
157 case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
158 case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
159 case SPV_OPERAND_TYPE_LOOP_CONTROL:
160 case SPV_OPERAND_TYPE_IMAGE:
161 case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
162 case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
163 case SPV_OPERAND_TYPE_SELECTION_CONTROL:
164 return 4;
165 case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER:
166 return 6;
167 default:
168 return 0;
169 }
170 return 0;
171 }
172
173 // Returns true if the opcode has a fixed number of operands. May return a
174 // false negative.
OpcodeHasFixedNumberOfOperands(SpvOp opcode)175 bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) {
176 switch (opcode) {
177 // TODO(atgoo@github.com) This is not a complete list.
178 case SpvOpNop:
179 case SpvOpName:
180 case SpvOpUndef:
181 case SpvOpSizeOf:
182 case SpvOpLine:
183 case SpvOpNoLine:
184 case SpvOpDecorationGroup:
185 case SpvOpExtension:
186 case SpvOpExtInstImport:
187 case SpvOpMemoryModel:
188 case SpvOpCapability:
189 case SpvOpTypeVoid:
190 case SpvOpTypeBool:
191 case SpvOpTypeInt:
192 case SpvOpTypeFloat:
193 case SpvOpTypeVector:
194 case SpvOpTypeMatrix:
195 case SpvOpTypeSampler:
196 case SpvOpTypeSampledImage:
197 case SpvOpTypeArray:
198 case SpvOpTypePointer:
199 case SpvOpConstantTrue:
200 case SpvOpConstantFalse:
201 case SpvOpLabel:
202 case SpvOpBranch:
203 case SpvOpFunction:
204 case SpvOpFunctionParameter:
205 case SpvOpFunctionEnd:
206 case SpvOpBitcast:
207 case SpvOpCopyObject:
208 case SpvOpTranspose:
209 case SpvOpSNegate:
210 case SpvOpFNegate:
211 case SpvOpIAdd:
212 case SpvOpFAdd:
213 case SpvOpISub:
214 case SpvOpFSub:
215 case SpvOpIMul:
216 case SpvOpFMul:
217 case SpvOpUDiv:
218 case SpvOpSDiv:
219 case SpvOpFDiv:
220 case SpvOpUMod:
221 case SpvOpSRem:
222 case SpvOpSMod:
223 case SpvOpFRem:
224 case SpvOpFMod:
225 case SpvOpVectorTimesScalar:
226 case SpvOpMatrixTimesScalar:
227 case SpvOpVectorTimesMatrix:
228 case SpvOpMatrixTimesVector:
229 case SpvOpMatrixTimesMatrix:
230 case SpvOpOuterProduct:
231 case SpvOpDot:
232 return true;
233 default:
234 break;
235 }
236 return false;
237 }
238
GetNumBitsToNextByte(size_t bit_pos)239 size_t GetNumBitsToNextByte(size_t bit_pos) {
240 return (8 - (bit_pos % 8)) % 8;
241 }
242
ShouldByteBreak(size_t bit_pos)243 bool ShouldByteBreak(size_t bit_pos) {
244 const size_t num_bits_to_next_byte = GetNumBitsToNextByte(bit_pos);
245 return num_bits_to_next_byte > 0; // && num_bits_to_next_byte <= 2;
246 }
247
248 // Defines and returns current MARK-V version.
GetMarkvVersion()249 uint32_t GetMarkvVersion() {
250 const uint32_t kVersionMajor = 1;
251 const uint32_t kVersionMinor = 0;
252 return kVersionMinor | (kVersionMajor << 16);
253 }
254
255 class CommentLogger {
256 public:
AppendText(const std::string & str)257 void AppendText(const std::string& str) {
258 Append(str);
259 use_delimiter_ = false;
260 }
261
AppendTextNewLine(const std::string & str)262 void AppendTextNewLine(const std::string& str) {
263 Append(str);
264 Append("\n");
265 use_delimiter_ = false;
266 }
267
AppendBitSequence(const std::string & str)268 void AppendBitSequence(const std::string& str) {
269 if (use_delimiter_)
270 Append("-");
271 Append(str);
272 use_delimiter_ = true;
273 }
274
AppendWhitespaces(size_t num)275 void AppendWhitespaces(size_t num) {
276 Append(std::string(num, ' '));
277 use_delimiter_ = false;
278 }
279
NewLine()280 void NewLine() {
281 Append("\n");
282 use_delimiter_ = false;
283 }
284
GetText() const285 std::string GetText() const {
286 return ss_.str();
287 }
288
289 private:
Append(const std::string & str)290 void Append(const std::string& str) {
291 ss_ << str;
292 // std::cerr << str;
293 }
294
295 std::stringstream ss_;
296
297 // If true a delimiter will be appended before the next bit sequence.
298 // Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0.
299 bool use_delimiter_ = false;
300 };
301
302 // Creates spv_text object containing text from |str|.
303 // The returned value is owned by the caller and needs to be destroyed with
304 // spvTextDestroy.
CreateSpvText(const std::string & str)305 spv_text CreateSpvText(const std::string& str) {
306 spv_text out = new spv_text_t();
307 assert(out);
308 char* cstr = new char[str.length() + 1];
309 assert(cstr);
310 std::strncpy(cstr, str.c_str(), str.length());
311 cstr[str.length()] = '\0';
312 out->str = cstr;
313 out->length = str.length();
314 return out;
315 }
316
317 // Base class for MARK-V encoder and decoder. Contains common functionality
318 // such as:
319 // - Validator connection and validation state.
320 // - SPIR-V grammar and helper functions.
321 class MarkvCodecBase {
322 public:
~MarkvCodecBase()323 virtual ~MarkvCodecBase() {
324 spvValidatorOptionsDestroy(validator_options_);
325 }
326
327 MarkvCodecBase() = delete;
328
SetModel(const MarkvModel * model)329 void SetModel(const MarkvModel* model) {
330 model_ = model;
331 }
332
333 protected:
334 struct MarkvHeader {
MarkvHeader__anon5783fbf10111::MarkvCodecBase::MarkvHeader335 MarkvHeader() {
336 magic_number = kMarkvMagicNumber;
337 markv_version = GetMarkvVersion();
338 markv_model = 0;
339 markv_length_in_bits = 0;
340 spirv_version = 0;
341 spirv_generator = 0;
342 }
343
344 uint32_t magic_number;
345 uint32_t markv_version;
346 // Magic number to identify or verify MarkvModel used for encoding.
347 uint32_t markv_model;
348 uint32_t markv_length_in_bits;
349 uint32_t spirv_version;
350 uint32_t spirv_generator;
351 };
352
MarkvCodecBase(spv_const_context context,spv_validator_options validator_options)353 explicit MarkvCodecBase(spv_const_context context,
354 spv_validator_options validator_options)
355 : validator_options_(validator_options),
356 vstate_(context, validator_options_), grammar_(context),
357 model_(GetDefaultModel()) {}
358
359 // Validates a single instruction and updates validation state of the module.
UpdateValidationState(const spv_parsed_instruction_t & inst)360 spv_result_t UpdateValidationState(const spv_parsed_instruction_t& inst) {
361 return ValidateInstructionAndUpdateValidationState(&vstate_, &inst);
362 }
363
364 // Returns the current instruction (the one last processed by the validator).
GetCurrentInstruction() const365 const Instruction& GetCurrentInstruction() const {
366 return vstate_.ordered_instructions().back();
367 }
368
369 spv_validator_options validator_options_;
370 ValidationState_t vstate_;
371 const libspirv::AssemblyGrammar grammar_;
372 MarkvHeader header_;
373 const MarkvModel* model_;
374
375 // Move-to-front list of all ids.
376 // TODO(atgoo@github.com) Consider a better move-to-front implementation.
377 std::list<uint32_t> move_to_front_ids_;
378 };
379
380 // SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and
381 // EncodeInstruction which can be used as callback by spvBinaryParse.
382 // Encoded binary is written to an internally maintained bitstream.
383 // After the last instruction is encoded, the resulting MARK-V binary can be
384 // acquired by calling GetMarkvBinary().
385 // The encoder uses SPIR-V validator to keep internal state, therefore
386 // SPIR-V binary needs to be able to pass validator checks.
387 // CreateCommentsLogger() can be used to enable the encoder to write comments
388 // on how encoding was done, which can later be accessed with GetComments().
389 class MarkvEncoder : public MarkvCodecBase {
390 public:
MarkvEncoder(spv_const_context context,spv_const_markv_encoder_options options)391 MarkvEncoder(spv_const_context context,
392 spv_const_markv_encoder_options options)
393 : MarkvCodecBase(context, GetValidatorOptions(options)),
394 options_(options) {
395 (void) options_;
396 }
397
398 // Writes data from SPIR-V header to MARK-V header.
EncodeHeader(spv_endianness_t,uint32_t,uint32_t version,uint32_t generator,uint32_t id_bound,uint32_t)399 spv_result_t EncodeHeader(
400 spv_endianness_t /* endian */, uint32_t /* magic */,
401 uint32_t version, uint32_t generator, uint32_t id_bound,
402 uint32_t /* schema */) {
403 vstate_.setIdBound(id_bound);
404 header_.spirv_version = version;
405 header_.spirv_generator = generator;
406 return SPV_SUCCESS;
407 }
408
409 // Encodes SPIR-V instruction to MARK-V and writes to bit stream.
410 // Operation can fail if the instruction fails to pass the validator or if
411 // the encoder stubmles on something unexpected.
412 spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst);
413
414 // Concatenates MARK-V header and the bit stream with encoded instructions
415 // into a single buffer and returns it as spv_markv_binary. The returned
416 // value is owned by the caller and needs to be destroyed with
417 // spvMarkvBinaryDestroy().
GetMarkvBinary()418 spv_markv_binary GetMarkvBinary() {
419 header_.markv_length_in_bits =
420 static_cast<uint32_t>(sizeof(header_) * 8 + writer_.GetNumBits());
421 const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes();
422
423 spv_markv_binary markv_binary = new spv_markv_binary_t();
424 markv_binary->data = new uint8_t[num_bytes];
425 markv_binary->length = num_bytes;
426 assert(writer_.GetData());
427 std::memcpy(markv_binary->data, &header_, sizeof(header_));
428 std::memcpy(markv_binary->data + sizeof(header_),
429 writer_.GetData(), writer_.GetDataSizeBytes());
430 return markv_binary;
431 }
432
433 // Creates an internal logger which writes comments on the encoding process.
434 // Output can later be accessed with GetComments().
CreateCommentsLogger()435 void CreateCommentsLogger() {
436 logger_.reset(new CommentLogger());
437 writer_.SetCallback([this](const std::string& str){
438 logger_->AppendBitSequence(str);
439 });
440 }
441
442 // Optionally adds disassembly to the comments.
443 // Disassembly should contain all instructions in the module separated by
444 // \n, and no header.
SetDisassembly(std::string && disassembly)445 void SetDisassembly(std::string&& disassembly) {
446 disassembly_.reset(new std::stringstream(std::move(disassembly)));
447 }
448
449 // Extracts the next instruction line from the disassembly and logs it.
LogDisassemblyInstruction()450 void LogDisassemblyInstruction() {
451 if (logger_ && disassembly_) {
452 std::string line;
453 std::getline(*disassembly_, line, '\n');
454 logger_->AppendTextNewLine(line);
455 }
456 }
457
458 // Extracts the text from the comment logger.
GetComments() const459 std::string GetComments() const {
460 if (!logger_)
461 return "";
462 return logger_->GetText();
463 }
464
465 private:
466 // Creates and returns validator options. Return value owned by the caller.
GetValidatorOptions(spv_const_markv_encoder_options)467 static spv_validator_options GetValidatorOptions(
468 spv_const_markv_encoder_options) {
469 return spvValidatorOptionsCreate();
470 }
471
472 // Writes a single word to bit stream. |type| determines if the word is
473 // encoded and how.
EncodeOperandWord(spv_operand_type_t type,uint32_t word)474 void EncodeOperandWord(spv_operand_type_t type, uint32_t word) {
475 const size_t chunk_length =
476 GetOperandVariableWidthChunkLength(type);
477 if (chunk_length) {
478 writer_.WriteVariableWidthU32(word, chunk_length);
479 } else {
480 writer_.WriteUnencoded(word);
481 }
482 }
483
484 // Returns id index and updates move-to-front.
485 // Index is uint16 as SPIR-V module is guaranteed to have no more than 65535
486 // instructions.
GetIdIndex(uint32_t id)487 uint16_t GetIdIndex(uint32_t id) {
488 if (all_known_ids_.count(id)) {
489 uint16_t index = 0;
490 for (auto it = move_to_front_ids_.begin();
491 it != move_to_front_ids_.end(); ++it) {
492 if (*it == id) {
493 if (index != 0) {
494 move_to_front_ids_.erase(it);
495 move_to_front_ids_.push_front(id);
496 }
497 return index;
498 }
499 ++index;
500 }
501 assert(0 && "Id not found in move_to_front_ids_");
502 return 0;
503 } else {
504 all_known_ids_.insert(id);
505 move_to_front_ids_.push_front(id);
506 return static_cast<uint16_t>(move_to_front_ids_.size() - 1);
507 }
508 }
509
AddByteBreakIfAgreed()510 void AddByteBreakIfAgreed() {
511 if (!ShouldByteBreak(writer_.GetNumBits()))
512 return;
513
514 if (logger_) {
515 logger_->AppendWhitespaces(kCommentNumWhitespaces);
516 logger_->AppendText("ByteBreak:");
517 }
518
519 writer_.WriteBits(0, GetNumBitsToNextByte(writer_.GetNumBits()));
520 }
521
522 // Encodes a literal number operand and writes it to the bit stream.
523 void EncodeLiteralNumber(const Instruction& instruction,
524 const spv_parsed_operand_t& operand);
525
526 spv_const_markv_encoder_options options_;
527
528 // Bit stream where encoded instructions are written.
529 BitWriterWord64 writer_;
530
531 // If not nullptr, encoder will write comments.
532 std::unique_ptr<CommentLogger> logger_;
533
534 // If not nullptr, disassembled instruction lines will be written to comments.
535 // Format: \n separated instruction lines, no header.
536 std::unique_ptr<std::stringstream> disassembly_;
537
538 // All ids which were previosly encountered in the module.
539 std::unordered_set<uint32_t> all_known_ids_;
540 };
541
542 // Decodes MARK-V buffers written by MarkvEncoder.
543 class MarkvDecoder : public MarkvCodecBase {
544 public:
MarkvDecoder(spv_const_context context,const uint8_t * markv_data,size_t markv_size_bytes,spv_const_markv_decoder_options options)545 MarkvDecoder(spv_const_context context,
546 const uint8_t* markv_data,
547 size_t markv_size_bytes,
548 spv_const_markv_decoder_options options)
549 : MarkvCodecBase(context, GetValidatorOptions(options)),
550 options_(options), reader_(markv_data, markv_size_bytes) {
551 (void) options_;
552 vstate_.setIdBound(1);
553 parsed_operands_.reserve(25);
554 }
555
556 // Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|.
557 // Can be called only once. Fails if data of wrong format or ends prematurely,
558 // of if validation fails.
559 spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary);
560
561 private:
562 // Describes the format of a typed literal number.
563 struct NumberType {
564 spv_number_kind_t type;
565 uint32_t bit_width;
566 };
567
568 // Creates and returns validator options. Return value owned by the caller.
GetValidatorOptions(spv_const_markv_decoder_options)569 static spv_validator_options GetValidatorOptions(
570 spv_const_markv_decoder_options) {
571 return spvValidatorOptionsCreate();
572 }
573
574 // Reads a single word from bit stream. |type| determines if the word needs
575 // to be decoded and how. Returns false if read fails.
DecodeOperandWord(spv_operand_type_t type,uint32_t * word)576 bool DecodeOperandWord(spv_operand_type_t type, uint32_t* word) {
577 const size_t chunk_length = GetOperandVariableWidthChunkLength(type);
578 if (chunk_length) {
579 return reader_.ReadVariableWidthU32(word, chunk_length);
580 } else {
581 return reader_.ReadUnencoded(word);
582 }
583 }
584
585 // Fetches the id from the move-to-front list and moves it to front.
GetIdAndMoveToFront(uint16_t index)586 uint32_t GetIdAndMoveToFront(uint16_t index) {
587 if (index >= move_to_front_ids_.size()) {
588 // Issue new id.
589 const uint32_t id = vstate_.getIdBound();
590 move_to_front_ids_.push_front(id);
591 vstate_.setIdBound(id + 1);
592 return id;
593 } else {
594 if (index == 0)
595 return move_to_front_ids_.front();
596
597 // Iterate to index.
598 auto it = move_to_front_ids_.begin();
599 for (size_t i = 0; i < index; ++i)
600 ++it;
601 const uint32_t id = *it;
602 move_to_front_ids_.erase(it);
603 move_to_front_ids_.push_front(id);
604 return id;
605 }
606 }
607
608 // Decodes id index and fetches the id from move-to-front list.
DecodeId(uint32_t * id)609 bool DecodeId(uint32_t* id) {
610 uint16_t index = 0;
611 if (!reader_.ReadVariableWidthU16(&index, model_->id_index_chunk_length()))
612 return false;
613
614 *id = GetIdAndMoveToFront(index);
615 return true;
616 }
617
ReadToByteBreakIfAgreed()618 bool ReadToByteBreakIfAgreed() {
619 if (!ShouldByteBreak(reader_.GetNumReadBits()))
620 return true;
621
622 uint64_t bits = 0;
623 if (!reader_.ReadBits(&bits,
624 GetNumBitsToNextByte(reader_.GetNumReadBits())))
625 return false;
626
627 if (bits != 0)
628 return false;
629
630 return true;
631 }
632
633 // Reads a literal number as it is described in |operand| from the bit stream,
634 // decodes and writes it to spirv_.
635 spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand);
636
637 // Reads instruction from bit stream, decodes and validates it.
638 // Decoded instruction is valid until the next call of DecodeInstruction().
639 spv_result_t DecodeInstruction(spv_parsed_instruction_t* inst);
640
641 // Read operand from the stream decodes and validates it.
642 spv_result_t DecodeOperand(size_t instruction_offset, size_t operand_offset,
643 spv_parsed_instruction_t* inst,
644 const spv_operand_type_t type,
645 spv_operand_pattern_t* expected_operands,
646 bool read_result_id);
647
648 // Records the numeric type for an operand according to the type information
649 // associated with the given non-zero type Id. This can fail if the type Id
650 // is not a type Id, or if the type Id does not reference a scalar numeric
651 // type. On success, return SPV_SUCCESS and populates the num_words,
652 // number_kind, and number_bit_width fields of parsed_operand.
653 spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand,
654 uint32_t type_id);
655
656 // Records the number type for the given instruction, if that
657 // instruction generates a type. For types that aren't scalar numbers,
658 // record something with number kind SPV_NUMBER_NONE.
659 void RecordNumberType(const spv_parsed_instruction_t& inst);
660
661 spv_const_markv_decoder_options options_;
662
663 // Temporary sink where decoded SPIR-V words are written. Once it contains the
664 // entire module, the container is moved and returned.
665 std::vector<uint32_t> spirv_;
666
667 // Bit stream containing encoded data.
668 BitReaderWord64 reader_;
669
670 // Temporary storage for operands of the currently parsed instruction.
671 // Valid until next DecodeInstruction call.
672 std::vector<spv_parsed_operand_t> parsed_operands_;
673
674 // Maps a result ID to its type ID. By convention:
675 // - a result ID that is a type definition maps to itself.
676 // - a result ID without a type maps to 0. (E.g. for OpLabel)
677 std::unordered_map<uint32_t, uint32_t> id_to_type_id_;
678 // Maps a type ID to its number type description.
679 std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_;
680 // Maps an ExtInstImport id to the extended instruction type.
681 std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_;
682 };
683
EncodeLiteralNumber(const Instruction & instruction,const spv_parsed_operand_t & operand)684 void MarkvEncoder::EncodeLiteralNumber(const Instruction& instruction,
685 const spv_parsed_operand_t& operand) {
686 if (operand.number_bit_width == 32) {
687 const uint32_t word = instruction.word(operand.offset);
688 if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
689 writer_.WriteVariableWidthU32(word, model_->u32_chunk_length());
690 } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
691 int32_t val = 0;
692 std::memcpy(&val, &word, 4);
693 writer_.WriteVariableWidthS32(val, model_->s32_chunk_length(),
694 model_->s32_block_exponent());
695 } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
696 writer_.WriteUnencoded(word);
697 } else {
698 assert(0);
699 }
700 } else if (operand.number_bit_width == 16) {
701 const uint16_t word =
702 static_cast<uint16_t>(instruction.word(operand.offset));
703 if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
704 writer_.WriteVariableWidthU16(word, model_->u16_chunk_length());
705 } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
706 int16_t val = 0;
707 std::memcpy(&val, &word, 2);
708 writer_.WriteVariableWidthS16(val, model_->s16_chunk_length(),
709 model_->s16_block_exponent());
710 } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
711 // TODO(atgoo@github.com) Write only 16 bits.
712 writer_.WriteUnencoded(word);
713 } else {
714 assert(0);
715 }
716 } else {
717 assert(operand.number_bit_width == 64);
718 const uint64_t word =
719 uint64_t(instruction.word(operand.offset)) |
720 (uint64_t(instruction.word(operand.offset + 1)) << 32);
721 if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
722 writer_.WriteVariableWidthU64(word, model_->u64_chunk_length());
723 } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
724 int64_t val = 0;
725 std::memcpy(&val, &word, 8);
726 writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(),
727 model_->s64_block_exponent());
728 } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
729 writer_.WriteUnencoded(word);
730 } else {
731 assert(0);
732 }
733 }
734 }
735
EncodeInstruction(const spv_parsed_instruction_t & inst)736 spv_result_t MarkvEncoder::EncodeInstruction(
737 const spv_parsed_instruction_t& inst) {
738 const spv_result_t validation_result = UpdateValidationState(inst);
739 if (validation_result != SPV_SUCCESS)
740 return validation_result;
741
742 bool result_id_was_forward_declared = false;
743 if (all_known_ids_.count(inst.result_id)) {
744 // Result id of the instruction was forward declared.
745 // Write a service opcode to signal this to the decoder.
746 writer_.WriteVariableWidthU32(kMarkvOpNextInstructionEncodesResultId,
747 model_->opcode_chunk_length());
748 result_id_was_forward_declared = true;
749 }
750
751 const Instruction& instruction = GetCurrentInstruction();
752 const auto& operands = instruction.operands();
753
754 LogDisassemblyInstruction();
755
756 // Write opcode.
757 writer_.WriteVariableWidthU32(inst.opcode, model_->opcode_chunk_length());
758
759 if (!OpcodeHasFixedNumberOfOperands(SpvOp(inst.opcode))) {
760 // If the opcode has a variable number of operands, encode the number of
761 // operands with the instruction.
762
763 if (logger_)
764 logger_->AppendWhitespaces(kCommentNumWhitespaces);
765
766 writer_.WriteVariableWidthU16(inst.num_operands,
767 model_->num_operands_chunk_length());
768 }
769
770 // Write operands.
771 for (const auto& operand : operands) {
772 if (operand.type == SPV_OPERAND_TYPE_RESULT_ID &&
773 !result_id_was_forward_declared) {
774 // Register the id, but don't encode it.
775 GetIdIndex(instruction.word(operand.offset));
776 continue;
777 }
778
779 if (logger_)
780 logger_->AppendWhitespaces(kCommentNumWhitespaces);
781
782 if (operand.type == SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER) {
783 EncodeLiteralNumber(instruction, operand);
784 } else if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING) {
785 const char* src =
786 reinterpret_cast<const char*>(&instruction.words()[operand.offset]);
787 const size_t length = spv_strnlen_s(src, operand.num_words * 4);
788 if (length == operand.num_words * 4)
789 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
790 << "Failed to find terminal character of literal string";
791 for (size_t i = 0; i < length + 1; ++i)
792 writer_.WriteUnencoded(src[i]);
793 } else if (spvIsIdType(operand.type)) {
794 const uint16_t id_index = GetIdIndex(instruction.word(operand.offset));
795 writer_.WriteVariableWidthU16(id_index, model_->id_index_chunk_length());
796 } else {
797 for (int i = 0; i < operand.num_words; ++i) {
798 const uint32_t word = instruction.word(operand.offset + i);
799 EncodeOperandWord(operand.type, word);
800 }
801 }
802 }
803
804 AddByteBreakIfAgreed();
805
806 if (logger_) {
807 logger_->NewLine();
808 logger_->NewLine();
809 }
810
811 return SPV_SUCCESS;
812 }
813
DecodeLiteralNumber(const spv_parsed_operand_t & operand)814 spv_result_t MarkvDecoder::DecodeLiteralNumber(
815 const spv_parsed_operand_t& operand) {
816 if (operand.number_bit_width == 32) {
817 uint32_t word = 0;
818 if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
819 if (!reader_.ReadVariableWidthU32(&word, model_->u32_chunk_length()))
820 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
821 << "Failed to read literal U32";
822 } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
823 int32_t val = 0;
824 if (!reader_.ReadVariableWidthS32(&val, model_->s32_chunk_length(),
825 model_->s32_block_exponent()))
826 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
827 << "Failed to read literal S32";
828 std::memcpy(&word, &val, 4);
829 } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
830 if (!reader_.ReadUnencoded(&word))
831 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
832 << "Failed to read literal F32";
833 } else {
834 assert(0);
835 }
836 spirv_.push_back(word);
837 } else if (operand.number_bit_width == 16) {
838 uint32_t word = 0;
839 if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
840 uint16_t val = 0;
841 if (!reader_.ReadVariableWidthU16(&val, model_->u16_chunk_length()))
842 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
843 << "Failed to read literal U16";
844 word = val;
845 } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
846 int16_t val = 0;
847 if (!reader_.ReadVariableWidthS16(&val, model_->s16_chunk_length(),
848 model_->s16_block_exponent()))
849 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
850 << "Failed to read literal S16";
851 // Int16 is stored as int32 in SPIR-V, not as bits.
852 int32_t val32 = val;
853 std::memcpy(&word, &val32, 4);
854 } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
855 uint16_t word16 = 0;
856 if (!reader_.ReadUnencoded(&word16))
857 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
858 << "Failed to read literal F16";
859 word = word16;
860 } else {
861 assert(0);
862 }
863 spirv_.push_back(word);
864 } else {
865 assert(operand.number_bit_width == 64);
866 uint64_t word = 0;
867 if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
868 if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
869 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
870 << "Failed to read literal U64";
871 } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
872 int64_t val = 0;
873 if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
874 model_->s64_block_exponent()))
875 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
876 << "Failed to read literal S64";
877 std::memcpy(&word, &val, 8);
878 } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
879 if (!reader_.ReadUnencoded(&word))
880 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
881 << "Failed to read literal F64";
882 } else {
883 assert(0);
884 }
885 spirv_.push_back(static_cast<uint32_t>(word));
886 spirv_.push_back(static_cast<uint32_t>(word >> 32));
887 }
888 return SPV_SUCCESS;
889 }
890
DecodeModule(std::vector<uint32_t> * spirv_binary)891 spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
892 const bool header_read_success =
893 reader_.ReadUnencoded(&header_.magic_number) &&
894 reader_.ReadUnencoded(&header_.markv_version) &&
895 reader_.ReadUnencoded(&header_.markv_model) &&
896 reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
897 reader_.ReadUnencoded(&header_.spirv_version) &&
898 reader_.ReadUnencoded(&header_.spirv_generator);
899
900 if (!header_read_success)
901 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
902 << "Unable to read MARK-V header";
903
904 assert(header_.magic_number == kMarkvMagicNumber);
905 assert(header_.markv_length_in_bits > 0);
906
907 if (header_.magic_number != kMarkvMagicNumber)
908 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
909 << "MARK-V binary has incorrect magic number";
910
911 // TODO(atgoo@github.com): Print version strings.
912 if (header_.markv_version != GetMarkvVersion())
913 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
914 << "MARK-V binary and the codec have different versions";
915
916 spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic.
917 spirv_.resize(5, 0);
918 spirv_[0] = kSpirvMagicNumber;
919 spirv_[1] = header_.spirv_version;
920 spirv_[2] = header_.spirv_generator;
921
922 while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
923 spv_parsed_instruction_t inst = {};
924 const spv_result_t decode_result = DecodeInstruction(&inst);
925 if (decode_result != SPV_SUCCESS)
926 return decode_result;
927
928 const spv_result_t validation_result = UpdateValidationState(inst);
929 if (validation_result != SPV_SUCCESS)
930 return validation_result;
931 }
932
933
934 if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
935 !reader_.OnlyZeroesLeft()) {
936 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
937 << "MARK-V binary has wrong stated bit length "
938 << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
939 }
940
941 // Decoding of the module is finished, validation state should have correct
942 // id bound.
943 spirv_[3] = vstate_.getIdBound();
944
945 *spirv_binary = std::move(spirv_);
946 return SPV_SUCCESS;
947 }
948
949 // TODO(atgoo@github.com): The implementation borrows heavily from
950 // Parser::parseOperand.
951 // Consider coupling them together in some way once MARK-V codec is more mature.
952 // For now it's better to keep the code independent for experimentation
953 // purposes.
DecodeOperand(size_t instruction_offset,size_t operand_offset,spv_parsed_instruction_t * inst,const spv_operand_type_t type,spv_operand_pattern_t * expected_operands,bool read_result_id)954 spv_result_t MarkvDecoder::DecodeOperand(
955 size_t instruction_offset, size_t operand_offset,
956 spv_parsed_instruction_t* inst, const spv_operand_type_t type,
957 spv_operand_pattern_t* expected_operands,
958 bool read_result_id) {
959 const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
960
961 spv_parsed_operand_t parsed_operand;
962 memset(&parsed_operand, 0, sizeof(parsed_operand));
963
964 assert((operand_offset >> 16) == 0);
965 parsed_operand.offset = static_cast<uint16_t>(operand_offset);
966 parsed_operand.type = type;
967
968 // Set default values, may be updated later.
969 parsed_operand.number_kind = SPV_NUMBER_NONE;
970 parsed_operand.number_bit_width = 0;
971
972 const size_t first_word_index = spirv_.size();
973
974 switch (type) {
975 case SPV_OPERAND_TYPE_TYPE_ID: {
976 if (!DecodeId(&inst->type_id)) {
977 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
978 << "Failed to read type_id";
979 }
980
981 if (inst->type_id == 0)
982 return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded type_id is 0";
983
984 spirv_.push_back(inst->type_id);
985 vstate_.setIdBound(std::max(vstate_.getIdBound(), inst->type_id + 1));
986 break;
987 }
988
989 case SPV_OPERAND_TYPE_RESULT_ID: {
990 if (read_result_id) {
991 if (!DecodeId(&inst->result_id))
992 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
993 << "Failed to read result_id";
994 } else {
995 inst->result_id = vstate_.getIdBound();
996 vstate_.setIdBound(inst->result_id + 1);
997 move_to_front_ids_.push_front(inst->result_id);
998 }
999
1000 spirv_.push_back(inst->result_id);
1001
1002 // Save the result ID to type ID mapping.
1003 // In the grammar, type ID always appears before result ID.
1004 // A regular value maps to its type. Some instructions (e.g. OpLabel)
1005 // have no type Id, and will map to 0. The result Id for a
1006 // type-generating instruction (e.g. OpTypeInt) maps to itself.
1007 auto insertion_result = id_to_type_id_.emplace(
1008 inst->result_id,
1009 spvOpcodeGeneratesType(opcode) ? inst->result_id : inst->type_id);
1010 if(!insertion_result.second) {
1011 return vstate_.diag(SPV_ERROR_INVALID_ID)
1012 << "Unexpected behavior: id->type_id pair was already registered";
1013 }
1014 break;
1015 }
1016
1017 case SPV_OPERAND_TYPE_ID:
1018 case SPV_OPERAND_TYPE_OPTIONAL_ID:
1019 case SPV_OPERAND_TYPE_SCOPE_ID:
1020 case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
1021 uint32_t id = 0;
1022 if (!DecodeId(&id))
1023 return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read id";
1024
1025 if (id == 0)
1026 return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
1027
1028 spirv_.push_back(id);
1029 vstate_.setIdBound(std::max(vstate_.getIdBound(), id + 1));
1030
1031 if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
1032
1033 parsed_operand.type = SPV_OPERAND_TYPE_ID;
1034
1035 if (opcode == SpvOpExtInst && parsed_operand.offset == 3) {
1036 // The current word is the extended instruction set id.
1037 // Set the extended instruction set type for the current instruction.
1038 auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id);
1039 if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) {
1040 return vstate_.diag(SPV_ERROR_INVALID_ID)
1041 << "OpExtInst set id " << id
1042 << " does not reference an OpExtInstImport result Id";
1043 }
1044 inst->ext_inst_type = ext_inst_type_iter->second;
1045 }
1046 }
1047 break;
1048 }
1049
1050 case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
1051 uint32_t word = 0;
1052 if (!DecodeOperandWord(type, &word))
1053 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1054 << "Failed to read enum";
1055
1056 spirv_.push_back(word);
1057
1058 assert(SpvOpExtInst == opcode);
1059 assert(inst->ext_inst_type != SPV_EXT_INST_TYPE_NONE);
1060 spv_ext_inst_desc ext_inst;
1061 if (grammar_.lookupExtInst(inst->ext_inst_type, word, &ext_inst))
1062 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1063 << "Invalid extended instruction number: " << word;
1064 spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
1065 break;
1066 }
1067
1068 case SPV_OPERAND_TYPE_LITERAL_INTEGER:
1069 case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
1070 // These are regular single-word literal integer operands.
1071 // Post-parsing validation should check the range of the parsed value.
1072 parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
1073 // It turns out they are always unsigned integers!
1074 parsed_operand.number_kind = SPV_NUMBER_UNSIGNED_INT;
1075 parsed_operand.number_bit_width = 32;
1076
1077 uint32_t word = 0;
1078 if (!DecodeOperandWord(type, &word))
1079 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1080 << "Failed to read literal integer";
1081
1082 spirv_.push_back(word);
1083 break;
1084 }
1085
1086 case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
1087 case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER:
1088 parsed_operand.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
1089 if (opcode == SpvOpSwitch) {
1090 // The literal operands have the same type as the value
1091 // referenced by the selector Id.
1092 const uint32_t selector_id = spirv_.at(instruction_offset + 1);
1093 const auto type_id_iter = id_to_type_id_.find(selector_id);
1094 if (type_id_iter == id_to_type_id_.end() ||
1095 type_id_iter->second == 0) {
1096 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1097 << "Invalid OpSwitch: selector id " << selector_id
1098 << " has no type";
1099 }
1100 uint32_t type_id = type_id_iter->second;
1101
1102 if (selector_id == type_id) {
1103 // Recall that by convention, a result ID that is a type definition
1104 // maps to itself.
1105 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1106 << "Invalid OpSwitch: selector id " << selector_id
1107 << " is a type, not a value";
1108 }
1109 if (auto error = SetNumericTypeInfoForType(&parsed_operand, type_id))
1110 return error;
1111 if (parsed_operand.number_kind != SPV_NUMBER_UNSIGNED_INT &&
1112 parsed_operand.number_kind != SPV_NUMBER_SIGNED_INT) {
1113 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1114 << "Invalid OpSwitch: selector id " << selector_id
1115 << " is not a scalar integer";
1116 }
1117 } else {
1118 assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
1119 // The literal number type is determined by the type Id for the
1120 // constant.
1121 assert(inst->type_id);
1122 if (auto error =
1123 SetNumericTypeInfoForType(&parsed_operand, inst->type_id))
1124 return error;
1125 }
1126
1127 if (auto error = DecodeLiteralNumber(parsed_operand))
1128 return error;
1129
1130 break;
1131
1132 case SPV_OPERAND_TYPE_LITERAL_STRING:
1133 case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
1134 parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_STRING;
1135 std::vector<char> str;
1136 // The loop is expected to terminate once we encounter '\0' or exhaust
1137 // the bit stream.
1138 while (true) {
1139 char ch = 0;
1140 if (!reader_.ReadUnencoded(&ch))
1141 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1142 << "Failed to read literal string";
1143
1144 str.push_back(ch);
1145
1146 if (ch == '\0')
1147 break;
1148 }
1149
1150 while (str.size() % 4 != 0)
1151 str.push_back('\0');
1152
1153 spirv_.resize(spirv_.size() + str.size() / 4);
1154 std::memcpy(&spirv_[first_word_index], str.data(), str.size());
1155
1156 if (SpvOpExtInstImport == opcode) {
1157 // Record the extended instruction type for the ID for this import.
1158 // There is only one string literal argument to OpExtInstImport,
1159 // so it's sufficient to guard this just on the opcode.
1160 const spv_ext_inst_type_t ext_inst_type =
1161 spvExtInstImportTypeGet(str.data());
1162 if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) {
1163 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1164 << "Invalid extended instruction import '" << str.data() << "'";
1165 }
1166 // We must have parsed a valid result ID. It's a condition
1167 // of the grammar, and we only accept non-zero result Ids.
1168 assert(inst->result_id);
1169 const bool inserted = import_id_to_ext_inst_type_.emplace(
1170 inst->result_id, ext_inst_type).second;
1171 (void)inserted;
1172 assert(inserted);
1173 }
1174 break;
1175 }
1176
1177 case SPV_OPERAND_TYPE_CAPABILITY:
1178 case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
1179 case SPV_OPERAND_TYPE_EXECUTION_MODEL:
1180 case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
1181 case SPV_OPERAND_TYPE_MEMORY_MODEL:
1182 case SPV_OPERAND_TYPE_EXECUTION_MODE:
1183 case SPV_OPERAND_TYPE_STORAGE_CLASS:
1184 case SPV_OPERAND_TYPE_DIMENSIONALITY:
1185 case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
1186 case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
1187 case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
1188 case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
1189 case SPV_OPERAND_TYPE_LINKAGE_TYPE:
1190 case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
1191 case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
1192 case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
1193 case SPV_OPERAND_TYPE_DECORATION:
1194 case SPV_OPERAND_TYPE_BUILT_IN:
1195 case SPV_OPERAND_TYPE_GROUP_OPERATION:
1196 case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
1197 case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
1198 // A single word that is a plain enum value.
1199 uint32_t word = 0;
1200 if (!DecodeOperandWord(type, &word))
1201 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1202 << "Failed to read enum";
1203
1204 spirv_.push_back(word);
1205
1206 // Map an optional operand type to its corresponding concrete type.
1207 if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
1208 parsed_operand.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
1209
1210 spv_operand_desc entry;
1211 if (grammar_.lookupOperand(type, word, &entry)) {
1212 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1213 << "Invalid "
1214 << spvOperandTypeStr(parsed_operand.type)
1215 << " operand: " << word;
1216 }
1217
1218 // Prepare to accept operands to this operand, if needed.
1219 spvPushOperandTypes(entry->operandTypes, expected_operands);
1220 break;
1221 }
1222
1223 case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
1224 case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
1225 case SPV_OPERAND_TYPE_LOOP_CONTROL:
1226 case SPV_OPERAND_TYPE_IMAGE:
1227 case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
1228 case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
1229 case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
1230 // This operand is a mask.
1231 uint32_t word = 0;
1232 if (!DecodeOperandWord(type, &word))
1233 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1234 << "Failed to read " << spvOperandTypeStr(type)
1235 << " for " << spvOpcodeString(SpvOp(inst->opcode));
1236
1237 spirv_.push_back(word);
1238
1239 // Map an optional operand type to its corresponding concrete type.
1240 if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
1241 parsed_operand.type = SPV_OPERAND_TYPE_IMAGE;
1242 else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
1243 parsed_operand.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
1244
1245 // Check validity of set mask bits. Also prepare for operands for those
1246 // masks if they have any. To get operand order correct, scan from
1247 // MSB to LSB since we can only prepend operands to a pattern.
1248 // The only case in the grammar where you have more than one mask bit
1249 // having an operand is for image operands. See SPIR-V 3.14 Image
1250 // Operands.
1251 uint32_t remaining_word = word;
1252 for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
1253 if (remaining_word & mask) {
1254 spv_operand_desc entry;
1255 if (grammar_.lookupOperand(type, mask, &entry)) {
1256 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1257 << "Invalid " << spvOperandTypeStr(parsed_operand.type)
1258 << " operand: " << word << " has invalid mask component "
1259 << mask;
1260 }
1261 remaining_word ^= mask;
1262 spvPushOperandTypes(entry->operandTypes, expected_operands);
1263 }
1264 }
1265 if (word == 0) {
1266 // An all-zeroes mask *might* also be valid.
1267 spv_operand_desc entry;
1268 if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
1269 // Prepare for its operands, if any.
1270 spvPushOperandTypes(entry->operandTypes, expected_operands);
1271 }
1272 }
1273 break;
1274 }
1275 default:
1276 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1277 << "Internal error: Unhandled operand type: " << type;
1278 }
1279
1280 parsed_operand.num_words = uint16_t(spirv_.size() - first_word_index);
1281
1282 assert(int(SPV_OPERAND_TYPE_FIRST_CONCRETE_TYPE) <= int(parsed_operand.type));
1283 assert(int(SPV_OPERAND_TYPE_LAST_CONCRETE_TYPE) >= int(parsed_operand.type));
1284
1285 parsed_operands_.push_back(parsed_operand);
1286
1287 return SPV_SUCCESS;
1288 }
1289
DecodeInstruction(spv_parsed_instruction_t * inst)1290 spv_result_t MarkvDecoder::DecodeInstruction(spv_parsed_instruction_t* inst) {
1291 parsed_operands_.clear();
1292 const size_t instruction_offset = spirv_.size();
1293
1294 bool read_result_id = false;
1295
1296 while (true) {
1297 uint32_t word = 0;
1298 if (!reader_.ReadVariableWidthU32(&word,
1299 model_->opcode_chunk_length())) {
1300 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1301 << "Failed to read opcode of instruction";
1302 }
1303
1304 if (word >= kMarkvFirstOpcode) {
1305 if (word == kMarkvOpNextInstructionEncodesResultId) {
1306 read_result_id = true;
1307 } else {
1308 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1309 << "Encountered unknown MARK-V opcode";
1310 }
1311 } else {
1312 inst->opcode = static_cast<uint16_t>(word);
1313 break;
1314 }
1315 }
1316
1317 const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
1318
1319 // Opcode/num_words placeholder, the word will be filled in later.
1320 spirv_.push_back(0);
1321
1322 spv_opcode_desc opcode_desc;
1323 if (grammar_.lookupOpcode(opcode, &opcode_desc)
1324 != SPV_SUCCESS) {
1325 return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
1326 }
1327
1328 spv_operand_pattern_t expected_operands;
1329 expected_operands.reserve(opcode_desc->numTypes);
1330 for (auto i = 0; i < opcode_desc->numTypes; i++)
1331 expected_operands.push_back(opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
1332
1333 if (!OpcodeHasFixedNumberOfOperands(opcode)) {
1334 if (!reader_.ReadVariableWidthU16(&inst->num_operands,
1335 model_->num_operands_chunk_length()))
1336 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1337 << "Failed to read num_operands of instruction";
1338 } else {
1339 inst->num_operands = static_cast<uint16_t>(expected_operands.size());
1340 }
1341
1342 for (size_t operand_index = 0;
1343 operand_index < static_cast<size_t>(inst->num_operands);
1344 ++operand_index) {
1345 assert(!expected_operands.empty());
1346 const spv_operand_type_t type =
1347 spvTakeFirstMatchableOperand(&expected_operands);
1348
1349 const size_t operand_offset = spirv_.size() - instruction_offset;
1350
1351 const spv_result_t decode_result =
1352 DecodeOperand(instruction_offset, operand_offset, inst, type,
1353 &expected_operands, read_result_id);
1354
1355 if (decode_result != SPV_SUCCESS)
1356 return decode_result;
1357 }
1358
1359 assert(inst->num_operands == parsed_operands_.size());
1360
1361 // Only valid while spirv_ and parsed_operands_ remain unchanged.
1362 inst->words = &spirv_[instruction_offset];
1363 inst->operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
1364 inst->num_words = static_cast<uint16_t>(spirv_.size() - instruction_offset);
1365 spirv_[instruction_offset] =
1366 spvOpcodeMake(inst->num_words, SpvOp(inst->opcode));
1367
1368 assert(inst->num_words == std::accumulate(
1369 parsed_operands_.begin(), parsed_operands_.end(), 1,
1370 [](int num_words, const spv_parsed_operand_t& operand) {
1371 return num_words += operand.num_words;
1372 }) && "num_words in instruction doesn't correspond to the sum of num_words"
1373 "in the operands");
1374
1375 RecordNumberType(*inst);
1376
1377 if (!ReadToByteBreakIfAgreed())
1378 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1379 << "Failed to read to byte break";
1380
1381 return SPV_SUCCESS;
1382 }
1383
SetNumericTypeInfoForType(spv_parsed_operand_t * parsed_operand,uint32_t type_id)1384 spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
1385 spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
1386 assert(type_id != 0);
1387 auto type_info_iter = type_id_to_number_type_info_.find(type_id);
1388 if (type_info_iter == type_id_to_number_type_info_.end()) {
1389 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1390 << "Type Id " << type_id << " is not a type";
1391 }
1392
1393 const NumberType& info = type_info_iter->second;
1394 if (info.type == SPV_NUMBER_NONE) {
1395 // This is a valid type, but for something other than a scalar number.
1396 return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1397 << "Type Id " << type_id << " is not a scalar numeric type";
1398 }
1399
1400 parsed_operand->number_kind = info.type;
1401 parsed_operand->number_bit_width = info.bit_width;
1402 // Round up the word count.
1403 parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
1404 return SPV_SUCCESS;
1405 }
1406
RecordNumberType(const spv_parsed_instruction_t & inst)1407 void MarkvDecoder::RecordNumberType(const spv_parsed_instruction_t& inst) {
1408 const SpvOp opcode = static_cast<SpvOp>(inst.opcode);
1409 if (spvOpcodeGeneratesType(opcode)) {
1410 NumberType info = {SPV_NUMBER_NONE, 0};
1411 if (SpvOpTypeInt == opcode) {
1412 info.bit_width = inst.words[inst.operands[1].offset];
1413 info.type = inst.words[inst.operands[2].offset] ?
1414 SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT;
1415 } else if (SpvOpTypeFloat == opcode) {
1416 info.bit_width = inst.words[inst.operands[1].offset];
1417 info.type = SPV_NUMBER_FLOATING;
1418 }
1419 // The *result* Id of a type generating instruction is the type Id.
1420 type_id_to_number_type_info_[inst.result_id] = info;
1421 }
1422 }
1423
EncodeHeader(void * user_data,spv_endianness_t endian,uint32_t magic,uint32_t version,uint32_t generator,uint32_t id_bound,uint32_t schema)1424 spv_result_t EncodeHeader(
1425 void* user_data, spv_endianness_t endian, uint32_t magic,
1426 uint32_t version, uint32_t generator, uint32_t id_bound,
1427 uint32_t schema) {
1428 MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
1429 return encoder->EncodeHeader(
1430 endian, magic, version, generator, id_bound, schema);
1431 }
1432
EncodeInstruction(void * user_data,const spv_parsed_instruction_t * inst)1433 spv_result_t EncodeInstruction(
1434 void* user_data, const spv_parsed_instruction_t* inst) {
1435 MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
1436 return encoder->EncodeInstruction(*inst);
1437 }
1438
1439 } // namespace
1440
spvSpirvToMarkv(spv_const_context context,const uint32_t * spirv_words,const size_t spirv_num_words,spv_const_markv_encoder_options options,spv_markv_binary * markv_binary,spv_text * comments,spv_diagnostic * diagnostic)1441 spv_result_t spvSpirvToMarkv(spv_const_context context,
1442 const uint32_t* spirv_words,
1443 const size_t spirv_num_words,
1444 spv_const_markv_encoder_options options,
1445 spv_markv_binary* markv_binary,
1446 spv_text* comments, spv_diagnostic* diagnostic) {
1447 spv_context_t hijack_context = *context;
1448 if (diagnostic) {
1449 *diagnostic = nullptr;
1450 libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic);
1451 }
1452
1453 spv_const_binary_t spirv_binary = {spirv_words, spirv_num_words};
1454
1455 spv_endianness_t endian;
1456 spv_position_t position = {};
1457 if (spvBinaryEndianness(&spirv_binary, &endian)) {
1458 return libspirv::DiagnosticStream(position, hijack_context.consumer,
1459 SPV_ERROR_INVALID_BINARY)
1460 << "Invalid SPIR-V magic number.";
1461 }
1462
1463 spv_header_t header;
1464 if (spvBinaryHeaderGet(&spirv_binary, endian, &header)) {
1465 return libspirv::DiagnosticStream(position, hijack_context.consumer,
1466 SPV_ERROR_INVALID_BINARY)
1467 << "Invalid SPIR-V header.";
1468 }
1469
1470 MarkvEncoder encoder(&hijack_context, options);
1471
1472 if (comments) {
1473 encoder.CreateCommentsLogger();
1474
1475 spv_text text = nullptr;
1476 if (spvBinaryToText(&hijack_context, spirv_words, spirv_num_words,
1477 SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text, nullptr)
1478 != SPV_SUCCESS) {
1479 return libspirv::DiagnosticStream(position, hijack_context.consumer,
1480 SPV_ERROR_INVALID_BINARY)
1481 << "Failed to disassemble SPIR-V binary.";
1482 }
1483 assert(text);
1484 encoder.SetDisassembly(std::string(text->str, text->length));
1485 spvTextDestroy(text);
1486 }
1487
1488 if (spvBinaryParse(
1489 &hijack_context, &encoder, spirv_words, spirv_num_words, EncodeHeader,
1490 EncodeInstruction, diagnostic) != SPV_SUCCESS) {
1491 return libspirv::DiagnosticStream(position, hijack_context.consumer,
1492 SPV_ERROR_INVALID_BINARY)
1493 << "Unable to encode to MARK-V.";
1494 }
1495
1496 if (comments)
1497 *comments = CreateSpvText(encoder.GetComments());
1498
1499 *markv_binary = encoder.GetMarkvBinary();
1500 return SPV_SUCCESS;
1501 }
1502
spvMarkvToSpirv(spv_const_context context,const uint8_t * markv_data,size_t markv_size_bytes,spv_const_markv_decoder_options options,spv_binary * spirv_binary,spv_text *,spv_diagnostic * diagnostic)1503 spv_result_t spvMarkvToSpirv(spv_const_context context,
1504 const uint8_t* markv_data,
1505 size_t markv_size_bytes,
1506 spv_const_markv_decoder_options options,
1507 spv_binary* spirv_binary,
1508 spv_text* /* comments */, spv_diagnostic* diagnostic) {
1509 spv_position_t position = {};
1510 spv_context_t hijack_context = *context;
1511 if (diagnostic) {
1512 *diagnostic = nullptr;
1513 libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic);
1514 }
1515
1516 MarkvDecoder decoder(&hijack_context, markv_data, markv_size_bytes, options);
1517
1518 std::vector<uint32_t> words;
1519
1520 if (decoder.DecodeModule(&words) != SPV_SUCCESS) {
1521 return libspirv::DiagnosticStream(position, hijack_context.consumer,
1522 SPV_ERROR_INVALID_BINARY)
1523 << "Unable to decode MARK-V.";
1524 }
1525
1526 assert(!words.empty());
1527
1528 *spirv_binary = new spv_binary_t();
1529 (*spirv_binary)->code = new uint32_t[words.size()];
1530 (*spirv_binary)->wordCount = words.size();
1531 std::memcpy((*spirv_binary)->code, words.data(), 4 * words.size());
1532
1533 return SPV_SUCCESS;
1534 }
1535
spvMarkvBinaryDestroy(spv_markv_binary binary)1536 void spvMarkvBinaryDestroy(spv_markv_binary binary) {
1537 if (!binary) return;
1538 delete[] binary->data;
1539 delete binary;
1540 }
1541
spvMarkvEncoderOptionsCreate()1542 spv_markv_encoder_options spvMarkvEncoderOptionsCreate() {
1543 return new spv_markv_encoder_options_t;
1544 }
1545
spvMarkvEncoderOptionsDestroy(spv_markv_encoder_options options)1546 void spvMarkvEncoderOptionsDestroy(spv_markv_encoder_options options) {
1547 delete options;
1548 }
1549
spvMarkvDecoderOptionsCreate()1550 spv_markv_decoder_options spvMarkvDecoderOptionsCreate() {
1551 return new spv_markv_decoder_options_t;
1552 }
1553
spvMarkvDecoderOptionsDestroy(spv_markv_decoder_options options)1554 void spvMarkvDecoderOptionsDestroy(spv_markv_decoder_options options) {
1555 delete options;
1556 }
1557