• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2016 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "test/opt/pass_utils.h"
16 
17 #include <algorithm>
18 #include <sstream>
19 
20 namespace spvtools {
21 namespace opt {
22 namespace {
23 
24 // Well, this is another place requiring the knowledge of the grammar and can be
25 // stale when SPIR-V is updated. It would be nice to automatically generate
26 // this, but the cost is just too high.
27 
28 const char* kDebugOpcodes[] = {
29     // clang-format off
30     "OpSourceContinued", "OpSource", "OpSourceExtension",
31     "OpName", "OpMemberName", "OpString",
32     "OpLine", "OpNoLine", "OpModuleProcessed"
33     // clang-format on
34 };
35 
36 }  // anonymous namespace
37 
GetTestMessageConsumer(std::vector<Message> & expected_messages)38 MessageConsumer GetTestMessageConsumer(
39     std::vector<Message>& expected_messages) {
40   return [&expected_messages](spv_message_level_t level, const char* source,
41                               const spv_position_t& position,
42                               const char* message) {
43     EXPECT_TRUE(!expected_messages.empty());
44     if (expected_messages.empty()) {
45       return;
46     }
47 
48     EXPECT_EQ(expected_messages[0].level, level);
49     EXPECT_EQ(expected_messages[0].line_number, position.line);
50     EXPECT_EQ(expected_messages[0].column_number, position.column);
51     EXPECT_STREQ(expected_messages[0].source_file, source);
52     EXPECT_STREQ(expected_messages[0].message, message);
53 
54     expected_messages.erase(expected_messages.begin());
55   };
56 }
57 
FindAndReplace(std::string * process_str,const std::string find_str,const std::string replace_str)58 bool FindAndReplace(std::string* process_str, const std::string find_str,
59                     const std::string replace_str) {
60   if (process_str->empty() || find_str.empty()) {
61     return false;
62   }
63   bool replaced = false;
64   // Note this algorithm has quadratic time complexity. It is OK for test cases
65   // with short strings, but might not fit in other contexts.
66   for (size_t pos = process_str->find(find_str, 0); pos != std::string::npos;
67        pos = process_str->find(find_str, pos)) {
68     process_str->replace(pos, find_str.length(), replace_str);
69     pos += replace_str.length();
70     replaced = true;
71   }
72   return replaced;
73 }
74 
ContainsDebugOpcode(const char * inst)75 bool ContainsDebugOpcode(const char* inst) {
76   return std::any_of(std::begin(kDebugOpcodes), std::end(kDebugOpcodes),
77                      [inst](const char* op) {
78                        return std::string(inst).find(op) != std::string::npos;
79                      });
80 }
81 
SelectiveJoin(const std::vector<const char * > & strings,const std::function<bool (const char *)> & skip_dictator,char delimiter)82 std::string SelectiveJoin(const std::vector<const char*>& strings,
83                           const std::function<bool(const char*)>& skip_dictator,
84                           char delimiter) {
85   std::ostringstream oss;
86   for (const auto* str : strings) {
87     if (!skip_dictator(str)) oss << str << delimiter;
88   }
89   return oss.str();
90 }
91 
JoinAllInsts(const std::vector<const char * > & insts)92 std::string JoinAllInsts(const std::vector<const char*>& insts) {
93   return SelectiveJoin(insts, [](const char*) { return false; });
94 }
95 
JoinNonDebugInsts(const std::vector<const char * > & insts)96 std::string JoinNonDebugInsts(const std::vector<const char*>& insts) {
97   return SelectiveJoin(
98       insts, [](const char* inst) { return ContainsDebugOpcode(inst); });
99 }
100 
101 }  // namespace opt
102 }  // namespace spvtools
103