• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2016 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "spirv-tools/optimizer.hpp"
16 
17 #include <cassert>
18 #include <charconv>
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "source/opt/build_module.h"
26 #include "source/opt/graphics_robust_access_pass.h"
27 #include "source/opt/log.h"
28 #include "source/opt/pass_manager.h"
29 #include "source/opt/passes.h"
30 #include "source/spirv_optimizer_options.h"
31 #include "source/util/make_unique.h"
32 #include "source/util/string_utils.h"
33 
34 namespace spvtools {
35 
GetVectorOfStrings(const char ** strings,const size_t string_count)36 std::vector<std::string> GetVectorOfStrings(const char** strings,
37                                             const size_t string_count) {
38   std::vector<std::string> result;
39   for (uint32_t i = 0; i < string_count; i++) {
40     result.emplace_back(strings[i]);
41   }
42   return result;
43 }
44 
45 struct Optimizer::PassToken::Impl {
Implspvtools::Optimizer::PassToken::Impl46   Impl(std::unique_ptr<opt::Pass> p) : pass(std::move(p)) {}
47 
48   std::unique_ptr<opt::Pass> pass;  // Internal implementation pass.
49 };
50 
PassToken(std::unique_ptr<Optimizer::PassToken::Impl> impl)51 Optimizer::PassToken::PassToken(
52     std::unique_ptr<Optimizer::PassToken::Impl> impl)
53     : impl_(std::move(impl)) {}
54 
PassToken(std::unique_ptr<opt::Pass> && pass)55 Optimizer::PassToken::PassToken(std::unique_ptr<opt::Pass>&& pass)
56     : impl_(MakeUnique<Optimizer::PassToken::Impl>(std::move(pass))) {}
57 
PassToken(PassToken && that)58 Optimizer::PassToken::PassToken(PassToken&& that)
59     : impl_(std::move(that.impl_)) {}
60 
operator =(PassToken && that)61 Optimizer::PassToken& Optimizer::PassToken::operator=(PassToken&& that) {
62   impl_ = std::move(that.impl_);
63   return *this;
64 }
65 
~PassToken()66 Optimizer::PassToken::~PassToken() {}
67 
68 struct Optimizer::Impl {
Implspvtools::Optimizer::Impl69   explicit Impl(spv_target_env env) : target_env(env), pass_manager() {}
70 
71   spv_target_env target_env;      // Target environment.
72   opt::PassManager pass_manager;  // Internal implementation pass manager.
73   std::unordered_set<uint32_t> live_locs;  // Arg to debug dead output passes
74 };
75 
Optimizer(spv_target_env env)76 Optimizer::Optimizer(spv_target_env env) : impl_(new Impl(env)) {
77   assert(env != SPV_ENV_WEBGPU_0);
78 }
79 
~Optimizer()80 Optimizer::~Optimizer() {}
81 
SetMessageConsumer(MessageConsumer c)82 void Optimizer::SetMessageConsumer(MessageConsumer c) {
83   // All passes' message consumer needs to be updated.
84   for (uint32_t i = 0; i < impl_->pass_manager.NumPasses(); ++i) {
85     impl_->pass_manager.GetPass(i)->SetMessageConsumer(c);
86   }
87   impl_->pass_manager.SetMessageConsumer(std::move(c));
88 }
89 
consumer() const90 const MessageConsumer& Optimizer::consumer() const {
91   return impl_->pass_manager.consumer();
92 }
93 
RegisterPass(PassToken && p)94 Optimizer& Optimizer::RegisterPass(PassToken&& p) {
95   // Change to use the pass manager's consumer.
96   p.impl_->pass->SetMessageConsumer(consumer());
97   impl_->pass_manager.AddPass(std::move(p.impl_->pass));
98   return *this;
99 }
100 
101 // The legalization passes take a spir-v shader generated by an HLSL front-end
102 // and turn it into a valid vulkan spir-v shader.  There are two ways in which
103 // the code will be invalid at the start:
104 //
105 // 1) There will be opaque objects, like images, which will be passed around
106 //    in intermediate objects.  Valid spir-v will have to replace the use of
107 //    the opaque object with an intermediate object that is the result of the
108 //    load of the global opaque object.
109 //
110 // 2) There will be variables that contain pointers to structured or uniform
111 //    buffers.  It be legal, the variables must be eliminated, and the
112 //    references to the structured buffers must use the result of OpVariable
113 //    in the Uniform storage class.
114 //
115 // Optimization in this list must accept shaders with these relaxation of the
116 // rules.  There is not guarantee that this list of optimizations is able to
117 // legalize all inputs, but it is on a best effort basis.
118 //
119 // The legalization problem is essentially a very general copy propagation
120 // problem.  The optimization we use are all used to either do copy propagation
121 // or enable more copy propagation.
RegisterLegalizationPasses(bool preserve_interface)122 Optimizer& Optimizer::RegisterLegalizationPasses(bool preserve_interface) {
123   return
124       // Wrap OpKill instructions so all other code can be inlined.
125       RegisterPass(CreateWrapOpKillPass())
126           // Remove unreachable block so that merge return works.
127           .RegisterPass(CreateDeadBranchElimPass())
128           // Merge the returns so we can inline.
129           .RegisterPass(CreateMergeReturnPass())
130           // Make sure uses and definitions are in the same function.
131           .RegisterPass(CreateInlineExhaustivePass())
132           // Make private variable function scope
133           .RegisterPass(CreateEliminateDeadFunctionsPass())
134           .RegisterPass(CreatePrivateToLocalPass())
135           // Fix up the storage classes that DXC may have purposely generated
136           // incorrectly.  All functions are inlined, and a lot of dead code has
137           // been removed.
138           .RegisterPass(CreateFixStorageClassPass())
139           // Propagate the value stored to the loads in very simple cases.
140           .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
141           .RegisterPass(CreateLocalSingleStoreElimPass())
142           .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
143           // Split up aggregates so they are easier to deal with.
144           .RegisterPass(CreateScalarReplacementPass(0))
145           // Remove loads and stores so everything is in intermediate values.
146           // Takes care of copy propagation of non-members.
147           .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
148           .RegisterPass(CreateLocalSingleStoreElimPass())
149           .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
150           .RegisterPass(CreateLocalMultiStoreElimPass())
151           .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
152           // Propagate constants to get as many constant conditions on branches
153           // as possible.
154           .RegisterPass(CreateCCPPass())
155           .RegisterPass(CreateLoopUnrollPass(true))
156           .RegisterPass(CreateDeadBranchElimPass())
157           // Copy propagate members.  Cleans up code sequences generated by
158           // scalar replacement.  Also important for removing OpPhi nodes.
159           .RegisterPass(CreateSimplificationPass())
160           .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
161           .RegisterPass(CreateCopyPropagateArraysPass())
162           // May need loop unrolling here see
163           // https://github.com/Microsoft/DirectXShaderCompiler/pull/930
164           // Get rid of unused code that contain traces of illegal code
165           // or unused references to unbound external objects
166           .RegisterPass(CreateVectorDCEPass())
167           .RegisterPass(CreateDeadInsertElimPass())
168           .RegisterPass(CreateReduceLoadSizePass())
169           .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
170           .RegisterPass(CreateRemoveUnusedInterfaceVariablesPass())
171           .RegisterPass(CreateInterpolateFixupPass())
172           .RegisterPass(CreateInvocationInterlockPlacementPass())
173           .RegisterPass(CreateOpExtInstWithForwardReferenceFixupPass());
174 }
175 
RegisterLegalizationPasses()176 Optimizer& Optimizer::RegisterLegalizationPasses() {
177   return RegisterLegalizationPasses(false);
178 }
179 
RegisterPerformancePasses(bool preserve_interface)180 Optimizer& Optimizer::RegisterPerformancePasses(bool preserve_interface) {
181   return RegisterPass(CreateWrapOpKillPass())
182       .RegisterPass(CreateDeadBranchElimPass())
183       .RegisterPass(CreateMergeReturnPass())
184       .RegisterPass(CreateInlineExhaustivePass())
185       .RegisterPass(CreateEliminateDeadFunctionsPass())
186       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
187       .RegisterPass(CreatePrivateToLocalPass())
188       .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
189       .RegisterPass(CreateLocalSingleStoreElimPass())
190       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
191       .RegisterPass(CreateScalarReplacementPass())
192       .RegisterPass(CreateLocalAccessChainConvertPass())
193       .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
194       .RegisterPass(CreateLocalSingleStoreElimPass())
195       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
196       .RegisterPass(CreateLocalMultiStoreElimPass())
197       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
198       .RegisterPass(CreateCCPPass())
199       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
200       .RegisterPass(CreateLoopUnrollPass(true))
201       .RegisterPass(CreateDeadBranchElimPass())
202       .RegisterPass(CreateRedundancyEliminationPass())
203       .RegisterPass(CreateCombineAccessChainsPass())
204       .RegisterPass(CreateSimplificationPass())
205       .RegisterPass(CreateScalarReplacementPass())
206       .RegisterPass(CreateLocalAccessChainConvertPass())
207       .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
208       .RegisterPass(CreateLocalSingleStoreElimPass())
209       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
210       .RegisterPass(CreateSSARewritePass())
211       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
212       .RegisterPass(CreateVectorDCEPass())
213       .RegisterPass(CreateDeadInsertElimPass())
214       .RegisterPass(CreateDeadBranchElimPass())
215       .RegisterPass(CreateSimplificationPass())
216       .RegisterPass(CreateIfConversionPass())
217       .RegisterPass(CreateCopyPropagateArraysPass())
218       .RegisterPass(CreateReduceLoadSizePass())
219       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
220       .RegisterPass(CreateBlockMergePass())
221       .RegisterPass(CreateRedundancyEliminationPass())
222       .RegisterPass(CreateDeadBranchElimPass())
223       .RegisterPass(CreateBlockMergePass())
224       .RegisterPass(CreateSimplificationPass());
225 }
226 
RegisterPerformancePasses()227 Optimizer& Optimizer::RegisterPerformancePasses() {
228   return RegisterPerformancePasses(false);
229 }
230 
RegisterSizePasses(bool preserve_interface)231 Optimizer& Optimizer::RegisterSizePasses(bool preserve_interface) {
232   return RegisterPass(CreateWrapOpKillPass())
233       .RegisterPass(CreateDeadBranchElimPass())
234       .RegisterPass(CreateMergeReturnPass())
235       .RegisterPass(CreateInlineExhaustivePass())
236       .RegisterPass(CreateEliminateDeadFunctionsPass())
237       .RegisterPass(CreatePrivateToLocalPass())
238       .RegisterPass(CreateScalarReplacementPass(0))
239       .RegisterPass(CreateLocalMultiStoreElimPass())
240       .RegisterPass(CreateCCPPass())
241       .RegisterPass(CreateLoopUnrollPass(true))
242       .RegisterPass(CreateDeadBranchElimPass())
243       .RegisterPass(CreateSimplificationPass())
244       .RegisterPass(CreateScalarReplacementPass(0))
245       .RegisterPass(CreateLocalSingleStoreElimPass())
246       .RegisterPass(CreateIfConversionPass())
247       .RegisterPass(CreateSimplificationPass())
248       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
249       .RegisterPass(CreateDeadBranchElimPass())
250       .RegisterPass(CreateBlockMergePass())
251       .RegisterPass(CreateLocalAccessChainConvertPass())
252       .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
253       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
254       .RegisterPass(CreateCopyPropagateArraysPass())
255       .RegisterPass(CreateVectorDCEPass())
256       .RegisterPass(CreateDeadInsertElimPass())
257       .RegisterPass(CreateEliminateDeadMembersPass())
258       .RegisterPass(CreateLocalSingleStoreElimPass())
259       .RegisterPass(CreateBlockMergePass())
260       .RegisterPass(CreateLocalMultiStoreElimPass())
261       .RegisterPass(CreateRedundancyEliminationPass())
262       .RegisterPass(CreateSimplificationPass())
263       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
264       .RegisterPass(CreateCFGCleanupPass());
265 }
266 
RegisterSizePasses()267 Optimizer& Optimizer::RegisterSizePasses() { return RegisterSizePasses(false); }
268 
RegisterPassesFromFlags(const std::vector<std::string> & flags)269 bool Optimizer::RegisterPassesFromFlags(const std::vector<std::string>& flags) {
270   return RegisterPassesFromFlags(flags, false);
271 }
272 
RegisterPassesFromFlags(const std::vector<std::string> & flags,bool preserve_interface)273 bool Optimizer::RegisterPassesFromFlags(const std::vector<std::string>& flags,
274                                         bool preserve_interface) {
275   for (const auto& flag : flags) {
276     if (!RegisterPassFromFlag(flag, preserve_interface)) {
277       return false;
278     }
279   }
280 
281   return true;
282 }
283 
FlagHasValidForm(const std::string & flag) const284 bool Optimizer::FlagHasValidForm(const std::string& flag) const {
285   if (flag == "-O" || flag == "-Os") {
286     return true;
287   } else if (flag.size() > 2 && flag.substr(0, 2) == "--") {
288     return true;
289   }
290 
291   Errorf(consumer(), nullptr, {},
292          "%s is not a valid flag.  Flag passes should have the form "
293          "'--pass_name[=pass_args]'. Special flag names also accepted: -O "
294          "and -Os.",
295          flag.c_str());
296   return false;
297 }
298 
RegisterPassFromFlag(const std::string & flag)299 bool Optimizer::RegisterPassFromFlag(const std::string& flag) {
300   return RegisterPassFromFlag(flag, false);
301 }
302 
RegisterPassFromFlag(const std::string & flag,bool preserve_interface)303 bool Optimizer::RegisterPassFromFlag(const std::string& flag,
304                                      bool preserve_interface) {
305   if (!FlagHasValidForm(flag)) {
306     return false;
307   }
308 
309   // Split flags of the form --pass_name=pass_args.
310   auto p = utils::SplitFlagArgs(flag);
311   std::string pass_name = p.first;
312   std::string pass_args = p.second;
313 
314   // FIXME(dnovillo): This should be re-factored so that pass names can be
315   // automatically checked against Pass::name() and PassToken instances created
316   // via a template function.  Additionally, class Pass should have a desc()
317   // method that describes the pass (so it can be used in --help).
318   //
319   // Both Pass::name() and Pass::desc() should be static class members so they
320   // can be invoked without creating a pass instance.
321   if (pass_name == "strip-debug") {
322     RegisterPass(CreateStripDebugInfoPass());
323   } else if (pass_name == "strip-reflect") {
324     RegisterPass(CreateStripReflectInfoPass());
325   } else if (pass_name == "strip-nonsemantic") {
326     RegisterPass(CreateStripNonSemanticInfoPass());
327   } else if (pass_name == "fix-opextinst-opcodes") {
328     RegisterPass(CreateOpExtInstWithForwardReferenceFixupPass());
329   } else if (pass_name == "set-spec-const-default-value") {
330     if (pass_args.size() > 0) {
331       auto spec_ids_vals =
332           opt::SetSpecConstantDefaultValuePass::ParseDefaultValuesString(
333               pass_args.c_str());
334       if (!spec_ids_vals) {
335         Errorf(consumer(), nullptr, {},
336                "Invalid argument for --set-spec-const-default-value: %s",
337                pass_args.c_str());
338         return false;
339       }
340       RegisterPass(
341           CreateSetSpecConstantDefaultValuePass(std::move(*spec_ids_vals)));
342     } else {
343       Errorf(consumer(), nullptr, {},
344              "Invalid spec constant value string '%s'. Expected a string of "
345              "<spec id>:<default value> pairs.",
346              pass_args.c_str());
347       return false;
348     }
349   } else if (pass_name == "if-conversion") {
350     RegisterPass(CreateIfConversionPass());
351   } else if (pass_name == "freeze-spec-const") {
352     RegisterPass(CreateFreezeSpecConstantValuePass());
353   } else if (pass_name == "inline-entry-points-exhaustive") {
354     RegisterPass(CreateInlineExhaustivePass());
355   } else if (pass_name == "inline-entry-points-opaque") {
356     RegisterPass(CreateInlineOpaquePass());
357   } else if (pass_name == "combine-access-chains") {
358     RegisterPass(CreateCombineAccessChainsPass());
359   } else if (pass_name == "convert-local-access-chains") {
360     RegisterPass(CreateLocalAccessChainConvertPass());
361   } else if (pass_name == "replace-desc-array-access-using-var-index") {
362     RegisterPass(CreateReplaceDescArrayAccessUsingVarIndexPass());
363   } else if (pass_name == "spread-volatile-semantics") {
364     RegisterPass(CreateSpreadVolatileSemanticsPass());
365   } else if (pass_name == "descriptor-scalar-replacement") {
366     RegisterPass(CreateDescriptorScalarReplacementPass());
367   } else if (pass_name == "eliminate-dead-code-aggressive") {
368     RegisterPass(CreateAggressiveDCEPass(preserve_interface));
369   } else if (pass_name == "eliminate-insert-extract") {
370     RegisterPass(CreateInsertExtractElimPass());
371   } else if (pass_name == "eliminate-local-single-block") {
372     RegisterPass(CreateLocalSingleBlockLoadStoreElimPass());
373   } else if (pass_name == "eliminate-local-single-store") {
374     RegisterPass(CreateLocalSingleStoreElimPass());
375   } else if (pass_name == "merge-blocks") {
376     RegisterPass(CreateBlockMergePass());
377   } else if (pass_name == "merge-return") {
378     RegisterPass(CreateMergeReturnPass());
379   } else if (pass_name == "eliminate-dead-branches") {
380     RegisterPass(CreateDeadBranchElimPass());
381   } else if (pass_name == "eliminate-dead-functions") {
382     RegisterPass(CreateEliminateDeadFunctionsPass());
383   } else if (pass_name == "eliminate-local-multi-store") {
384     RegisterPass(CreateLocalMultiStoreElimPass());
385   } else if (pass_name == "eliminate-dead-const") {
386     RegisterPass(CreateEliminateDeadConstantPass());
387   } else if (pass_name == "eliminate-dead-inserts") {
388     RegisterPass(CreateDeadInsertElimPass());
389   } else if (pass_name == "eliminate-dead-variables") {
390     RegisterPass(CreateDeadVariableEliminationPass());
391   } else if (pass_name == "eliminate-dead-members") {
392     RegisterPass(CreateEliminateDeadMembersPass());
393   } else if (pass_name == "fold-spec-const-op-composite") {
394     RegisterPass(CreateFoldSpecConstantOpAndCompositePass());
395   } else if (pass_name == "loop-unswitch") {
396     RegisterPass(CreateLoopUnswitchPass());
397   } else if (pass_name == "scalar-replacement") {
398     if (pass_args.size() == 0) {
399       RegisterPass(CreateScalarReplacementPass());
400     } else {
401       int limit = -1;
402       if (pass_args.find_first_not_of("0123456789") == std::string::npos) {
403         limit = atoi(pass_args.c_str());
404       }
405 
406       if (limit >= 0) {
407         RegisterPass(CreateScalarReplacementPass(limit));
408       } else {
409         Error(consumer(), nullptr, {},
410               "--scalar-replacement must have no arguments or a non-negative "
411               "integer argument");
412         return false;
413       }
414     }
415   } else if (pass_name == "strength-reduction") {
416     RegisterPass(CreateStrengthReductionPass());
417   } else if (pass_name == "unify-const") {
418     RegisterPass(CreateUnifyConstantPass());
419   } else if (pass_name == "flatten-decorations") {
420     RegisterPass(CreateFlattenDecorationPass());
421   } else if (pass_name == "compact-ids") {
422     RegisterPass(CreateCompactIdsPass());
423   } else if (pass_name == "cfg-cleanup") {
424     RegisterPass(CreateCFGCleanupPass());
425   } else if (pass_name == "local-redundancy-elimination") {
426     RegisterPass(CreateLocalRedundancyEliminationPass());
427   } else if (pass_name == "loop-invariant-code-motion") {
428     RegisterPass(CreateLoopInvariantCodeMotionPass());
429   } else if (pass_name == "reduce-load-size") {
430     if (pass_args.size() == 0) {
431       RegisterPass(CreateReduceLoadSizePass());
432     } else {
433       double load_replacement_threshold = 0.9;
434       if (pass_args.find_first_not_of(".0123456789") == std::string::npos) {
435         load_replacement_threshold = atof(pass_args.c_str());
436       }
437 
438       if (load_replacement_threshold >= 0) {
439         RegisterPass(CreateReduceLoadSizePass(load_replacement_threshold));
440       } else {
441         Error(consumer(), nullptr, {},
442               "--reduce-load-size must have no arguments or a non-negative "
443               "double argument");
444         return false;
445       }
446     }
447   } else if (pass_name == "redundancy-elimination") {
448     RegisterPass(CreateRedundancyEliminationPass());
449   } else if (pass_name == "private-to-local") {
450     RegisterPass(CreatePrivateToLocalPass());
451   } else if (pass_name == "remove-duplicates") {
452     RegisterPass(CreateRemoveDuplicatesPass());
453   } else if (pass_name == "workaround-1209") {
454     RegisterPass(CreateWorkaround1209Pass());
455   } else if (pass_name == "replace-invalid-opcode") {
456     RegisterPass(CreateReplaceInvalidOpcodePass());
457   } else if (pass_name == "convert-relaxed-to-half") {
458     RegisterPass(CreateConvertRelaxedToHalfPass());
459   } else if (pass_name == "relax-float-ops") {
460     RegisterPass(CreateRelaxFloatOpsPass());
461   } else if (pass_name == "inst-debug-printf") {
462     // This private option is not for user consumption.
463     // It is here to assist in debugging and fixing the debug printf
464     // instrumentation pass.
465     // For users who wish to utilize debug printf, see the white paper at
466     // https://www.lunarg.com/wp-content/uploads/2021/08/Using-Debug-Printf-02August2021.pdf
467     RegisterPass(CreateInstDebugPrintfPass(7, 23));
468   } else if (pass_name == "simplify-instructions") {
469     RegisterPass(CreateSimplificationPass());
470   } else if (pass_name == "ssa-rewrite") {
471     RegisterPass(CreateSSARewritePass());
472   } else if (pass_name == "copy-propagate-arrays") {
473     RegisterPass(CreateCopyPropagateArraysPass());
474   } else if (pass_name == "loop-fission") {
475     int register_threshold_to_split =
476         (pass_args.size() > 0) ? atoi(pass_args.c_str()) : -1;
477     if (register_threshold_to_split > 0) {
478       RegisterPass(CreateLoopFissionPass(
479           static_cast<size_t>(register_threshold_to_split)));
480     } else {
481       Error(consumer(), nullptr, {},
482             "--loop-fission must have a positive integer argument");
483       return false;
484     }
485   } else if (pass_name == "loop-fusion") {
486     int max_registers_per_loop =
487         (pass_args.size() > 0) ? atoi(pass_args.c_str()) : -1;
488     if (max_registers_per_loop > 0) {
489       RegisterPass(
490           CreateLoopFusionPass(static_cast<size_t>(max_registers_per_loop)));
491     } else {
492       Error(consumer(), nullptr, {},
493             "--loop-fusion must have a positive integer argument");
494       return false;
495     }
496   } else if (pass_name == "loop-unroll") {
497     RegisterPass(CreateLoopUnrollPass(true));
498   } else if (pass_name == "upgrade-memory-model") {
499     RegisterPass(CreateUpgradeMemoryModelPass());
500   } else if (pass_name == "vector-dce") {
501     RegisterPass(CreateVectorDCEPass());
502   } else if (pass_name == "loop-unroll-partial") {
503     int factor = (pass_args.size() > 0) ? atoi(pass_args.c_str()) : 0;
504     if (factor > 0) {
505       RegisterPass(CreateLoopUnrollPass(false, factor));
506     } else {
507       Error(consumer(), nullptr, {},
508             "--loop-unroll-partial must have a positive integer argument");
509       return false;
510     }
511   } else if (pass_name == "loop-peeling") {
512     RegisterPass(CreateLoopPeelingPass());
513   } else if (pass_name == "loop-peeling-threshold") {
514     int factor = (pass_args.size() > 0) ? atoi(pass_args.c_str()) : 0;
515     if (factor > 0) {
516       opt::LoopPeelingPass::SetLoopPeelingThreshold(factor);
517     } else {
518       Error(consumer(), nullptr, {},
519             "--loop-peeling-threshold must have a positive integer argument");
520       return false;
521     }
522   } else if (pass_name == "ccp") {
523     RegisterPass(CreateCCPPass());
524   } else if (pass_name == "code-sink") {
525     RegisterPass(CreateCodeSinkingPass());
526   } else if (pass_name == "fix-storage-class") {
527     RegisterPass(CreateFixStorageClassPass());
528   } else if (pass_name == "O") {
529     RegisterPerformancePasses(preserve_interface);
530   } else if (pass_name == "Os") {
531     RegisterSizePasses(preserve_interface);
532   } else if (pass_name == "legalize-hlsl") {
533     RegisterLegalizationPasses(preserve_interface);
534   } else if (pass_name == "remove-unused-interface-variables") {
535     RegisterPass(CreateRemoveUnusedInterfaceVariablesPass());
536   } else if (pass_name == "graphics-robust-access") {
537     RegisterPass(CreateGraphicsRobustAccessPass());
538   } else if (pass_name == "wrap-opkill") {
539     RegisterPass(CreateWrapOpKillPass());
540   } else if (pass_name == "amd-ext-to-khr") {
541     RegisterPass(CreateAmdExtToKhrPass());
542   } else if (pass_name == "interpolate-fixup") {
543     RegisterPass(CreateInterpolateFixupPass());
544   } else if (pass_name == "remove-dont-inline") {
545     RegisterPass(CreateRemoveDontInlinePass());
546   } else if (pass_name == "eliminate-dead-input-components") {
547     RegisterPass(CreateEliminateDeadInputComponentsSafePass());
548   } else if (pass_name == "fix-func-call-param") {
549     RegisterPass(CreateFixFuncCallArgumentsPass());
550   } else if (pass_name == "convert-to-sampled-image") {
551     if (pass_args.size() > 0) {
552       auto descriptor_set_binding_pairs =
553           opt::ConvertToSampledImagePass::ParseDescriptorSetBindingPairsString(
554               pass_args.c_str());
555       if (!descriptor_set_binding_pairs) {
556         Errorf(consumer(), nullptr, {},
557                "Invalid argument for --convert-to-sampled-image: %s",
558                pass_args.c_str());
559         return false;
560       }
561       RegisterPass(CreateConvertToSampledImagePass(
562           std::move(*descriptor_set_binding_pairs)));
563     } else {
564       Errorf(consumer(), nullptr, {},
565              "Invalid pairs of descriptor set and binding '%s'. Expected a "
566              "string of <descriptor set>:<binding> pairs.",
567              pass_args.c_str());
568       return false;
569     }
570   } else if (pass_name == "switch-descriptorset") {
571     if (pass_args.size() == 0) {
572       Error(consumer(), nullptr, {},
573             "--switch-descriptorset requires a from:to argument.");
574       return false;
575     }
576     uint32_t from_set = 0, to_set = 0;
577     const char* start = pass_args.data();
578     const char* end = pass_args.data() + pass_args.size();
579 
580     auto result = std::from_chars(start, end, from_set);
581     if (result.ec != std::errc()) {
582       Errorf(consumer(), nullptr, {},
583              "Invalid argument for --switch-descriptorset: %s",
584              pass_args.c_str());
585       return false;
586     }
587     start = result.ptr;
588     if (start[0] != ':') {
589       Errorf(consumer(), nullptr, {},
590              "Invalid argument for --switch-descriptorset: %s",
591              pass_args.c_str());
592       return false;
593     }
594     start++;
595     result = std::from_chars(start, end, to_set);
596     if (result.ec != std::errc() || result.ptr != end) {
597       Errorf(consumer(), nullptr, {},
598              "Invalid argument for --switch-descriptorset: %s",
599              pass_args.c_str());
600       return false;
601     }
602     RegisterPass(CreateSwitchDescriptorSetPass(from_set, to_set));
603   } else if (pass_name == "modify-maximal-reconvergence") {
604     if (pass_args.size() == 0) {
605       Error(consumer(), nullptr, {},
606             "--modify-maximal-reconvergence requires an argument");
607       return false;
608     }
609     if (pass_args == "add") {
610       RegisterPass(CreateModifyMaximalReconvergencePass(true));
611     } else if (pass_args == "remove") {
612       RegisterPass(CreateModifyMaximalReconvergencePass(false));
613     } else {
614       Errorf(consumer(), nullptr, {},
615              "Invalid argument for --modify-maximal-reconvergence: %s (must be "
616              "'add' or 'remove')",
617              pass_args.c_str());
618       return false;
619     }
620   } else if (pass_name == "trim-capabilities") {
621     RegisterPass(CreateTrimCapabilitiesPass());
622   } else {
623     Errorf(consumer(), nullptr, {},
624            "Unknown flag '--%s'. Use --help for a list of valid flags",
625            pass_name.c_str());
626     return false;
627   }
628 
629   return true;
630 }
631 
SetTargetEnv(const spv_target_env env)632 void Optimizer::SetTargetEnv(const spv_target_env env) {
633   impl_->target_env = env;
634 }
635 
Run(const uint32_t * original_binary,const size_t original_binary_size,std::vector<uint32_t> * optimized_binary) const636 bool Optimizer::Run(const uint32_t* original_binary,
637                     const size_t original_binary_size,
638                     std::vector<uint32_t>* optimized_binary) const {
639   return Run(original_binary, original_binary_size, optimized_binary,
640              OptimizerOptions());
641 }
642 
Run(const uint32_t * original_binary,const size_t original_binary_size,std::vector<uint32_t> * optimized_binary,const ValidatorOptions & validator_options,bool skip_validation) const643 bool Optimizer::Run(const uint32_t* original_binary,
644                     const size_t original_binary_size,
645                     std::vector<uint32_t>* optimized_binary,
646                     const ValidatorOptions& validator_options,
647                     bool skip_validation) const {
648   OptimizerOptions opt_options;
649   opt_options.set_run_validator(!skip_validation);
650   opt_options.set_validator_options(validator_options);
651   return Run(original_binary, original_binary_size, optimized_binary,
652              opt_options);
653 }
654 
Run(const uint32_t * original_binary,const size_t original_binary_size,std::vector<uint32_t> * optimized_binary,const spv_optimizer_options opt_options) const655 bool Optimizer::Run(const uint32_t* original_binary,
656                     const size_t original_binary_size,
657                     std::vector<uint32_t>* optimized_binary,
658                     const spv_optimizer_options opt_options) const {
659   spvtools::SpirvTools tools(impl_->target_env);
660   tools.SetMessageConsumer(impl_->pass_manager.consumer());
661   if (opt_options->run_validator_ &&
662       !tools.Validate(original_binary, original_binary_size,
663                       &opt_options->val_options_)) {
664     return false;
665   }
666 
667   std::unique_ptr<opt::IRContext> context = BuildModule(
668       impl_->target_env, consumer(), original_binary, original_binary_size);
669   if (context == nullptr) return false;
670 
671   context->set_max_id_bound(opt_options->max_id_bound_);
672   context->set_preserve_bindings(opt_options->preserve_bindings_);
673   context->set_preserve_spec_constants(opt_options->preserve_spec_constants_);
674 
675   impl_->pass_manager.SetValidatorOptions(&opt_options->val_options_);
676   impl_->pass_manager.SetTargetEnv(impl_->target_env);
677   auto status = impl_->pass_manager.Run(context.get());
678 
679   if (status == opt::Pass::Status::Failure) {
680     return false;
681   }
682 
683 #ifndef NDEBUG
684   // We do not keep the result id of DebugScope in struct DebugScope.
685   // Instead, we assign random ids for them, which results in integrity
686   // check failures. In addition, propagating the OpLine/OpNoLine to preserve
687   // the debug information through transformations results in integrity
688   // check failures. We want to skip the integrity check when the module
689   // contains DebugScope or OpLine/OpNoLine instructions.
690   if (status == opt::Pass::Status::SuccessWithoutChange &&
691       !context->module()->ContainsDebugInfo()) {
692     std::vector<uint32_t> optimized_binary_with_nop;
693     context->module()->ToBinary(&optimized_binary_with_nop,
694                                 /* skip_nop = */ false);
695     assert(optimized_binary_with_nop.size() == original_binary_size &&
696            "Binary size unexpectedly changed despite the optimizer saying "
697            "there was no change");
698 
699     // Compare the magic number to make sure the binaries were encoded in the
700     // endianness.  If not, the contents of the binaries will be different, so
701     // do not check the contents.
702     if (optimized_binary_with_nop[0] == original_binary[0]) {
703       assert(memcmp(optimized_binary_with_nop.data(), original_binary,
704                     original_binary_size) == 0 &&
705              "Binary content unexpectedly changed despite the optimizer saying "
706              "there was no change");
707     }
708   }
709 #endif  // !NDEBUG
710 
711   // Note that |original_binary| and |optimized_binary| may share the same
712   // buffer and the below will invalidate |original_binary|.
713   optimized_binary->clear();
714   context->module()->ToBinary(optimized_binary, /* skip_nop = */ true);
715 
716   return true;
717 }
718 
SetPrintAll(std::ostream * out)719 Optimizer& Optimizer::SetPrintAll(std::ostream* out) {
720   impl_->pass_manager.SetPrintAll(out);
721   return *this;
722 }
723 
SetTimeReport(std::ostream * out)724 Optimizer& Optimizer::SetTimeReport(std::ostream* out) {
725   impl_->pass_manager.SetTimeReport(out);
726   return *this;
727 }
728 
SetValidateAfterAll(bool validate)729 Optimizer& Optimizer::SetValidateAfterAll(bool validate) {
730   impl_->pass_manager.SetValidateAfterAll(validate);
731   return *this;
732 }
733 
CreateNullPass()734 Optimizer::PassToken CreateNullPass() {
735   return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::NullPass>());
736 }
737 
CreateStripDebugInfoPass()738 Optimizer::PassToken CreateStripDebugInfoPass() {
739   return MakeUnique<Optimizer::PassToken::Impl>(
740       MakeUnique<opt::StripDebugInfoPass>());
741 }
742 
CreateStripReflectInfoPass()743 Optimizer::PassToken CreateStripReflectInfoPass() {
744   return CreateStripNonSemanticInfoPass();
745 }
746 
CreateStripNonSemanticInfoPass()747 Optimizer::PassToken CreateStripNonSemanticInfoPass() {
748   return MakeUnique<Optimizer::PassToken::Impl>(
749       MakeUnique<opt::StripNonSemanticInfoPass>());
750 }
751 
CreateEliminateDeadFunctionsPass()752 Optimizer::PassToken CreateEliminateDeadFunctionsPass() {
753   return MakeUnique<Optimizer::PassToken::Impl>(
754       MakeUnique<opt::EliminateDeadFunctionsPass>());
755 }
756 
CreateEliminateDeadMembersPass()757 Optimizer::PassToken CreateEliminateDeadMembersPass() {
758   return MakeUnique<Optimizer::PassToken::Impl>(
759       MakeUnique<opt::EliminateDeadMembersPass>());
760 }
761 
CreateSetSpecConstantDefaultValuePass(const std::unordered_map<uint32_t,std::string> & id_value_map)762 Optimizer::PassToken CreateSetSpecConstantDefaultValuePass(
763     const std::unordered_map<uint32_t, std::string>& id_value_map) {
764   return MakeUnique<Optimizer::PassToken::Impl>(
765       MakeUnique<opt::SetSpecConstantDefaultValuePass>(id_value_map));
766 }
767 
CreateSetSpecConstantDefaultValuePass(const std::unordered_map<uint32_t,std::vector<uint32_t>> & id_value_map)768 Optimizer::PassToken CreateSetSpecConstantDefaultValuePass(
769     const std::unordered_map<uint32_t, std::vector<uint32_t>>& id_value_map) {
770   return MakeUnique<Optimizer::PassToken::Impl>(
771       MakeUnique<opt::SetSpecConstantDefaultValuePass>(id_value_map));
772 }
773 
CreateFlattenDecorationPass()774 Optimizer::PassToken CreateFlattenDecorationPass() {
775   return MakeUnique<Optimizer::PassToken::Impl>(
776       MakeUnique<opt::FlattenDecorationPass>());
777 }
778 
CreateFreezeSpecConstantValuePass()779 Optimizer::PassToken CreateFreezeSpecConstantValuePass() {
780   return MakeUnique<Optimizer::PassToken::Impl>(
781       MakeUnique<opt::FreezeSpecConstantValuePass>());
782 }
783 
CreateFoldSpecConstantOpAndCompositePass()784 Optimizer::PassToken CreateFoldSpecConstantOpAndCompositePass() {
785   return MakeUnique<Optimizer::PassToken::Impl>(
786       MakeUnique<opt::FoldSpecConstantOpAndCompositePass>());
787 }
788 
CreateUnifyConstantPass()789 Optimizer::PassToken CreateUnifyConstantPass() {
790   return MakeUnique<Optimizer::PassToken::Impl>(
791       MakeUnique<opt::UnifyConstantPass>());
792 }
793 
CreateEliminateDeadConstantPass()794 Optimizer::PassToken CreateEliminateDeadConstantPass() {
795   return MakeUnique<Optimizer::PassToken::Impl>(
796       MakeUnique<opt::EliminateDeadConstantPass>());
797 }
798 
CreateDeadVariableEliminationPass()799 Optimizer::PassToken CreateDeadVariableEliminationPass() {
800   return MakeUnique<Optimizer::PassToken::Impl>(
801       MakeUnique<opt::DeadVariableElimination>());
802 }
803 
CreateStrengthReductionPass()804 Optimizer::PassToken CreateStrengthReductionPass() {
805   return MakeUnique<Optimizer::PassToken::Impl>(
806       MakeUnique<opt::StrengthReductionPass>());
807 }
808 
CreateBlockMergePass()809 Optimizer::PassToken CreateBlockMergePass() {
810   return MakeUnique<Optimizer::PassToken::Impl>(
811       MakeUnique<opt::BlockMergePass>());
812 }
813 
CreateInlineExhaustivePass()814 Optimizer::PassToken CreateInlineExhaustivePass() {
815   return MakeUnique<Optimizer::PassToken::Impl>(
816       MakeUnique<opt::InlineExhaustivePass>());
817 }
818 
CreateInlineOpaquePass()819 Optimizer::PassToken CreateInlineOpaquePass() {
820   return MakeUnique<Optimizer::PassToken::Impl>(
821       MakeUnique<opt::InlineOpaquePass>());
822 }
823 
CreateLocalAccessChainConvertPass()824 Optimizer::PassToken CreateLocalAccessChainConvertPass() {
825   return MakeUnique<Optimizer::PassToken::Impl>(
826       MakeUnique<opt::LocalAccessChainConvertPass>());
827 }
828 
CreateLocalSingleBlockLoadStoreElimPass()829 Optimizer::PassToken CreateLocalSingleBlockLoadStoreElimPass() {
830   return MakeUnique<Optimizer::PassToken::Impl>(
831       MakeUnique<opt::LocalSingleBlockLoadStoreElimPass>());
832 }
833 
CreateLocalSingleStoreElimPass()834 Optimizer::PassToken CreateLocalSingleStoreElimPass() {
835   return MakeUnique<Optimizer::PassToken::Impl>(
836       MakeUnique<opt::LocalSingleStoreElimPass>());
837 }
838 
CreateInsertExtractElimPass()839 Optimizer::PassToken CreateInsertExtractElimPass() {
840   return MakeUnique<Optimizer::PassToken::Impl>(
841       MakeUnique<opt::SimplificationPass>());
842 }
843 
CreateDeadInsertElimPass()844 Optimizer::PassToken CreateDeadInsertElimPass() {
845   return MakeUnique<Optimizer::PassToken::Impl>(
846       MakeUnique<opt::DeadInsertElimPass>());
847 }
848 
CreateDeadBranchElimPass()849 Optimizer::PassToken CreateDeadBranchElimPass() {
850   return MakeUnique<Optimizer::PassToken::Impl>(
851       MakeUnique<opt::DeadBranchElimPass>());
852 }
853 
CreateLocalMultiStoreElimPass()854 Optimizer::PassToken CreateLocalMultiStoreElimPass() {
855   return MakeUnique<Optimizer::PassToken::Impl>(
856       MakeUnique<opt::SSARewritePass>());
857 }
858 
CreateAggressiveDCEPass()859 Optimizer::PassToken CreateAggressiveDCEPass() {
860   return MakeUnique<Optimizer::PassToken::Impl>(
861       MakeUnique<opt::AggressiveDCEPass>(false, false));
862 }
863 
CreateAggressiveDCEPass(bool preserve_interface)864 Optimizer::PassToken CreateAggressiveDCEPass(bool preserve_interface) {
865   return MakeUnique<Optimizer::PassToken::Impl>(
866       MakeUnique<opt::AggressiveDCEPass>(preserve_interface, false));
867 }
868 
CreateAggressiveDCEPass(bool preserve_interface,bool remove_outputs)869 Optimizer::PassToken CreateAggressiveDCEPass(bool preserve_interface,
870                                              bool remove_outputs) {
871   return MakeUnique<Optimizer::PassToken::Impl>(
872       MakeUnique<opt::AggressiveDCEPass>(preserve_interface, remove_outputs));
873 }
874 
CreateRemoveUnusedInterfaceVariablesPass()875 Optimizer::PassToken CreateRemoveUnusedInterfaceVariablesPass() {
876   return MakeUnique<Optimizer::PassToken::Impl>(
877       MakeUnique<opt::RemoveUnusedInterfaceVariablesPass>());
878 }
879 
CreatePropagateLineInfoPass()880 Optimizer::PassToken CreatePropagateLineInfoPass() {
881   return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::EmptyPass>());
882 }
883 
CreateRedundantLineInfoElimPass()884 Optimizer::PassToken CreateRedundantLineInfoElimPass() {
885   return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::EmptyPass>());
886 }
887 
CreateCompactIdsPass()888 Optimizer::PassToken CreateCompactIdsPass() {
889   return MakeUnique<Optimizer::PassToken::Impl>(
890       MakeUnique<opt::CompactIdsPass>());
891 }
892 
CreateMergeReturnPass()893 Optimizer::PassToken CreateMergeReturnPass() {
894   return MakeUnique<Optimizer::PassToken::Impl>(
895       MakeUnique<opt::MergeReturnPass>());
896 }
897 
GetPassNames() const898 std::vector<const char*> Optimizer::GetPassNames() const {
899   std::vector<const char*> v;
900   for (uint32_t i = 0; i < impl_->pass_manager.NumPasses(); i++) {
901     v.push_back(impl_->pass_manager.GetPass(i)->name());
902   }
903   return v;
904 }
905 
CreateCFGCleanupPass()906 Optimizer::PassToken CreateCFGCleanupPass() {
907   return MakeUnique<Optimizer::PassToken::Impl>(
908       MakeUnique<opt::CFGCleanupPass>());
909 }
910 
CreateLocalRedundancyEliminationPass()911 Optimizer::PassToken CreateLocalRedundancyEliminationPass() {
912   return MakeUnique<Optimizer::PassToken::Impl>(
913       MakeUnique<opt::LocalRedundancyEliminationPass>());
914 }
915 
CreateLoopFissionPass(size_t threshold)916 Optimizer::PassToken CreateLoopFissionPass(size_t threshold) {
917   return MakeUnique<Optimizer::PassToken::Impl>(
918       MakeUnique<opt::LoopFissionPass>(threshold));
919 }
920 
CreateLoopFusionPass(size_t max_registers_per_loop)921 Optimizer::PassToken CreateLoopFusionPass(size_t max_registers_per_loop) {
922   return MakeUnique<Optimizer::PassToken::Impl>(
923       MakeUnique<opt::LoopFusionPass>(max_registers_per_loop));
924 }
925 
CreateLoopInvariantCodeMotionPass()926 Optimizer::PassToken CreateLoopInvariantCodeMotionPass() {
927   return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::LICMPass>());
928 }
929 
CreateLoopPeelingPass()930 Optimizer::PassToken CreateLoopPeelingPass() {
931   return MakeUnique<Optimizer::PassToken::Impl>(
932       MakeUnique<opt::LoopPeelingPass>());
933 }
934 
CreateLoopUnswitchPass()935 Optimizer::PassToken CreateLoopUnswitchPass() {
936   return MakeUnique<Optimizer::PassToken::Impl>(
937       MakeUnique<opt::LoopUnswitchPass>());
938 }
939 
CreateRedundancyEliminationPass()940 Optimizer::PassToken CreateRedundancyEliminationPass() {
941   return MakeUnique<Optimizer::PassToken::Impl>(
942       MakeUnique<opt::RedundancyEliminationPass>());
943 }
944 
CreateRemoveDuplicatesPass()945 Optimizer::PassToken CreateRemoveDuplicatesPass() {
946   return MakeUnique<Optimizer::PassToken::Impl>(
947       MakeUnique<opt::RemoveDuplicatesPass>());
948 }
949 
CreateScalarReplacementPass(uint32_t size_limit)950 Optimizer::PassToken CreateScalarReplacementPass(uint32_t size_limit) {
951   return MakeUnique<Optimizer::PassToken::Impl>(
952       MakeUnique<opt::ScalarReplacementPass>(size_limit));
953 }
954 
CreatePrivateToLocalPass()955 Optimizer::PassToken CreatePrivateToLocalPass() {
956   return MakeUnique<Optimizer::PassToken::Impl>(
957       MakeUnique<opt::PrivateToLocalPass>());
958 }
959 
CreateCCPPass()960 Optimizer::PassToken CreateCCPPass() {
961   return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::CCPPass>());
962 }
963 
CreateWorkaround1209Pass()964 Optimizer::PassToken CreateWorkaround1209Pass() {
965   return MakeUnique<Optimizer::PassToken::Impl>(
966       MakeUnique<opt::Workaround1209>());
967 }
968 
CreateIfConversionPass()969 Optimizer::PassToken CreateIfConversionPass() {
970   return MakeUnique<Optimizer::PassToken::Impl>(
971       MakeUnique<opt::IfConversion>());
972 }
973 
CreateReplaceInvalidOpcodePass()974 Optimizer::PassToken CreateReplaceInvalidOpcodePass() {
975   return MakeUnique<Optimizer::PassToken::Impl>(
976       MakeUnique<opt::ReplaceInvalidOpcodePass>());
977 }
978 
CreateSimplificationPass()979 Optimizer::PassToken CreateSimplificationPass() {
980   return MakeUnique<Optimizer::PassToken::Impl>(
981       MakeUnique<opt::SimplificationPass>());
982 }
983 
CreateLoopUnrollPass(bool fully_unroll,int factor)984 Optimizer::PassToken CreateLoopUnrollPass(bool fully_unroll, int factor) {
985   return MakeUnique<Optimizer::PassToken::Impl>(
986       MakeUnique<opt::LoopUnroller>(fully_unroll, factor));
987 }
988 
CreateSSARewritePass()989 Optimizer::PassToken CreateSSARewritePass() {
990   return MakeUnique<Optimizer::PassToken::Impl>(
991       MakeUnique<opt::SSARewritePass>());
992 }
993 
CreateCopyPropagateArraysPass()994 Optimizer::PassToken CreateCopyPropagateArraysPass() {
995   return MakeUnique<Optimizer::PassToken::Impl>(
996       MakeUnique<opt::CopyPropagateArrays>());
997 }
998 
CreateVectorDCEPass()999 Optimizer::PassToken CreateVectorDCEPass() {
1000   return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::VectorDCE>());
1001 }
1002 
CreateReduceLoadSizePass(double load_replacement_threshold)1003 Optimizer::PassToken CreateReduceLoadSizePass(
1004     double load_replacement_threshold) {
1005   return MakeUnique<Optimizer::PassToken::Impl>(
1006       MakeUnique<opt::ReduceLoadSize>(load_replacement_threshold));
1007 }
1008 
CreateCombineAccessChainsPass()1009 Optimizer::PassToken CreateCombineAccessChainsPass() {
1010   return MakeUnique<Optimizer::PassToken::Impl>(
1011       MakeUnique<opt::CombineAccessChains>());
1012 }
1013 
CreateUpgradeMemoryModelPass()1014 Optimizer::PassToken CreateUpgradeMemoryModelPass() {
1015   return MakeUnique<Optimizer::PassToken::Impl>(
1016       MakeUnique<opt::UpgradeMemoryModel>());
1017 }
1018 
CreateInstDebugPrintfPass(uint32_t desc_set,uint32_t shader_id)1019 Optimizer::PassToken CreateInstDebugPrintfPass(uint32_t desc_set,
1020                                                uint32_t shader_id) {
1021   return MakeUnique<Optimizer::PassToken::Impl>(
1022       MakeUnique<opt::InstDebugPrintfPass>(desc_set, shader_id));
1023 }
1024 
CreateConvertRelaxedToHalfPass()1025 Optimizer::PassToken CreateConvertRelaxedToHalfPass() {
1026   return MakeUnique<Optimizer::PassToken::Impl>(
1027       MakeUnique<opt::ConvertToHalfPass>());
1028 }
1029 
CreateRelaxFloatOpsPass()1030 Optimizer::PassToken CreateRelaxFloatOpsPass() {
1031   return MakeUnique<Optimizer::PassToken::Impl>(
1032       MakeUnique<opt::RelaxFloatOpsPass>());
1033 }
1034 
CreateCodeSinkingPass()1035 Optimizer::PassToken CreateCodeSinkingPass() {
1036   return MakeUnique<Optimizer::PassToken::Impl>(
1037       MakeUnique<opt::CodeSinkingPass>());
1038 }
1039 
CreateFixStorageClassPass()1040 Optimizer::PassToken CreateFixStorageClassPass() {
1041   return MakeUnique<Optimizer::PassToken::Impl>(
1042       MakeUnique<opt::FixStorageClass>());
1043 }
1044 
CreateGraphicsRobustAccessPass()1045 Optimizer::PassToken CreateGraphicsRobustAccessPass() {
1046   return MakeUnique<Optimizer::PassToken::Impl>(
1047       MakeUnique<opt::GraphicsRobustAccessPass>());
1048 }
1049 
CreateReplaceDescArrayAccessUsingVarIndexPass()1050 Optimizer::PassToken CreateReplaceDescArrayAccessUsingVarIndexPass() {
1051   return MakeUnique<Optimizer::PassToken::Impl>(
1052       MakeUnique<opt::ReplaceDescArrayAccessUsingVarIndex>());
1053 }
1054 
CreateSpreadVolatileSemanticsPass()1055 Optimizer::PassToken CreateSpreadVolatileSemanticsPass() {
1056   return MakeUnique<Optimizer::PassToken::Impl>(
1057       MakeUnique<opt::SpreadVolatileSemantics>());
1058 }
1059 
CreateDescriptorScalarReplacementPass()1060 Optimizer::PassToken CreateDescriptorScalarReplacementPass() {
1061   return MakeUnique<Optimizer::PassToken::Impl>(
1062       MakeUnique<opt::DescriptorScalarReplacement>());
1063 }
1064 
CreateWrapOpKillPass()1065 Optimizer::PassToken CreateWrapOpKillPass() {
1066   return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::WrapOpKill>());
1067 }
1068 
CreateAmdExtToKhrPass()1069 Optimizer::PassToken CreateAmdExtToKhrPass() {
1070   return MakeUnique<Optimizer::PassToken::Impl>(
1071       MakeUnique<opt::AmdExtensionToKhrPass>());
1072 }
1073 
CreateInterpolateFixupPass()1074 Optimizer::PassToken CreateInterpolateFixupPass() {
1075   return MakeUnique<Optimizer::PassToken::Impl>(
1076       MakeUnique<opt::InterpFixupPass>());
1077 }
1078 
CreateEliminateDeadInputComponentsPass()1079 Optimizer::PassToken CreateEliminateDeadInputComponentsPass() {
1080   return MakeUnique<Optimizer::PassToken::Impl>(
1081       MakeUnique<opt::EliminateDeadIOComponentsPass>(spv::StorageClass::Input,
1082                                                      /* safe_mode */ false));
1083 }
1084 
CreateEliminateDeadOutputComponentsPass()1085 Optimizer::PassToken CreateEliminateDeadOutputComponentsPass() {
1086   return MakeUnique<Optimizer::PassToken::Impl>(
1087       MakeUnique<opt::EliminateDeadIOComponentsPass>(spv::StorageClass::Output,
1088                                                      /* safe_mode */ false));
1089 }
1090 
CreateEliminateDeadInputComponentsSafePass()1091 Optimizer::PassToken CreateEliminateDeadInputComponentsSafePass() {
1092   return MakeUnique<Optimizer::PassToken::Impl>(
1093       MakeUnique<opt::EliminateDeadIOComponentsPass>(spv::StorageClass::Input,
1094                                                      /* safe_mode */ true));
1095 }
1096 
CreateAnalyzeLiveInputPass(std::unordered_set<uint32_t> * live_locs,std::unordered_set<uint32_t> * live_builtins)1097 Optimizer::PassToken CreateAnalyzeLiveInputPass(
1098     std::unordered_set<uint32_t>* live_locs,
1099     std::unordered_set<uint32_t>* live_builtins) {
1100   return MakeUnique<Optimizer::PassToken::Impl>(
1101       MakeUnique<opt::AnalyzeLiveInputPass>(live_locs, live_builtins));
1102 }
1103 
CreateEliminateDeadOutputStoresPass(std::unordered_set<uint32_t> * live_locs,std::unordered_set<uint32_t> * live_builtins)1104 Optimizer::PassToken CreateEliminateDeadOutputStoresPass(
1105     std::unordered_set<uint32_t>* live_locs,
1106     std::unordered_set<uint32_t>* live_builtins) {
1107   return MakeUnique<Optimizer::PassToken::Impl>(
1108       MakeUnique<opt::EliminateDeadOutputStoresPass>(live_locs, live_builtins));
1109 }
1110 
CreateConvertToSampledImagePass(const std::vector<opt::DescriptorSetAndBinding> & descriptor_set_binding_pairs)1111 Optimizer::PassToken CreateConvertToSampledImagePass(
1112     const std::vector<opt::DescriptorSetAndBinding>&
1113         descriptor_set_binding_pairs) {
1114   return MakeUnique<Optimizer::PassToken::Impl>(
1115       MakeUnique<opt::ConvertToSampledImagePass>(descriptor_set_binding_pairs));
1116 }
1117 
CreateInterfaceVariableScalarReplacementPass()1118 Optimizer::PassToken CreateInterfaceVariableScalarReplacementPass() {
1119   return MakeUnique<Optimizer::PassToken::Impl>(
1120       MakeUnique<opt::InterfaceVariableScalarReplacement>());
1121 }
1122 
CreateRemoveDontInlinePass()1123 Optimizer::PassToken CreateRemoveDontInlinePass() {
1124   return MakeUnique<Optimizer::PassToken::Impl>(
1125       MakeUnique<opt::RemoveDontInline>());
1126 }
1127 
CreateFixFuncCallArgumentsPass()1128 Optimizer::PassToken CreateFixFuncCallArgumentsPass() {
1129   return MakeUnique<Optimizer::PassToken::Impl>(
1130       MakeUnique<opt::FixFuncCallArgumentsPass>());
1131 }
1132 
CreateTrimCapabilitiesPass()1133 Optimizer::PassToken CreateTrimCapabilitiesPass() {
1134   return MakeUnique<Optimizer::PassToken::Impl>(
1135       MakeUnique<opt::TrimCapabilitiesPass>());
1136 }
1137 
CreateSwitchDescriptorSetPass(uint32_t from,uint32_t to)1138 Optimizer::PassToken CreateSwitchDescriptorSetPass(uint32_t from, uint32_t to) {
1139   return MakeUnique<Optimizer::PassToken::Impl>(
1140       MakeUnique<opt::SwitchDescriptorSetPass>(from, to));
1141 }
1142 
CreateInvocationInterlockPlacementPass()1143 Optimizer::PassToken CreateInvocationInterlockPlacementPass() {
1144   return MakeUnique<Optimizer::PassToken::Impl>(
1145       MakeUnique<opt::InvocationInterlockPlacementPass>());
1146 }
1147 
CreateModifyMaximalReconvergencePass(bool add)1148 Optimizer::PassToken CreateModifyMaximalReconvergencePass(bool add) {
1149   return MakeUnique<Optimizer::PassToken::Impl>(
1150       MakeUnique<opt::ModifyMaximalReconvergence>(add));
1151 }
1152 
CreateOpExtInstWithForwardReferenceFixupPass()1153 Optimizer::PassToken CreateOpExtInstWithForwardReferenceFixupPass() {
1154   return MakeUnique<Optimizer::PassToken::Impl>(
1155       MakeUnique<opt::OpExtInstWithForwardReferenceFixupPass>());
1156 }
1157 
1158 }  // namespace spvtools
1159 
1160 extern "C" {
1161 
spvOptimizerCreate(spv_target_env env)1162 SPIRV_TOOLS_EXPORT spv_optimizer_t* spvOptimizerCreate(spv_target_env env) {
1163   return reinterpret_cast<spv_optimizer_t*>(new spvtools::Optimizer(env));
1164 }
1165 
spvOptimizerDestroy(spv_optimizer_t * optimizer)1166 SPIRV_TOOLS_EXPORT void spvOptimizerDestroy(spv_optimizer_t* optimizer) {
1167   delete reinterpret_cast<spvtools::Optimizer*>(optimizer);
1168 }
1169 
spvOptimizerSetMessageConsumer(spv_optimizer_t * optimizer,spv_message_consumer consumer)1170 SPIRV_TOOLS_EXPORT void spvOptimizerSetMessageConsumer(
1171     spv_optimizer_t* optimizer, spv_message_consumer consumer) {
1172   reinterpret_cast<spvtools::Optimizer*>(optimizer)->
1173       SetMessageConsumer(
1174           [consumer](spv_message_level_t level, const char* source,
1175                      const spv_position_t& position, const char* message) {
1176             return consumer(level, source, &position, message);
1177           });
1178 }
1179 
spvOptimizerRegisterLegalizationPasses(spv_optimizer_t * optimizer)1180 SPIRV_TOOLS_EXPORT void spvOptimizerRegisterLegalizationPasses(
1181     spv_optimizer_t* optimizer) {
1182   reinterpret_cast<spvtools::Optimizer*>(optimizer)->
1183       RegisterLegalizationPasses();
1184 }
1185 
spvOptimizerRegisterPerformancePasses(spv_optimizer_t * optimizer)1186 SPIRV_TOOLS_EXPORT void spvOptimizerRegisterPerformancePasses(
1187     spv_optimizer_t* optimizer) {
1188   reinterpret_cast<spvtools::Optimizer*>(optimizer)->
1189       RegisterPerformancePasses();
1190 }
1191 
spvOptimizerRegisterSizePasses(spv_optimizer_t * optimizer)1192 SPIRV_TOOLS_EXPORT void spvOptimizerRegisterSizePasses(
1193     spv_optimizer_t* optimizer) {
1194   reinterpret_cast<spvtools::Optimizer*>(optimizer)->RegisterSizePasses();
1195 }
1196 
spvOptimizerRegisterPassFromFlag(spv_optimizer_t * optimizer,const char * flag)1197 SPIRV_TOOLS_EXPORT bool spvOptimizerRegisterPassFromFlag(
1198     spv_optimizer_t* optimizer, const char* flag)
1199 {
1200   return reinterpret_cast<spvtools::Optimizer*>(optimizer)->
1201       RegisterPassFromFlag(flag);
1202 }
1203 
spvOptimizerRegisterPassesFromFlags(spv_optimizer_t * optimizer,const char ** flags,const size_t flag_count)1204 SPIRV_TOOLS_EXPORT bool spvOptimizerRegisterPassesFromFlags(
1205     spv_optimizer_t* optimizer, const char** flags, const size_t flag_count) {
1206   std::vector<std::string> opt_flags =
1207       spvtools::GetVectorOfStrings(flags, flag_count);
1208   return reinterpret_cast<spvtools::Optimizer*>(optimizer)
1209       ->RegisterPassesFromFlags(opt_flags, false);
1210 }
1211 
1212 SPIRV_TOOLS_EXPORT bool
spvOptimizerRegisterPassesFromFlagsWhilePreservingTheInterface(spv_optimizer_t * optimizer,const char ** flags,const size_t flag_count)1213 spvOptimizerRegisterPassesFromFlagsWhilePreservingTheInterface(
1214     spv_optimizer_t* optimizer, const char** flags, const size_t flag_count) {
1215   std::vector<std::string> opt_flags =
1216       spvtools::GetVectorOfStrings(flags, flag_count);
1217   return reinterpret_cast<spvtools::Optimizer*>(optimizer)
1218       ->RegisterPassesFromFlags(opt_flags, true);
1219 }
1220 
1221 SPIRV_TOOLS_EXPORT
spvOptimizerRun(spv_optimizer_t * optimizer,const uint32_t * binary,const size_t word_count,spv_binary * optimized_binary,const spv_optimizer_options options)1222 spv_result_t spvOptimizerRun(spv_optimizer_t* optimizer,
1223                              const uint32_t* binary,
1224                              const size_t word_count,
1225                              spv_binary* optimized_binary,
1226                              const spv_optimizer_options options) {
1227   std::vector<uint32_t> optimized;
1228 
1229   if (!reinterpret_cast<spvtools::Optimizer*>(optimizer)->
1230       Run(binary, word_count, &optimized, options)) {
1231     return SPV_ERROR_INTERNAL;
1232   }
1233 
1234   auto result_binary = new spv_binary_t();
1235   if (!result_binary) {
1236       *optimized_binary = nullptr;
1237       return SPV_ERROR_OUT_OF_MEMORY;
1238   }
1239 
1240   result_binary->code = new uint32_t[optimized.size()];
1241   if (!result_binary->code) {
1242       delete result_binary;
1243       *optimized_binary = nullptr;
1244       return SPV_ERROR_OUT_OF_MEMORY;
1245   }
1246   result_binary->wordCount = optimized.size();
1247 
1248   memcpy(result_binary->code, optimized.data(),
1249          optimized.size() * sizeof(uint32_t));
1250 
1251   *optimized_binary = result_binary;
1252 
1253   return SPV_SUCCESS;
1254 }
1255 
1256 }  // extern "C"
1257