• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Contains
16 //   - SPIR-V to MARK-V encoder
17 //   - MARK-V to SPIR-V decoder
18 //
19 // MARK-V is a compression format for SPIR-V binaries. It strips away
20 // non-essential information (such as result ids which can be regenerated) and
21 // uses various bit reduction techiniques to reduce the size of the binary.
22 //
23 // MarkvModel is a flatbuffers object containing a set of rules defining how
24 // compression/decompression is done (coding schemes, dictionaries).
25 
26 #include <algorithm>
27 #include <cassert>
28 #include <cstring>
29 #include <functional>
30 #include <iostream>
31 #include <list>
32 #include <memory>
33 #include <numeric>
34 #include <string>
35 #include <vector>
36 
37 #include "binary.h"
38 #include "diagnostic.h"
39 #include "enum_string_mapping.h"
40 #include "extensions.h"
41 #include "ext_inst.h"
42 #include "instruction.h"
43 #include "opcode.h"
44 #include "operand.h"
45 #include "spirv-tools/libspirv.h"
46 #include "spirv-tools/markv.h"
47 #include "spirv_endian.h"
48 #include "spirv_validator_options.h"
49 #include "util/bit_stream.h"
50 #include "util/parse_number.h"
51 #include "validate.h"
52 #include "val/instruction.h"
53 #include "val/validation_state.h"
54 
55 using libspirv::Instruction;
56 using libspirv::ValidationState_t;
57 using spvtools::ValidateInstructionAndUpdateValidationState;
58 using spvutils::BitReaderWord64;
59 using spvutils::BitWriterWord64;
60 
61 struct spv_markv_encoder_options_t {
62 };
63 
64 struct spv_markv_decoder_options_t {
65 };
66 
67 namespace {
68 
69 const uint32_t kSpirvMagicNumber = SpvMagicNumber;
70 const uint32_t kMarkvMagicNumber = 0x07230303;
71 
72 enum {
73   kMarkvFirstOpcode = 65536,
74   kMarkvOpNextInstructionEncodesResultId = 65536,
75 };
76 
77 const size_t kCommentNumWhitespaces = 2;
78 
79 // TODO(atgoo@github.com): This is a placeholder for an autogenerated flatbuffer
80 // containing MARK-V model for a specific dataset.
81 class MarkvModel {
82  public:
opcode_chunk_length() const83   size_t opcode_chunk_length() const { return 7; }
num_operands_chunk_length() const84   size_t num_operands_chunk_length() const { return 3; }
id_index_chunk_length() const85   size_t id_index_chunk_length() const { return 3; }
86 
u16_chunk_length() const87   size_t u16_chunk_length() const { return 4; }
s16_chunk_length() const88   size_t s16_chunk_length() const { return 4; }
s16_block_exponent() const89   size_t s16_block_exponent() const { return 6; }
90 
u32_chunk_length() const91   size_t u32_chunk_length() const { return 8; }
s32_chunk_length() const92   size_t s32_chunk_length() const { return 8; }
s32_block_exponent() const93   size_t s32_block_exponent() const { return 10; }
94 
u64_chunk_length() const95   size_t u64_chunk_length() const { return 8; }
s64_chunk_length() const96   size_t s64_chunk_length() const { return 8; }
s64_block_exponent() const97   size_t s64_block_exponent() const { return 10; }
98 };
99 
GetDefaultModel()100 const MarkvModel* GetDefaultModel() {
101   static MarkvModel model;
102   return &model;
103 }
104 
105 // Returns chunk length used for variable length encoding of spirv operand
106 // words. Returns zero if operand type corresponds to potentially multiple
107 // words or a word which is not expected to profit from variable width encoding.
108 // Chunk length is selected based on the size of expected value.
109 // Most of these values will later be encoded with probability-based coding,
110 // but variable width integer coding is a good quick solution.
111 // TODO(atgoo@github.com): Put this in MarkvModel flatbuffer.
GetOperandVariableWidthChunkLength(spv_operand_type_t type)112 size_t GetOperandVariableWidthChunkLength(spv_operand_type_t type) {
113   switch (type) {
114     case SPV_OPERAND_TYPE_TYPE_ID:
115       return 4;
116     case SPV_OPERAND_TYPE_RESULT_ID:
117     case SPV_OPERAND_TYPE_ID:
118     case SPV_OPERAND_TYPE_SCOPE_ID:
119     case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID:
120       return 8;
121     case SPV_OPERAND_TYPE_LITERAL_INTEGER:
122     case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER:
123       return 6;
124     case SPV_OPERAND_TYPE_CAPABILITY:
125       return 6;
126     case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
127     case SPV_OPERAND_TYPE_EXECUTION_MODEL:
128       return 3;
129     case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
130     case SPV_OPERAND_TYPE_MEMORY_MODEL:
131       return 2;
132     case SPV_OPERAND_TYPE_EXECUTION_MODE:
133       return 6;
134     case SPV_OPERAND_TYPE_STORAGE_CLASS:
135       return 4;
136     case SPV_OPERAND_TYPE_DIMENSIONALITY:
137     case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
138       return 3;
139     case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
140       return 2;
141     case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
142       return 6;
143     case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
144     case SPV_OPERAND_TYPE_LINKAGE_TYPE:
145     case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
146     case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
147       return 2;
148     case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
149       return 3;
150     case SPV_OPERAND_TYPE_DECORATION:
151     case SPV_OPERAND_TYPE_BUILT_IN:
152       return 6;
153     case SPV_OPERAND_TYPE_GROUP_OPERATION:
154     case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
155     case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO:
156       return 2;
157     case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
158     case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
159     case SPV_OPERAND_TYPE_LOOP_CONTROL:
160     case SPV_OPERAND_TYPE_IMAGE:
161     case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
162     case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
163     case SPV_OPERAND_TYPE_SELECTION_CONTROL:
164       return 4;
165     case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER:
166       return 6;
167     default:
168       return 0;
169   }
170   return 0;
171 }
172 
173 // Returns true if the opcode has a fixed number of operands. May return a
174 // false negative.
OpcodeHasFixedNumberOfOperands(SpvOp opcode)175 bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) {
176   switch (opcode) {
177     // TODO(atgoo@github.com) This is not a complete list.
178     case SpvOpNop:
179     case SpvOpName:
180     case SpvOpUndef:
181     case SpvOpSizeOf:
182     case SpvOpLine:
183     case SpvOpNoLine:
184     case SpvOpDecorationGroup:
185     case SpvOpExtension:
186     case SpvOpExtInstImport:
187     case SpvOpMemoryModel:
188     case SpvOpCapability:
189     case SpvOpTypeVoid:
190     case SpvOpTypeBool:
191     case SpvOpTypeInt:
192     case SpvOpTypeFloat:
193     case SpvOpTypeVector:
194     case SpvOpTypeMatrix:
195     case SpvOpTypeSampler:
196     case SpvOpTypeSampledImage:
197     case SpvOpTypeArray:
198     case SpvOpTypePointer:
199     case SpvOpConstantTrue:
200     case SpvOpConstantFalse:
201     case SpvOpLabel:
202     case SpvOpBranch:
203     case SpvOpFunction:
204     case SpvOpFunctionParameter:
205     case SpvOpFunctionEnd:
206     case SpvOpBitcast:
207     case SpvOpCopyObject:
208     case SpvOpTranspose:
209     case SpvOpSNegate:
210     case SpvOpFNegate:
211     case SpvOpIAdd:
212     case SpvOpFAdd:
213     case SpvOpISub:
214     case SpvOpFSub:
215     case SpvOpIMul:
216     case SpvOpFMul:
217     case SpvOpUDiv:
218     case SpvOpSDiv:
219     case SpvOpFDiv:
220     case SpvOpUMod:
221     case SpvOpSRem:
222     case SpvOpSMod:
223     case SpvOpFRem:
224     case SpvOpFMod:
225     case SpvOpVectorTimesScalar:
226     case SpvOpMatrixTimesScalar:
227     case SpvOpVectorTimesMatrix:
228     case SpvOpMatrixTimesVector:
229     case SpvOpMatrixTimesMatrix:
230     case SpvOpOuterProduct:
231     case SpvOpDot:
232       return true;
233     default:
234       break;
235   }
236   return false;
237 }
238 
GetNumBitsToNextByte(size_t bit_pos)239 size_t GetNumBitsToNextByte(size_t bit_pos) {
240   return (8 - (bit_pos % 8)) % 8;
241 }
242 
ShouldByteBreak(size_t bit_pos)243 bool ShouldByteBreak(size_t bit_pos) {
244   const size_t num_bits_to_next_byte = GetNumBitsToNextByte(bit_pos);
245   return num_bits_to_next_byte > 0; // && num_bits_to_next_byte <= 2;
246 }
247 
248 // Defines and returns current MARK-V version.
GetMarkvVersion()249 uint32_t GetMarkvVersion() {
250   const uint32_t kVersionMajor = 1;
251   const uint32_t kVersionMinor = 0;
252   return kVersionMinor | (kVersionMajor << 16);
253 }
254 
255 class CommentLogger {
256  public:
AppendText(const std::string & str)257   void AppendText(const std::string& str) {
258     Append(str);
259     use_delimiter_ = false;
260   }
261 
AppendTextNewLine(const std::string & str)262   void AppendTextNewLine(const std::string& str) {
263     Append(str);
264     Append("\n");
265     use_delimiter_ = false;
266   }
267 
AppendBitSequence(const std::string & str)268   void AppendBitSequence(const std::string& str) {
269     if (use_delimiter_)
270       Append("-");
271     Append(str);
272     use_delimiter_ = true;
273   }
274 
AppendWhitespaces(size_t num)275   void AppendWhitespaces(size_t num) {
276     Append(std::string(num, ' '));
277     use_delimiter_ = false;
278   }
279 
NewLine()280   void NewLine() {
281     Append("\n");
282     use_delimiter_ = false;
283   }
284 
GetText() const285   std::string GetText() const {
286     return ss_.str();
287   }
288 
289  private:
Append(const std::string & str)290   void Append(const std::string& str) {
291     ss_ << str;
292     // std::cerr << str;
293   }
294 
295   std::stringstream ss_;
296 
297   // If true a delimiter will be appended before the next bit sequence.
298   // Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0.
299   bool use_delimiter_ = false;
300 };
301 
302 // Creates spv_text object containing text from |str|.
303 // The returned value is owned by the caller and needs to be destroyed with
304 // spvTextDestroy.
CreateSpvText(const std::string & str)305 spv_text CreateSpvText(const std::string& str) {
306   spv_text out = new spv_text_t();
307   assert(out);
308   char* cstr = new char[str.length() + 1];
309   assert(cstr);
310   std::strncpy(cstr, str.c_str(), str.length());
311   cstr[str.length()] = '\0';
312   out->str = cstr;
313   out->length = str.length();
314   return out;
315 }
316 
317 // Base class for MARK-V encoder and decoder. Contains common functionality
318 // such as:
319 // - Validator connection and validation state.
320 // - SPIR-V grammar and helper functions.
321 class MarkvCodecBase {
322  public:
~MarkvCodecBase()323   virtual ~MarkvCodecBase() {
324     spvValidatorOptionsDestroy(validator_options_);
325   }
326 
327   MarkvCodecBase() = delete;
328 
SetModel(const MarkvModel * model)329   void SetModel(const MarkvModel* model) {
330     model_ = model;
331   }
332 
333  protected:
334   struct MarkvHeader {
MarkvHeader__anon5783fbf10111::MarkvCodecBase::MarkvHeader335     MarkvHeader() {
336       magic_number = kMarkvMagicNumber;
337       markv_version = GetMarkvVersion();
338       markv_model = 0;
339       markv_length_in_bits = 0;
340       spirv_version = 0;
341       spirv_generator = 0;
342     }
343 
344     uint32_t magic_number;
345     uint32_t markv_version;
346     // Magic number to identify or verify MarkvModel used for encoding.
347     uint32_t markv_model;
348     uint32_t markv_length_in_bits;
349     uint32_t spirv_version;
350     uint32_t spirv_generator;
351   };
352 
MarkvCodecBase(spv_const_context context,spv_validator_options validator_options)353   explicit MarkvCodecBase(spv_const_context context,
354                           spv_validator_options validator_options)
355       : validator_options_(validator_options),
356         vstate_(context, validator_options_), grammar_(context),
357         model_(GetDefaultModel()) {}
358 
359   // Validates a single instruction and updates validation state of the module.
UpdateValidationState(const spv_parsed_instruction_t & inst)360   spv_result_t UpdateValidationState(const spv_parsed_instruction_t& inst) {
361     return ValidateInstructionAndUpdateValidationState(&vstate_, &inst);
362   }
363 
364   // Returns the current instruction (the one last processed by the validator).
GetCurrentInstruction() const365   const Instruction& GetCurrentInstruction() const {
366     return vstate_.ordered_instructions().back();
367   }
368 
369   spv_validator_options validator_options_;
370   ValidationState_t vstate_;
371   const libspirv::AssemblyGrammar grammar_;
372   MarkvHeader header_;
373   const MarkvModel* model_;
374 
375   // Move-to-front list of all ids.
376   // TODO(atgoo@github.com) Consider a better move-to-front implementation.
377   std::list<uint32_t> move_to_front_ids_;
378 };
379 
380 // SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and
381 // EncodeInstruction which can be used as callback by spvBinaryParse.
382 // Encoded binary is written to an internally maintained bitstream.
383 // After the last instruction is encoded, the resulting MARK-V binary can be
384 // acquired by calling GetMarkvBinary().
385 // The encoder uses SPIR-V validator to keep internal state, therefore
386 // SPIR-V binary needs to be able to pass validator checks.
387 // CreateCommentsLogger() can be used to enable the encoder to write comments
388 // on how encoding was done, which can later be accessed with GetComments().
389 class MarkvEncoder : public MarkvCodecBase {
390  public:
MarkvEncoder(spv_const_context context,spv_const_markv_encoder_options options)391   MarkvEncoder(spv_const_context context,
392                spv_const_markv_encoder_options options)
393       : MarkvCodecBase(context, GetValidatorOptions(options)),
394         options_(options) {
395     (void) options_;
396   }
397 
398   // Writes data from SPIR-V header to MARK-V header.
EncodeHeader(spv_endianness_t,uint32_t,uint32_t version,uint32_t generator,uint32_t id_bound,uint32_t)399   spv_result_t EncodeHeader(
400       spv_endianness_t /* endian */, uint32_t /* magic */,
401       uint32_t version, uint32_t generator, uint32_t id_bound,
402       uint32_t /* schema */) {
403     vstate_.setIdBound(id_bound);
404     header_.spirv_version = version;
405     header_.spirv_generator = generator;
406     return SPV_SUCCESS;
407   }
408 
409   // Encodes SPIR-V instruction to MARK-V and writes to bit stream.
410   // Operation can fail if the instruction fails to pass the validator or if
411   // the encoder stubmles on something unexpected.
412   spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst);
413 
414   // Concatenates MARK-V header and the bit stream with encoded instructions
415   // into a single buffer and returns it as spv_markv_binary. The returned
416   // value is owned by the caller and needs to be destroyed with
417   // spvMarkvBinaryDestroy().
GetMarkvBinary()418   spv_markv_binary GetMarkvBinary() {
419     header_.markv_length_in_bits =
420         static_cast<uint32_t>(sizeof(header_) * 8 + writer_.GetNumBits());
421     const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes();
422 
423     spv_markv_binary markv_binary = new spv_markv_binary_t();
424     markv_binary->data = new uint8_t[num_bytes];
425     markv_binary->length = num_bytes;
426     assert(writer_.GetData());
427     std::memcpy(markv_binary->data, &header_, sizeof(header_));
428     std::memcpy(markv_binary->data + sizeof(header_),
429            writer_.GetData(), writer_.GetDataSizeBytes());
430     return markv_binary;
431   }
432 
433   // Creates an internal logger which writes comments on the encoding process.
434   // Output can later be accessed with GetComments().
CreateCommentsLogger()435   void CreateCommentsLogger() {
436     logger_.reset(new CommentLogger());
437     writer_.SetCallback([this](const std::string& str){
438       logger_->AppendBitSequence(str);
439     });
440   }
441 
442   // Optionally adds disassembly to the comments.
443   // Disassembly should contain all instructions in the module separated by
444   // \n, and no header.
SetDisassembly(std::string && disassembly)445   void SetDisassembly(std::string&& disassembly) {
446     disassembly_.reset(new std::stringstream(std::move(disassembly)));
447   }
448 
449   // Extracts the next instruction line from the disassembly and logs it.
LogDisassemblyInstruction()450   void LogDisassemblyInstruction() {
451     if (logger_ && disassembly_) {
452       std::string line;
453       std::getline(*disassembly_, line, '\n');
454       logger_->AppendTextNewLine(line);
455     }
456   }
457 
458   // Extracts the text from the comment logger.
GetComments() const459   std::string GetComments() const {
460     if (!logger_)
461       return "";
462     return logger_->GetText();
463   }
464 
465  private:
466   // Creates and returns validator options. Return value owned by the caller.
GetValidatorOptions(spv_const_markv_encoder_options)467   static spv_validator_options GetValidatorOptions(
468       spv_const_markv_encoder_options) {
469     return spvValidatorOptionsCreate();
470   }
471 
472   // Writes a single word to bit stream. |type| determines if the word is
473   // encoded and how.
EncodeOperandWord(spv_operand_type_t type,uint32_t word)474   void EncodeOperandWord(spv_operand_type_t type, uint32_t word) {
475     const size_t chunk_length =
476         GetOperandVariableWidthChunkLength(type);
477     if (chunk_length) {
478       writer_.WriteVariableWidthU32(word, chunk_length);
479     } else {
480       writer_.WriteUnencoded(word);
481     }
482   }
483 
484   // Returns id index and updates move-to-front.
485   // Index is uint16 as SPIR-V module is guaranteed to have no more than 65535
486   // instructions.
GetIdIndex(uint32_t id)487   uint16_t GetIdIndex(uint32_t id) {
488     if (all_known_ids_.count(id)) {
489       uint16_t index = 0;
490       for (auto it = move_to_front_ids_.begin();
491            it != move_to_front_ids_.end(); ++it) {
492         if (*it == id) {
493           if (index != 0) {
494             move_to_front_ids_.erase(it);
495             move_to_front_ids_.push_front(id);
496           }
497           return index;
498         }
499         ++index;
500       }
501       assert(0 && "Id not found in move_to_front_ids_");
502       return 0;
503     } else {
504       all_known_ids_.insert(id);
505       move_to_front_ids_.push_front(id);
506       return static_cast<uint16_t>(move_to_front_ids_.size() - 1);
507     }
508   }
509 
AddByteBreakIfAgreed()510   void AddByteBreakIfAgreed() {
511     if (!ShouldByteBreak(writer_.GetNumBits()))
512       return;
513 
514     if (logger_) {
515       logger_->AppendWhitespaces(kCommentNumWhitespaces);
516       logger_->AppendText("ByteBreak:");
517     }
518 
519     writer_.WriteBits(0, GetNumBitsToNextByte(writer_.GetNumBits()));
520   }
521 
522   // Encodes a literal number operand and writes it to the bit stream.
523   void EncodeLiteralNumber(const Instruction& instruction,
524                            const spv_parsed_operand_t& operand);
525 
526   spv_const_markv_encoder_options options_;
527 
528   // Bit stream where encoded instructions are written.
529   BitWriterWord64 writer_;
530 
531   // If not nullptr, encoder will write comments.
532   std::unique_ptr<CommentLogger> logger_;
533 
534   // If not nullptr, disassembled instruction lines will be written to comments.
535   // Format: \n separated instruction lines, no header.
536   std::unique_ptr<std::stringstream> disassembly_;
537 
538   // All ids which were previosly encountered in the module.
539   std::unordered_set<uint32_t> all_known_ids_;
540 };
541 
542 // Decodes MARK-V buffers written by MarkvEncoder.
543 class MarkvDecoder : public MarkvCodecBase {
544  public:
MarkvDecoder(spv_const_context context,const uint8_t * markv_data,size_t markv_size_bytes,spv_const_markv_decoder_options options)545   MarkvDecoder(spv_const_context context,
546                const uint8_t* markv_data,
547                size_t markv_size_bytes,
548                spv_const_markv_decoder_options options)
549       : MarkvCodecBase(context, GetValidatorOptions(options)),
550         options_(options), reader_(markv_data, markv_size_bytes) {
551     (void) options_;
552     vstate_.setIdBound(1);
553     parsed_operands_.reserve(25);
554   }
555 
556   // Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|.
557   // Can be called only once. Fails if data of wrong format or ends prematurely,
558   // of if validation fails.
559   spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary);
560 
561  private:
562   // Describes the format of a typed literal number.
563   struct NumberType {
564     spv_number_kind_t type;
565     uint32_t bit_width;
566   };
567 
568   // Creates and returns validator options. Return value owned by the caller.
GetValidatorOptions(spv_const_markv_decoder_options)569   static spv_validator_options GetValidatorOptions(
570       spv_const_markv_decoder_options) {
571     return spvValidatorOptionsCreate();
572   }
573 
574   // Reads a single word from bit stream. |type| determines if the word needs
575   // to be decoded and how. Returns false if read fails.
DecodeOperandWord(spv_operand_type_t type,uint32_t * word)576   bool DecodeOperandWord(spv_operand_type_t type, uint32_t* word) {
577     const size_t chunk_length = GetOperandVariableWidthChunkLength(type);
578     if (chunk_length) {
579       return reader_.ReadVariableWidthU32(word, chunk_length);
580     } else {
581       return reader_.ReadUnencoded(word);
582     }
583   }
584 
585   // Fetches the id from the move-to-front list and moves it to front.
GetIdAndMoveToFront(uint16_t index)586   uint32_t GetIdAndMoveToFront(uint16_t index) {
587     if (index >= move_to_front_ids_.size()) {
588       // Issue new id.
589       const uint32_t id = vstate_.getIdBound();
590       move_to_front_ids_.push_front(id);
591       vstate_.setIdBound(id + 1);
592       return id;
593     } else {
594       if (index == 0)
595         return move_to_front_ids_.front();
596 
597       // Iterate to index.
598       auto it = move_to_front_ids_.begin();
599       for (size_t i = 0; i < index; ++i)
600         ++it;
601       const uint32_t id = *it;
602       move_to_front_ids_.erase(it);
603       move_to_front_ids_.push_front(id);
604       return id;
605     }
606   }
607 
608   // Decodes id index and fetches the id from move-to-front list.
DecodeId(uint32_t * id)609   bool DecodeId(uint32_t* id) {
610     uint16_t index = 0;
611     if (!reader_.ReadVariableWidthU16(&index, model_->id_index_chunk_length()))
612        return false;
613 
614     *id = GetIdAndMoveToFront(index);
615     return true;
616   }
617 
ReadToByteBreakIfAgreed()618   bool ReadToByteBreakIfAgreed() {
619     if (!ShouldByteBreak(reader_.GetNumReadBits()))
620       return true;
621 
622     uint64_t bits = 0;
623     if (!reader_.ReadBits(&bits,
624                           GetNumBitsToNextByte(reader_.GetNumReadBits())))
625       return false;
626 
627     if (bits != 0)
628       return false;
629 
630     return true;
631   }
632 
633   // Reads a literal number as it is described in |operand| from the bit stream,
634   // decodes and writes it to spirv_.
635   spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand);
636 
637   // Reads instruction from bit stream, decodes and validates it.
638   // Decoded instruction is valid until the next call of DecodeInstruction().
639   spv_result_t DecodeInstruction(spv_parsed_instruction_t* inst);
640 
641   // Read operand from the stream decodes and validates it.
642   spv_result_t DecodeOperand(size_t instruction_offset, size_t operand_offset,
643                              spv_parsed_instruction_t* inst,
644                              const spv_operand_type_t type,
645                              spv_operand_pattern_t* expected_operands,
646                              bool read_result_id);
647 
648   // Records the numeric type for an operand according to the type information
649   // associated with the given non-zero type Id.  This can fail if the type Id
650   // is not a type Id, or if the type Id does not reference a scalar numeric
651   // type.  On success, return SPV_SUCCESS and populates the num_words,
652   // number_kind, and number_bit_width fields of parsed_operand.
653   spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand,
654                                          uint32_t type_id);
655 
656   // Records the number type for the given instruction, if that
657   // instruction generates a type.  For types that aren't scalar numbers,
658   // record something with number kind SPV_NUMBER_NONE.
659   void RecordNumberType(const spv_parsed_instruction_t& inst);
660 
661   spv_const_markv_decoder_options options_;
662 
663   // Temporary sink where decoded SPIR-V words are written. Once it contains the
664   // entire module, the container is moved and returned.
665   std::vector<uint32_t> spirv_;
666 
667   // Bit stream containing encoded data.
668   BitReaderWord64 reader_;
669 
670   // Temporary storage for operands of the currently parsed instruction.
671   // Valid until next DecodeInstruction call.
672   std::vector<spv_parsed_operand_t> parsed_operands_;
673 
674   // Maps a result ID to its type ID.  By convention:
675   //  - a result ID that is a type definition maps to itself.
676   //  - a result ID without a type maps to 0.  (E.g. for OpLabel)
677   std::unordered_map<uint32_t, uint32_t> id_to_type_id_;
678   // Maps a type ID to its number type description.
679   std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_;
680   // Maps an ExtInstImport id to the extended instruction type.
681   std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_;
682 };
683 
EncodeLiteralNumber(const Instruction & instruction,const spv_parsed_operand_t & operand)684 void MarkvEncoder::EncodeLiteralNumber(const Instruction& instruction,
685                                        const spv_parsed_operand_t& operand) {
686   if (operand.number_bit_width == 32) {
687     const uint32_t word = instruction.word(operand.offset);
688     if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
689       writer_.WriteVariableWidthU32(word, model_->u32_chunk_length());
690     } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
691       int32_t val = 0;
692       std::memcpy(&val, &word, 4);
693       writer_.WriteVariableWidthS32(val, model_->s32_chunk_length(),
694                                     model_->s32_block_exponent());
695     } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
696       writer_.WriteUnencoded(word);
697     } else {
698       assert(0);
699     }
700   } else if (operand.number_bit_width == 16) {
701     const uint16_t word =
702         static_cast<uint16_t>(instruction.word(operand.offset));
703     if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
704       writer_.WriteVariableWidthU16(word, model_->u16_chunk_length());
705     } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
706       int16_t val = 0;
707       std::memcpy(&val, &word, 2);
708       writer_.WriteVariableWidthS16(val, model_->s16_chunk_length(),
709                                     model_->s16_block_exponent());
710     } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
711       // TODO(atgoo@github.com) Write only 16 bits.
712       writer_.WriteUnencoded(word);
713     } else {
714       assert(0);
715     }
716   } else {
717     assert(operand.number_bit_width == 64);
718     const uint64_t word =
719         uint64_t(instruction.word(operand.offset)) |
720         (uint64_t(instruction.word(operand.offset + 1)) << 32);
721     if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
722       writer_.WriteVariableWidthU64(word, model_->u64_chunk_length());
723     } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
724       int64_t val = 0;
725       std::memcpy(&val, &word, 8);
726       writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(),
727                                     model_->s64_block_exponent());
728     } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
729       writer_.WriteUnencoded(word);
730     } else {
731       assert(0);
732     }
733   }
734 }
735 
EncodeInstruction(const spv_parsed_instruction_t & inst)736 spv_result_t MarkvEncoder::EncodeInstruction(
737     const spv_parsed_instruction_t& inst) {
738   const spv_result_t validation_result = UpdateValidationState(inst);
739   if (validation_result != SPV_SUCCESS)
740     return validation_result;
741 
742   bool result_id_was_forward_declared = false;
743   if (all_known_ids_.count(inst.result_id)) {
744     // Result id of the instruction was forward declared.
745     // Write a service opcode to signal this to the decoder.
746     writer_.WriteVariableWidthU32(kMarkvOpNextInstructionEncodesResultId,
747                                   model_->opcode_chunk_length());
748     result_id_was_forward_declared = true;
749   }
750 
751   const Instruction& instruction = GetCurrentInstruction();
752   const auto& operands = instruction.operands();
753 
754   LogDisassemblyInstruction();
755 
756   // Write opcode.
757   writer_.WriteVariableWidthU32(inst.opcode, model_->opcode_chunk_length());
758 
759   if (!OpcodeHasFixedNumberOfOperands(SpvOp(inst.opcode))) {
760     // If the opcode has a variable number of operands, encode the number of
761     // operands with the instruction.
762 
763     if (logger_)
764       logger_->AppendWhitespaces(kCommentNumWhitespaces);
765 
766     writer_.WriteVariableWidthU16(inst.num_operands,
767                                   model_->num_operands_chunk_length());
768   }
769 
770   // Write operands.
771   for (const auto& operand : operands) {
772     if (operand.type == SPV_OPERAND_TYPE_RESULT_ID &&
773         !result_id_was_forward_declared) {
774       // Register the id, but don't encode it.
775       GetIdIndex(instruction.word(operand.offset));
776       continue;
777     }
778 
779     if (logger_)
780       logger_->AppendWhitespaces(kCommentNumWhitespaces);
781 
782     if (operand.type == SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER) {
783       EncodeLiteralNumber(instruction, operand);
784     } else if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING) {
785       const char* src =
786           reinterpret_cast<const char*>(&instruction.words()[operand.offset]);
787       const size_t length = spv_strnlen_s(src, operand.num_words * 4);
788       if (length == operand.num_words * 4)
789         return vstate_.diag(SPV_ERROR_INVALID_BINARY)
790             << "Failed to find terminal character of literal string";
791       for (size_t i = 0; i < length + 1; ++i)
792         writer_.WriteUnencoded(src[i]);
793     } else if (spvIsIdType(operand.type)) {
794       const uint16_t id_index = GetIdIndex(instruction.word(operand.offset));
795       writer_.WriteVariableWidthU16(id_index, model_->id_index_chunk_length());
796     } else {
797       for (int i = 0; i < operand.num_words; ++i) {
798         const uint32_t word = instruction.word(operand.offset + i);
799         EncodeOperandWord(operand.type, word);
800       }
801     }
802   }
803 
804   AddByteBreakIfAgreed();
805 
806   if (logger_) {
807     logger_->NewLine();
808     logger_->NewLine();
809   }
810 
811   return SPV_SUCCESS;
812 }
813 
DecodeLiteralNumber(const spv_parsed_operand_t & operand)814 spv_result_t MarkvDecoder::DecodeLiteralNumber(
815     const spv_parsed_operand_t& operand) {
816   if (operand.number_bit_width == 32) {
817     uint32_t word = 0;
818     if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
819       if (!reader_.ReadVariableWidthU32(&word, model_->u32_chunk_length()))
820         return vstate_.diag(SPV_ERROR_INVALID_BINARY)
821             << "Failed to read literal U32";
822     } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
823       int32_t val = 0;
824       if (!reader_.ReadVariableWidthS32(&val, model_->s32_chunk_length(),
825                                         model_->s32_block_exponent()))
826         return vstate_.diag(SPV_ERROR_INVALID_BINARY)
827             << "Failed to read literal S32";
828       std::memcpy(&word, &val, 4);
829     } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
830       if (!reader_.ReadUnencoded(&word))
831         return vstate_.diag(SPV_ERROR_INVALID_BINARY)
832             << "Failed to read literal F32";
833     } else {
834       assert(0);
835     }
836     spirv_.push_back(word);
837   } else if (operand.number_bit_width == 16) {
838     uint32_t word = 0;
839     if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
840       uint16_t val = 0;
841       if (!reader_.ReadVariableWidthU16(&val, model_->u16_chunk_length()))
842         return vstate_.diag(SPV_ERROR_INVALID_BINARY)
843             << "Failed to read literal U16";
844       word = val;
845     } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
846       int16_t val = 0;
847       if (!reader_.ReadVariableWidthS16(&val, model_->s16_chunk_length(),
848                                         model_->s16_block_exponent()))
849         return vstate_.diag(SPV_ERROR_INVALID_BINARY)
850             << "Failed to read literal S16";
851       // Int16 is stored as int32 in SPIR-V, not as bits.
852       int32_t val32 = val;
853       std::memcpy(&word, &val32, 4);
854     } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
855       uint16_t word16 = 0;
856       if (!reader_.ReadUnencoded(&word16))
857         return vstate_.diag(SPV_ERROR_INVALID_BINARY)
858             << "Failed to read literal F16";
859       word = word16;
860     } else {
861       assert(0);
862     }
863     spirv_.push_back(word);
864   } else {
865     assert(operand.number_bit_width == 64);
866     uint64_t word = 0;
867     if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
868       if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
869         return vstate_.diag(SPV_ERROR_INVALID_BINARY)
870             << "Failed to read literal U64";
871     } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
872       int64_t val = 0;
873       if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
874                                         model_->s64_block_exponent()))
875         return vstate_.diag(SPV_ERROR_INVALID_BINARY)
876             << "Failed to read literal S64";
877       std::memcpy(&word, &val, 8);
878     } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
879       if (!reader_.ReadUnencoded(&word))
880         return vstate_.diag(SPV_ERROR_INVALID_BINARY)
881             << "Failed to read literal F64";
882     } else {
883       assert(0);
884     }
885     spirv_.push_back(static_cast<uint32_t>(word));
886     spirv_.push_back(static_cast<uint32_t>(word >> 32));
887   }
888   return SPV_SUCCESS;
889 }
890 
DecodeModule(std::vector<uint32_t> * spirv_binary)891 spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
892   const bool header_read_success =
893       reader_.ReadUnencoded(&header_.magic_number) &&
894       reader_.ReadUnencoded(&header_.markv_version) &&
895       reader_.ReadUnencoded(&header_.markv_model) &&
896       reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
897       reader_.ReadUnencoded(&header_.spirv_version) &&
898       reader_.ReadUnencoded(&header_.spirv_generator);
899 
900   if (!header_read_success)
901     return vstate_.diag(SPV_ERROR_INVALID_BINARY)
902         << "Unable to read MARK-V header";
903 
904   assert(header_.magic_number == kMarkvMagicNumber);
905   assert(header_.markv_length_in_bits > 0);
906 
907   if (header_.magic_number != kMarkvMagicNumber)
908     return vstate_.diag(SPV_ERROR_INVALID_BINARY)
909         << "MARK-V binary has incorrect magic number";
910 
911   // TODO(atgoo@github.com): Print version strings.
912   if (header_.markv_version != GetMarkvVersion())
913     return vstate_.diag(SPV_ERROR_INVALID_BINARY)
914         << "MARK-V binary and the codec have different versions";
915 
916   spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic.
917   spirv_.resize(5, 0);
918   spirv_[0] = kSpirvMagicNumber;
919   spirv_[1] = header_.spirv_version;
920   spirv_[2] = header_.spirv_generator;
921 
922   while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
923     spv_parsed_instruction_t inst = {};
924     const spv_result_t decode_result = DecodeInstruction(&inst);
925     if (decode_result != SPV_SUCCESS)
926       return decode_result;
927 
928     const spv_result_t validation_result = UpdateValidationState(inst);
929     if (validation_result != SPV_SUCCESS)
930       return validation_result;
931   }
932 
933 
934   if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
935       !reader_.OnlyZeroesLeft()) {
936     return vstate_.diag(SPV_ERROR_INVALID_BINARY)
937         << "MARK-V binary has wrong stated bit length "
938         << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
939   }
940 
941   // Decoding of the module is finished, validation state should have correct
942   // id bound.
943   spirv_[3] = vstate_.getIdBound();
944 
945   *spirv_binary = std::move(spirv_);
946   return SPV_SUCCESS;
947 }
948 
949 // TODO(atgoo@github.com): The implementation borrows heavily from
950 // Parser::parseOperand.
951 // Consider coupling them together in some way once MARK-V codec is more mature.
952 // For now it's better to keep the code independent for experimentation
953 // purposes.
DecodeOperand(size_t instruction_offset,size_t operand_offset,spv_parsed_instruction_t * inst,const spv_operand_type_t type,spv_operand_pattern_t * expected_operands,bool read_result_id)954 spv_result_t MarkvDecoder::DecodeOperand(
955     size_t instruction_offset, size_t operand_offset,
956     spv_parsed_instruction_t* inst, const spv_operand_type_t type,
957     spv_operand_pattern_t* expected_operands,
958     bool read_result_id) {
959   const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
960 
961   spv_parsed_operand_t parsed_operand;
962   memset(&parsed_operand, 0, sizeof(parsed_operand));
963 
964   assert((operand_offset >> 16) == 0);
965   parsed_operand.offset = static_cast<uint16_t>(operand_offset);
966   parsed_operand.type = type;
967 
968   // Set default values, may be updated later.
969   parsed_operand.number_kind = SPV_NUMBER_NONE;
970   parsed_operand.number_bit_width = 0;
971 
972   const size_t first_word_index = spirv_.size();
973 
974   switch (type) {
975     case SPV_OPERAND_TYPE_TYPE_ID: {
976       if (!DecodeId(&inst->type_id)) {
977         return vstate_.diag(SPV_ERROR_INVALID_BINARY)
978             << "Failed to read type_id";
979       }
980 
981       if (inst->type_id == 0)
982         return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded type_id is 0";
983 
984       spirv_.push_back(inst->type_id);
985       vstate_.setIdBound(std::max(vstate_.getIdBound(), inst->type_id + 1));
986       break;
987     }
988 
989     case SPV_OPERAND_TYPE_RESULT_ID: {
990       if (read_result_id) {
991         if (!DecodeId(&inst->result_id))
992           return vstate_.diag(SPV_ERROR_INVALID_BINARY)
993               << "Failed to read result_id";
994       } else {
995         inst->result_id = vstate_.getIdBound();
996         vstate_.setIdBound(inst->result_id + 1);
997         move_to_front_ids_.push_front(inst->result_id);
998       }
999 
1000       spirv_.push_back(inst->result_id);
1001 
1002       // Save the result ID to type ID mapping.
1003       // In the grammar, type ID always appears before result ID.
1004       // A regular value maps to its type. Some instructions (e.g. OpLabel)
1005       // have no type Id, and will map to 0. The result Id for a
1006       // type-generating instruction (e.g. OpTypeInt) maps to itself.
1007       auto insertion_result = id_to_type_id_.emplace(
1008           inst->result_id,
1009           spvOpcodeGeneratesType(opcode) ? inst->result_id : inst->type_id);
1010       if(!insertion_result.second) {
1011         return vstate_.diag(SPV_ERROR_INVALID_ID)
1012             << "Unexpected behavior: id->type_id pair was already registered";
1013       }
1014       break;
1015     }
1016 
1017     case SPV_OPERAND_TYPE_ID:
1018     case SPV_OPERAND_TYPE_OPTIONAL_ID:
1019     case SPV_OPERAND_TYPE_SCOPE_ID:
1020     case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
1021       uint32_t id = 0;
1022       if (!DecodeId(&id))
1023         return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read id";
1024 
1025       if (id == 0)
1026         return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
1027 
1028       spirv_.push_back(id);
1029       vstate_.setIdBound(std::max(vstate_.getIdBound(), id + 1));
1030 
1031       if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
1032 
1033         parsed_operand.type = SPV_OPERAND_TYPE_ID;
1034 
1035         if (opcode == SpvOpExtInst && parsed_operand.offset == 3) {
1036           // The current word is the extended instruction set id.
1037           // Set the extended instruction set type for the current instruction.
1038           auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id);
1039           if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) {
1040             return vstate_.diag(SPV_ERROR_INVALID_ID)
1041                 << "OpExtInst set id " << id
1042                 << " does not reference an OpExtInstImport result Id";
1043           }
1044           inst->ext_inst_type = ext_inst_type_iter->second;
1045         }
1046       }
1047       break;
1048     }
1049 
1050     case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
1051       uint32_t word = 0;
1052       if (!DecodeOperandWord(type, &word))
1053         return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1054             << "Failed to read enum";
1055 
1056       spirv_.push_back(word);
1057 
1058       assert(SpvOpExtInst == opcode);
1059       assert(inst->ext_inst_type != SPV_EXT_INST_TYPE_NONE);
1060       spv_ext_inst_desc ext_inst;
1061       if (grammar_.lookupExtInst(inst->ext_inst_type, word, &ext_inst))
1062         return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1063             << "Invalid extended instruction number: " << word;
1064       spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
1065       break;
1066     }
1067 
1068     case SPV_OPERAND_TYPE_LITERAL_INTEGER:
1069     case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
1070       // These are regular single-word literal integer operands.
1071       // Post-parsing validation should check the range of the parsed value.
1072       parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
1073       // It turns out they are always unsigned integers!
1074       parsed_operand.number_kind = SPV_NUMBER_UNSIGNED_INT;
1075       parsed_operand.number_bit_width = 32;
1076 
1077       uint32_t word = 0;
1078       if (!DecodeOperandWord(type, &word))
1079         return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1080             << "Failed to read literal integer";
1081 
1082       spirv_.push_back(word);
1083       break;
1084     }
1085 
1086     case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
1087     case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER:
1088       parsed_operand.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
1089       if (opcode == SpvOpSwitch) {
1090         // The literal operands have the same type as the value
1091         // referenced by the selector Id.
1092         const uint32_t selector_id = spirv_.at(instruction_offset + 1);
1093         const auto type_id_iter = id_to_type_id_.find(selector_id);
1094         if (type_id_iter == id_to_type_id_.end() ||
1095             type_id_iter->second == 0) {
1096           return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1097               << "Invalid OpSwitch: selector id " << selector_id
1098               << " has no type";
1099         }
1100         uint32_t type_id = type_id_iter->second;
1101 
1102         if (selector_id == type_id) {
1103           // Recall that by convention, a result ID that is a type definition
1104           // maps to itself.
1105           return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1106               << "Invalid OpSwitch: selector id " << selector_id
1107               << " is a type, not a value";
1108         }
1109         if (auto error = SetNumericTypeInfoForType(&parsed_operand, type_id))
1110           return error;
1111         if (parsed_operand.number_kind != SPV_NUMBER_UNSIGNED_INT &&
1112             parsed_operand.number_kind != SPV_NUMBER_SIGNED_INT) {
1113           return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1114               << "Invalid OpSwitch: selector id " << selector_id
1115               << " is not a scalar integer";
1116         }
1117       } else {
1118         assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
1119         // The literal number type is determined by the type Id for the
1120         // constant.
1121         assert(inst->type_id);
1122         if (auto error =
1123             SetNumericTypeInfoForType(&parsed_operand, inst->type_id))
1124           return error;
1125       }
1126 
1127       if (auto error = DecodeLiteralNumber(parsed_operand))
1128         return error;
1129 
1130       break;
1131 
1132     case SPV_OPERAND_TYPE_LITERAL_STRING:
1133     case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
1134       parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_STRING;
1135       std::vector<char> str;
1136       // The loop is expected to terminate once we encounter '\0' or exhaust
1137       // the bit stream.
1138       while (true) {
1139         char ch = 0;
1140         if (!reader_.ReadUnencoded(&ch))
1141           return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1142               << "Failed to read literal string";
1143 
1144         str.push_back(ch);
1145 
1146         if (ch == '\0')
1147           break;
1148       }
1149 
1150       while (str.size() % 4 != 0)
1151         str.push_back('\0');
1152 
1153       spirv_.resize(spirv_.size() + str.size() / 4);
1154       std::memcpy(&spirv_[first_word_index], str.data(), str.size());
1155 
1156       if (SpvOpExtInstImport == opcode) {
1157         // Record the extended instruction type for the ID for this import.
1158         // There is only one string literal argument to OpExtInstImport,
1159         // so it's sufficient to guard this just on the opcode.
1160         const spv_ext_inst_type_t ext_inst_type =
1161             spvExtInstImportTypeGet(str.data());
1162         if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) {
1163           return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1164               << "Invalid extended instruction import '" << str.data() << "'";
1165         }
1166         // We must have parsed a valid result ID.  It's a condition
1167         // of the grammar, and we only accept non-zero result Ids.
1168         assert(inst->result_id);
1169         const bool inserted = import_id_to_ext_inst_type_.emplace(
1170             inst->result_id, ext_inst_type).second;
1171         (void)inserted;
1172         assert(inserted);
1173       }
1174       break;
1175     }
1176 
1177     case SPV_OPERAND_TYPE_CAPABILITY:
1178     case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
1179     case SPV_OPERAND_TYPE_EXECUTION_MODEL:
1180     case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
1181     case SPV_OPERAND_TYPE_MEMORY_MODEL:
1182     case SPV_OPERAND_TYPE_EXECUTION_MODE:
1183     case SPV_OPERAND_TYPE_STORAGE_CLASS:
1184     case SPV_OPERAND_TYPE_DIMENSIONALITY:
1185     case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
1186     case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
1187     case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
1188     case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
1189     case SPV_OPERAND_TYPE_LINKAGE_TYPE:
1190     case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
1191     case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
1192     case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
1193     case SPV_OPERAND_TYPE_DECORATION:
1194     case SPV_OPERAND_TYPE_BUILT_IN:
1195     case SPV_OPERAND_TYPE_GROUP_OPERATION:
1196     case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
1197     case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
1198       // A single word that is a plain enum value.
1199       uint32_t word = 0;
1200       if (!DecodeOperandWord(type, &word))
1201         return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1202             << "Failed to read enum";
1203 
1204       spirv_.push_back(word);
1205 
1206       // Map an optional operand type to its corresponding concrete type.
1207       if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
1208         parsed_operand.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
1209 
1210       spv_operand_desc entry;
1211       if (grammar_.lookupOperand(type, word, &entry)) {
1212         return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1213             << "Invalid "
1214             << spvOperandTypeStr(parsed_operand.type)
1215             << " operand: " << word;
1216       }
1217 
1218       // Prepare to accept operands to this operand, if needed.
1219       spvPushOperandTypes(entry->operandTypes, expected_operands);
1220       break;
1221     }
1222 
1223     case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
1224     case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
1225     case SPV_OPERAND_TYPE_LOOP_CONTROL:
1226     case SPV_OPERAND_TYPE_IMAGE:
1227     case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
1228     case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
1229     case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
1230       // This operand is a mask.
1231       uint32_t word = 0;
1232       if (!DecodeOperandWord(type, &word))
1233         return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1234             << "Failed to read " << spvOperandTypeStr(type)
1235             << " for " << spvOpcodeString(SpvOp(inst->opcode));
1236 
1237       spirv_.push_back(word);
1238 
1239       // Map an optional operand type to its corresponding concrete type.
1240       if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
1241         parsed_operand.type = SPV_OPERAND_TYPE_IMAGE;
1242       else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
1243         parsed_operand.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
1244 
1245       // Check validity of set mask bits. Also prepare for operands for those
1246       // masks if they have any.  To get operand order correct, scan from
1247       // MSB to LSB since we can only prepend operands to a pattern.
1248       // The only case in the grammar where you have more than one mask bit
1249       // having an operand is for image operands.  See SPIR-V 3.14 Image
1250       // Operands.
1251       uint32_t remaining_word = word;
1252       for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
1253         if (remaining_word & mask) {
1254           spv_operand_desc entry;
1255           if (grammar_.lookupOperand(type, mask, &entry)) {
1256             return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1257                    << "Invalid " << spvOperandTypeStr(parsed_operand.type)
1258                    << " operand: " << word << " has invalid mask component "
1259                    << mask;
1260           }
1261           remaining_word ^= mask;
1262           spvPushOperandTypes(entry->operandTypes, expected_operands);
1263         }
1264       }
1265       if (word == 0) {
1266         // An all-zeroes mask *might* also be valid.
1267         spv_operand_desc entry;
1268         if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
1269           // Prepare for its operands, if any.
1270           spvPushOperandTypes(entry->operandTypes, expected_operands);
1271         }
1272       }
1273       break;
1274     }
1275     default:
1276       return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1277           << "Internal error: Unhandled operand type: " << type;
1278   }
1279 
1280   parsed_operand.num_words = uint16_t(spirv_.size() - first_word_index);
1281 
1282   assert(int(SPV_OPERAND_TYPE_FIRST_CONCRETE_TYPE) <= int(parsed_operand.type));
1283   assert(int(SPV_OPERAND_TYPE_LAST_CONCRETE_TYPE) >= int(parsed_operand.type));
1284 
1285   parsed_operands_.push_back(parsed_operand);
1286 
1287   return SPV_SUCCESS;
1288 }
1289 
DecodeInstruction(spv_parsed_instruction_t * inst)1290 spv_result_t MarkvDecoder::DecodeInstruction(spv_parsed_instruction_t* inst) {
1291   parsed_operands_.clear();
1292   const size_t instruction_offset = spirv_.size();
1293 
1294   bool read_result_id = false;
1295 
1296   while (true) {
1297     uint32_t word = 0;
1298     if (!reader_.ReadVariableWidthU32(&word,
1299                                       model_->opcode_chunk_length())) {
1300       return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1301           << "Failed to read opcode of instruction";
1302     }
1303 
1304     if (word >= kMarkvFirstOpcode) {
1305       if (word == kMarkvOpNextInstructionEncodesResultId) {
1306         read_result_id = true;
1307       } else {
1308         return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1309             << "Encountered unknown MARK-V opcode";
1310       }
1311     } else {
1312       inst->opcode = static_cast<uint16_t>(word);
1313       break;
1314     }
1315   }
1316 
1317   const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
1318 
1319   // Opcode/num_words placeholder, the word will be filled in later.
1320   spirv_.push_back(0);
1321 
1322   spv_opcode_desc opcode_desc;
1323   if (grammar_.lookupOpcode(opcode, &opcode_desc)
1324       != SPV_SUCCESS) {
1325     return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
1326   }
1327 
1328   spv_operand_pattern_t expected_operands;
1329   expected_operands.reserve(opcode_desc->numTypes);
1330   for (auto i = 0; i < opcode_desc->numTypes; i++)
1331     expected_operands.push_back(opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
1332 
1333   if (!OpcodeHasFixedNumberOfOperands(opcode)) {
1334     if (!reader_.ReadVariableWidthU16(&inst->num_operands,
1335                                       model_->num_operands_chunk_length()))
1336       return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1337           << "Failed to read num_operands of instruction";
1338   } else {
1339     inst->num_operands = static_cast<uint16_t>(expected_operands.size());
1340   }
1341 
1342   for (size_t operand_index = 0;
1343        operand_index < static_cast<size_t>(inst->num_operands);
1344        ++operand_index) {
1345     assert(!expected_operands.empty());
1346     const spv_operand_type_t type =
1347         spvTakeFirstMatchableOperand(&expected_operands);
1348 
1349     const size_t operand_offset = spirv_.size() - instruction_offset;
1350 
1351     const spv_result_t decode_result =
1352         DecodeOperand(instruction_offset, operand_offset, inst, type,
1353                       &expected_operands, read_result_id);
1354 
1355     if (decode_result != SPV_SUCCESS)
1356       return decode_result;
1357   }
1358 
1359   assert(inst->num_operands == parsed_operands_.size());
1360 
1361   // Only valid while spirv_ and parsed_operands_ remain unchanged.
1362   inst->words = &spirv_[instruction_offset];
1363   inst->operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
1364   inst->num_words = static_cast<uint16_t>(spirv_.size() - instruction_offset);
1365   spirv_[instruction_offset] =
1366       spvOpcodeMake(inst->num_words, SpvOp(inst->opcode));
1367 
1368   assert(inst->num_words == std::accumulate(
1369       parsed_operands_.begin(), parsed_operands_.end(), 1,
1370       [](int num_words, const spv_parsed_operand_t& operand) {
1371         return num_words += operand.num_words;
1372   }) && "num_words in instruction doesn't correspond to the sum of num_words"
1373         "in the operands");
1374 
1375   RecordNumberType(*inst);
1376 
1377   if (!ReadToByteBreakIfAgreed())
1378     return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1379         << "Failed to read to byte break";
1380 
1381   return SPV_SUCCESS;
1382 }
1383 
SetNumericTypeInfoForType(spv_parsed_operand_t * parsed_operand,uint32_t type_id)1384 spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
1385     spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
1386   assert(type_id != 0);
1387   auto type_info_iter = type_id_to_number_type_info_.find(type_id);
1388   if (type_info_iter == type_id_to_number_type_info_.end()) {
1389     return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1390         << "Type Id " << type_id << " is not a type";
1391   }
1392 
1393   const NumberType& info = type_info_iter->second;
1394   if (info.type == SPV_NUMBER_NONE) {
1395     // This is a valid type, but for something other than a scalar number.
1396     return vstate_.diag(SPV_ERROR_INVALID_BINARY)
1397         << "Type Id " << type_id << " is not a scalar numeric type";
1398   }
1399 
1400   parsed_operand->number_kind = info.type;
1401   parsed_operand->number_bit_width = info.bit_width;
1402   // Round up the word count.
1403   parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
1404   return SPV_SUCCESS;
1405 }
1406 
RecordNumberType(const spv_parsed_instruction_t & inst)1407 void MarkvDecoder::RecordNumberType(const spv_parsed_instruction_t& inst) {
1408   const SpvOp opcode = static_cast<SpvOp>(inst.opcode);
1409   if (spvOpcodeGeneratesType(opcode)) {
1410     NumberType info = {SPV_NUMBER_NONE, 0};
1411     if (SpvOpTypeInt == opcode) {
1412       info.bit_width = inst.words[inst.operands[1].offset];
1413       info.type = inst.words[inst.operands[2].offset] ?
1414           SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT;
1415     } else if (SpvOpTypeFloat == opcode) {
1416       info.bit_width = inst.words[inst.operands[1].offset];
1417       info.type = SPV_NUMBER_FLOATING;
1418     }
1419     // The *result* Id of a type generating instruction is the type Id.
1420     type_id_to_number_type_info_[inst.result_id] = info;
1421   }
1422 }
1423 
EncodeHeader(void * user_data,spv_endianness_t endian,uint32_t magic,uint32_t version,uint32_t generator,uint32_t id_bound,uint32_t schema)1424 spv_result_t EncodeHeader(
1425     void* user_data, spv_endianness_t endian, uint32_t magic,
1426     uint32_t version, uint32_t generator, uint32_t id_bound,
1427     uint32_t schema) {
1428   MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
1429   return encoder->EncodeHeader(
1430       endian, magic, version, generator, id_bound, schema);
1431 }
1432 
EncodeInstruction(void * user_data,const spv_parsed_instruction_t * inst)1433 spv_result_t EncodeInstruction(
1434     void* user_data, const spv_parsed_instruction_t* inst) {
1435   MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
1436   return encoder->EncodeInstruction(*inst);
1437 }
1438 
1439 }  // namespace
1440 
spvSpirvToMarkv(spv_const_context context,const uint32_t * spirv_words,const size_t spirv_num_words,spv_const_markv_encoder_options options,spv_markv_binary * markv_binary,spv_text * comments,spv_diagnostic * diagnostic)1441 spv_result_t spvSpirvToMarkv(spv_const_context context,
1442                              const uint32_t* spirv_words,
1443                              const size_t spirv_num_words,
1444                              spv_const_markv_encoder_options options,
1445                              spv_markv_binary* markv_binary,
1446                              spv_text* comments, spv_diagnostic* diagnostic) {
1447   spv_context_t hijack_context = *context;
1448   if (diagnostic) {
1449     *diagnostic = nullptr;
1450     libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic);
1451   }
1452 
1453   spv_const_binary_t spirv_binary = {spirv_words, spirv_num_words};
1454 
1455   spv_endianness_t endian;
1456   spv_position_t position = {};
1457   if (spvBinaryEndianness(&spirv_binary, &endian)) {
1458     return libspirv::DiagnosticStream(position, hijack_context.consumer,
1459                                       SPV_ERROR_INVALID_BINARY)
1460         << "Invalid SPIR-V magic number.";
1461   }
1462 
1463   spv_header_t header;
1464   if (spvBinaryHeaderGet(&spirv_binary, endian, &header)) {
1465     return libspirv::DiagnosticStream(position, hijack_context.consumer,
1466                                       SPV_ERROR_INVALID_BINARY)
1467         << "Invalid SPIR-V header.";
1468   }
1469 
1470   MarkvEncoder encoder(&hijack_context, options);
1471 
1472   if (comments) {
1473     encoder.CreateCommentsLogger();
1474 
1475     spv_text text = nullptr;
1476     if (spvBinaryToText(&hijack_context, spirv_words, spirv_num_words,
1477                         SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text, nullptr)
1478         != SPV_SUCCESS) {
1479       return libspirv::DiagnosticStream(position, hijack_context.consumer,
1480                                         SPV_ERROR_INVALID_BINARY)
1481           << "Failed to disassemble SPIR-V binary.";
1482     }
1483     assert(text);
1484     encoder.SetDisassembly(std::string(text->str, text->length));
1485     spvTextDestroy(text);
1486   }
1487 
1488   if (spvBinaryParse(
1489       &hijack_context, &encoder, spirv_words, spirv_num_words, EncodeHeader,
1490       EncodeInstruction, diagnostic) != SPV_SUCCESS) {
1491     return libspirv::DiagnosticStream(position, hijack_context.consumer,
1492                                       SPV_ERROR_INVALID_BINARY)
1493         << "Unable to encode to MARK-V.";
1494   }
1495 
1496   if (comments)
1497     *comments = CreateSpvText(encoder.GetComments());
1498 
1499   *markv_binary = encoder.GetMarkvBinary();
1500   return SPV_SUCCESS;
1501 }
1502 
spvMarkvToSpirv(spv_const_context context,const uint8_t * markv_data,size_t markv_size_bytes,spv_const_markv_decoder_options options,spv_binary * spirv_binary,spv_text *,spv_diagnostic * diagnostic)1503 spv_result_t spvMarkvToSpirv(spv_const_context context,
1504                              const uint8_t* markv_data,
1505                              size_t markv_size_bytes,
1506                              spv_const_markv_decoder_options options,
1507                              spv_binary* spirv_binary,
1508                              spv_text* /* comments */, spv_diagnostic* diagnostic) {
1509   spv_position_t position = {};
1510   spv_context_t hijack_context = *context;
1511   if (diagnostic) {
1512     *diagnostic = nullptr;
1513     libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic);
1514   }
1515 
1516   MarkvDecoder decoder(&hijack_context, markv_data, markv_size_bytes, options);
1517 
1518   std::vector<uint32_t> words;
1519 
1520   if (decoder.DecodeModule(&words) != SPV_SUCCESS) {
1521     return libspirv::DiagnosticStream(position, hijack_context.consumer,
1522                                       SPV_ERROR_INVALID_BINARY)
1523         << "Unable to decode MARK-V.";
1524   }
1525 
1526   assert(!words.empty());
1527 
1528   *spirv_binary = new spv_binary_t();
1529   (*spirv_binary)->code = new uint32_t[words.size()];
1530   (*spirv_binary)->wordCount = words.size();
1531   std::memcpy((*spirv_binary)->code, words.data(), 4 * words.size());
1532 
1533   return SPV_SUCCESS;
1534 }
1535 
spvMarkvBinaryDestroy(spv_markv_binary binary)1536 void spvMarkvBinaryDestroy(spv_markv_binary binary) {
1537   if (!binary) return;
1538   delete[] binary->data;
1539   delete binary;
1540 }
1541 
spvMarkvEncoderOptionsCreate()1542 spv_markv_encoder_options spvMarkvEncoderOptionsCreate() {
1543   return new spv_markv_encoder_options_t;
1544 }
1545 
spvMarkvEncoderOptionsDestroy(spv_markv_encoder_options options)1546 void spvMarkvEncoderOptionsDestroy(spv_markv_encoder_options options) {
1547   delete options;
1548 }
1549 
spvMarkvDecoderOptionsCreate()1550 spv_markv_decoder_options spvMarkvDecoderOptionsCreate() {
1551   return new spv_markv_decoder_options_t;
1552 }
1553 
spvMarkvDecoderOptionsDestroy(spv_markv_decoder_options options)1554 void spvMarkvDecoderOptionsDestroy(spv_markv_decoder_options options) {
1555   delete options;
1556 }
1557