• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Amber Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/vulkan/engine_vulkan.h"
16 
17 #if AMBER_ENABLE_VK_DEBUGGING
18 
19 #include <stdlib.h>
20 
21 #include <chrono>              // NOLINT(build/c++11)
22 #include <condition_variable>  // NOLINT(build/c++11)
23 #include <fstream>
24 #include <mutex>  // NOLINT(build/c++11)
25 #include <sstream>
26 #include <thread>  // NOLINT(build/c++11)
27 #include <unordered_map>
28 
29 #include "dap/network.h"
30 #include "dap/protocol.h"
31 #include "dap/session.h"
32 #include "src/virtual_file_store.h"
33 
34 // Set to 1 to enable verbose debugger logging
35 #define ENABLE_DEBUGGER_LOG 0
36 
37 #if ENABLE_DEBUGGER_LOG
38 #define DEBUGGER_LOG(...) \
39   do {                    \
40     printf(__VA_ARGS__);  \
41     printf("\n");         \
42   } while (false)
43 #else
44 #define DEBUGGER_LOG(...)
45 #endif
46 
47 namespace amber {
48 namespace vulkan {
49 
50 namespace {
51 
52 // kThreadTimeout is the maximum time to wait for a thread to complete.
53 static constexpr const auto kThreadTimeout = std::chrono::minutes(1);
54 
55 // kStepEventTimeout is the maximum time to wait for a thread to complete a step
56 // request.
57 static constexpr const auto kStepEventTimeout = std::chrono::seconds(1);
58 
59 // Event provides a basic wait-and-signal synchronization primitive.
60 class Event {
61  public:
62   // Wait blocks until the event is fired, or the timeout is reached.
63   // If the Event was signalled, then Wait returns true, otherwise false.
64   template <typename Rep, typename Period>
Wait(const std::chrono::duration<Rep,Period> & duration)65   bool Wait(const std::chrono::duration<Rep, Period>& duration) {
66     std::unique_lock<std::mutex> lock(mutex_);
67     return cv_.wait_for(lock, duration, [&] { return signalled_; });
68   }
69 
70   // Signal signals the Event, unblocking any calls to Wait.
Signal()71   void Signal() {
72     std::unique_lock<std::mutex> lock(mutex_);
73     signalled_ = true;
74     cv_.notify_all();
75   }
76 
77  private:
78   std::condition_variable cv_;
79   std::mutex mutex_;
80   bool signalled_ = false;
81 };
82 
83 // Split slices str into all substrings separated by sep and returns a vector of
84 // the substrings between those separators.
Split(const std::string & str,const std::string & sep)85 std::vector<std::string> Split(const std::string& str, const std::string& sep) {
86   std::vector<std::string> out;
87   std::size_t cur = 0;
88   std::size_t prev = 0;
89   while ((cur = str.find(sep, prev)) != std::string::npos) {
90     out.push_back(str.substr(prev, cur - prev));
91     prev = cur + 1;
92   }
93   out.push_back(str.substr(prev));
94   return out;
95 }
96 
97 // GlobalInvocationId holds a three-element unsigned integer index, used to
98 // identifiy a single compute invocation.
99 struct GlobalInvocationId {
hashamber::vulkan::__anonf24a5b550111::GlobalInvocationId100   size_t hash() const { return x << 20 | y << 10 | z; }
operator ==amber::vulkan::__anonf24a5b550111::GlobalInvocationId101   bool operator==(const GlobalInvocationId& other) const {
102     return x == other.x && y == other.y && z == other.z;
103   }
104 
105   uint32_t x;
106   uint32_t y;
107   uint32_t z;
108 };
109 
110 // WindowSpacePosition holds a two-element unsigned integer index, used to
111 // identifiy a single fragment invocation.
112 struct WindowSpacePosition {
hashamber::vulkan::__anonf24a5b550111::WindowSpacePosition113   size_t hash() const { return x << 10 | y; }
operator ==amber::vulkan::__anonf24a5b550111::WindowSpacePosition114   bool operator==(const WindowSpacePosition& other) const {
115     return x == other.x && y == other.y;
116   }
117 
118   uint32_t x;
119   uint32_t y;
120 };
121 
122 // Forward declaration.
123 struct Variable;
124 
125 // Variables is a list of Variable (), with helper methods.
126 class Variables : public std::vector<Variable> {
127  public:
128   inline const Variable* Find(const std::string& name) const;
129   inline std::string AllNames() const;
130 };
131 
132 // Variable holds a debugger returned named value (local, global, etc).
133 // Variables can hold child variables (for structs, arrays, etc).
134 struct Variable {
135   std::string name;
136   std::string value;
137   Variables children;
138 
139   // Get parses the Variable value for the requested type, assigning the result
140   // to |out|. Returns true on success, otherwise false.
Getamber::vulkan::__anonf24a5b550111::Variable141   bool Get(uint32_t* out) const {
142     *out = static_cast<uint32_t>(std::atoi(value.c_str()));
143     return true;  // TODO(bclayton): Verify the value parsed correctly.
144   }
145 
Getamber::vulkan::__anonf24a5b550111::Variable146   bool Get(int64_t* out) const {
147     *out = static_cast<int64_t>(std::atoi(value.c_str()));
148     return true;  // TODO(bclayton): Verify the value parsed correctly.
149   }
150 
Getamber::vulkan::__anonf24a5b550111::Variable151   bool Get(double* out) const {
152     *out = std::atof(value.c_str());
153     return true;  // TODO(bclayton): Verify the value parsed correctly.
154   }
155 
Getamber::vulkan::__anonf24a5b550111::Variable156   bool Get(std::string* out) const {
157     *out = value;
158     return true;
159   }
160 
Getamber::vulkan::__anonf24a5b550111::Variable161   bool Get(GlobalInvocationId* out) const {
162     auto x = children.Find("x");
163     auto y = children.Find("y");
164     auto z = children.Find("z");
165     return (x != nullptr && y != nullptr && z != nullptr && x->Get(&out->x) &&
166             y->Get(&out->y) && z->Get(&out->z));
167   }
168 
Getamber::vulkan::__anonf24a5b550111::Variable169   bool Get(WindowSpacePosition* out) const {
170     auto x = children.Find("x");
171     auto y = children.Find("y");
172     return (x != nullptr && y != nullptr && x->Get(&out->x) && y->Get(&out->y));
173   }
174 };
175 
Find(const std::string & name) const176 const Variable* Variables::Find(const std::string& name) const {
177   for (auto& child : *this) {
178     if (child.name == name) {
179       return &child;
180     }
181   }
182   return nullptr;
183 }
184 
AllNames() const185 std::string Variables::AllNames() const {
186   std::string out;
187   for (auto& var : *this) {
188     if (out.size() > 0) {
189       out += ", ";
190     }
191     out += "'" + var.name + "'";
192   }
193   return out;
194 }
195 
196 // Client wraps a dap::Session and a error handler, and provides a more
197 // convenient interface for talking to the debugger. Client also provides basic
198 // immutable data caching to help performance.
199 class Client {
200   static constexpr const char* kLocals = "locals";
201   static constexpr const char* kLane = "Lane";
202 
203  public:
204   using ErrorHandler = std::function<void(const std::string&)>;
205   using SourceLines = std::vector<std::string>;
206 
Client(const std::shared_ptr<dap::Session> & session,const ErrorHandler & onerror)207   Client(const std::shared_ptr<dap::Session>& session,
208          const ErrorHandler& onerror)
209       : session_(session), onerror_(onerror) {}
210 
211   // TopStackFrame retrieves the frame at the top of the thread's call stack.
212   // Returns true on success, false on error.
TopStackFrame(dap::integer thread_id,dap::StackFrame * frame)213   bool TopStackFrame(dap::integer thread_id, dap::StackFrame* frame) {
214     std::vector<dap::StackFrame> stack;
215     if (!Callstack(thread_id, &stack)) {
216       return false;
217     }
218     *frame = stack.front();
219     return true;
220   }
221 
222   // Callstack retrieves the thread's full call stack.
223   // Returns true on success, false on error.
Callstack(dap::integer thread_id,std::vector<dap::StackFrame> * stack)224   bool Callstack(dap::integer thread_id, std::vector<dap::StackFrame>* stack) {
225     dap::StackTraceRequest request;
226     request.threadId = thread_id;
227     auto response = session_->send(request).get();
228     if (response.error) {
229       onerror_(response.error.message);
230       return false;
231     }
232     if (response.response.stackFrames.size() == 0) {
233       onerror_("Stack frame is empty");
234       return false;
235     }
236     *stack = response.response.stackFrames;
237     return true;
238   }
239 
240   // FrameLocation retrieves the current frame source location, and optional
241   // source line text.
242   // Returns true on success, false on error.
FrameLocation(VirtualFileStore * virtual_files,const dap::StackFrame & frame,debug::Location * location,std::string * line=nullptr)243   bool FrameLocation(VirtualFileStore* virtual_files,
244                      const dap::StackFrame& frame,
245                      debug::Location* location,
246                      std::string* line = nullptr) {
247     location->line = static_cast<uint32_t>(frame.line);
248 
249     if (!frame.source.has_value()) {
250       onerror_("Stack frame with name '" + frame.name + "' has no source");
251       return false;
252     } else if (frame.source->path.has_value()) {
253       location->file = frame.source.value().path.value();
254     } else if (frame.source->name.has_value()) {
255       location->file = frame.source.value().name.value();
256     } else {
257       onerror_("Frame source had no path or name");
258       return false;
259     }
260 
261     if (location->line < 1) {
262       onerror_("Line location is " + std::to_string(location->line));
263       return false;
264     }
265 
266     if (line != nullptr) {
267       SourceLines lines;
268       if (!SourceContent(virtual_files, frame.source.value(), &lines)) {
269         return false;
270       }
271       if (location->line > lines.size()) {
272         onerror_("Line " + std::to_string(location->line) +
273                  " is greater than the number of lines in the source file (" +
274                  std::to_string(lines.size()) + ")");
275         return false;
276       }
277       if (location->line < 1) {
278         onerror_("Line number reported was " + std::to_string(location->line));
279         return false;
280       }
281       *line = lines[location->line - 1];
282     }
283 
284     return true;
285   }
286 
287   // SourceContext retrieves the the SourceLines for the given source.
288   // Returns true on success, false on error.
SourceContent(VirtualFileStore * virtual_files,const dap::Source & source,SourceLines * out)289   bool SourceContent(VirtualFileStore* virtual_files,
290                      const dap::Source& source,
291                      SourceLines* out) {
292     auto path = source.path.value("");
293 
294     if (source.sourceReference.has_value()) {
295       auto ref = source.sourceReference.value();
296       auto it = sourceCache_.by_ref.find(ref);
297       if (it != sourceCache_.by_ref.end()) {
298         *out = it->second;
299         return true;
300       }
301 
302       dap::SourceRequest request;
303       dap::SourceResponse response;
304       request.sourceReference = ref;
305       if (!Send(request, &response)) {
306         return false;
307       }
308       auto lines = Split(response.content, "\n");
309       sourceCache_.by_ref.emplace(ref, lines);
310       *out = lines;
311       return true;
312     }
313 
314     if (path != "") {
315       auto it = sourceCache_.by_path.find(path);
316       if (it != sourceCache_.by_path.end()) {
317         *out = it->second;
318         return true;
319       }
320 
321       // First search the virtual file store.
322       std::string content;
323       if (virtual_files->Get(path, &content).IsSuccess()) {
324         auto lines = Split(content, "\n");
325         sourceCache_.by_path.emplace(path, lines);
326         *out = lines;
327         return true;
328       }
329 
330       // TODO(bclayton) - We shouldn't be doing direct file IO here. We should
331       // bubble the IO request to the amber 'embedder'.
332       // See: https://github.com/google/amber/issues/777
333       if (auto file = std::ifstream(path)) {
334         SourceLines lines;
335         std::string line;
336         while (std::getline(file, line)) {
337           lines.emplace_back(line);
338         }
339 
340         sourceCache_.by_path.emplace(path, lines);
341         *out = lines;
342         return true;
343       }
344     }
345 
346     onerror_("Could not get source content");
347     return false;
348   }
349 
350   // Send sends the request to the debugger, waits for the request to complete,
351   // and then assigns the response to |res|.
352   // Returns true on success, false on error.
353   template <typename REQUEST, typename RESPONSE>
Send(const REQUEST & request,RESPONSE * res)354   bool Send(const REQUEST& request, RESPONSE* res) {
355     auto r = session_->send(request).get();
356     if (r.error) {
357       onerror_(r.error.message);
358       return false;
359     }
360     *res = r.response;
361     return true;
362   }
363 
364   // Send sends the request to the debugger, and waits for the request to
365   // complete.
366   // Returns true on success, false on error.
367   template <typename REQUEST>
Send(const REQUEST & request)368   bool Send(const REQUEST& request) {
369     using RESPONSE = typename REQUEST::Response;
370     RESPONSE response;
371     return Send(request, &response);
372   }
373 
374   // GetVariables fetches the fully traversed set of Variables from the debugger
375   // for the given reference identifier.
376   // Returns true on success, false on error.
GetVariables(dap::integer variablesRef,Variables * out)377   bool GetVariables(dap::integer variablesRef, Variables* out) {
378     dap::VariablesRequest request;
379     dap::VariablesResponse response;
380     request.variablesReference = variablesRef;
381     if (!Send(request, &response)) {
382       return false;
383     }
384     for (auto var : response.variables) {
385       Variable v;
386       v.name = var.name;
387       v.value = var.value;
388       if (var.variablesReference > 0) {
389         if (!GetVariables(var.variablesReference, &v.children)) {
390           return false;
391         }
392       }
393       out->emplace_back(v);
394     }
395     return true;
396   }
397 
398   // GetLocals fetches the fully traversed set of local Variables from the
399   // debugger for the given stack frame.
400   // Returns true on success, false on error.
GetLocals(const dap::StackFrame & frame,Variables * out)401   bool GetLocals(const dap::StackFrame& frame, Variables* out) {
402     dap::ScopesRequest scopeReq;
403     dap::ScopesResponse scopeRes;
404     scopeReq.frameId = frame.id;
405     if (!Send(scopeReq, &scopeRes)) {
406       return false;
407     }
408 
409     for (auto scope : scopeRes.scopes) {
410       if (scope.presentationHint.value("") == kLocals) {
411         return GetVariables(scope.variablesReference, out);
412       }
413     }
414 
415     onerror_("Locals scope not found");
416     return false;
417   }
418 
419   // GetLane returns a pointer to the Variables representing the thread's SIMD
420   // lane with the given index, or nullptr if the lane was not found.
GetLane(const Variables & lanes,int lane)421   const Variables* GetLane(const Variables& lanes, int lane) {
422     auto out = lanes.Find(std::string(kLane) + " " + std::to_string(lane));
423     if (out == nullptr) {
424       return nullptr;
425     }
426     return &out->children;
427   }
428 
429  private:
430   struct SourceCache {
431     std::unordered_map<int64_t, SourceLines> by_ref;
432     std::unordered_map<std::string, SourceLines> by_path;
433   };
434 
435   std::shared_ptr<dap::Session> session_;
436   ErrorHandler onerror_;
437   SourceCache sourceCache_;
438 };
439 
440 // InvocationKey is a tagged-union structure that identifies a single shader
441 // invocation.
442 struct InvocationKey {
443   // Hash is a custom hasher that can enable InvocationKeys to be used as keys
444   // in std containers.
445   struct Hash {
446     size_t operator()(const InvocationKey& key) const;
447   };
448 
449   enum class Type { kGlobalInvocationId, kVertexIndex, kWindowSpacePosition };
450   union Data {
451     GlobalInvocationId global_invocation_id;
452     uint32_t vertex_id;
453     WindowSpacePosition window_space_position;
454   };
455 
456   explicit InvocationKey(const GlobalInvocationId&);
457   explicit InvocationKey(const WindowSpacePosition&);
458   InvocationKey(Type, const Data&);
459 
460   bool operator==(const InvocationKey& other) const;
461 
462   // String returns a human-readable description of the key.
463   std::string String() const;
464 
465   Type type;
466   Data data;
467 };
468 
operator ()(const InvocationKey & key) const469 size_t InvocationKey::Hash::operator()(const InvocationKey& key) const {
470   size_t hash = 31 * static_cast<size_t>(key.type);
471   switch (key.type) {
472     case Type::kGlobalInvocationId:
473       hash += key.data.global_invocation_id.hash();
474       break;
475     case Type::kVertexIndex:
476       hash += key.data.vertex_id;
477       break;
478     case Type::kWindowSpacePosition:
479       hash += key.data.window_space_position.hash();
480       break;
481   }
482   return hash;
483 }
484 
InvocationKey(const GlobalInvocationId & id)485 InvocationKey::InvocationKey(const GlobalInvocationId& id)
486     : type(Type::kGlobalInvocationId) {
487   data.global_invocation_id = id;
488 }
489 
InvocationKey(const WindowSpacePosition & pos)490 InvocationKey::InvocationKey(const WindowSpacePosition& pos)
491     : type(Type::kWindowSpacePosition) {
492   data.window_space_position = pos;
493 }
494 
InvocationKey(Type type_,const Data & data_)495 InvocationKey::InvocationKey(Type type_, const Data& data_)
496     : type(type_), data(data_) {}
497 
String() const498 std::string InvocationKey::String() const {
499   std::stringstream ss;
500   switch (type) {
501     case Type::kGlobalInvocationId:
502       ss << "GlobalInvocation(" << data.global_invocation_id.x << ", "
503          << data.global_invocation_id.y << ", " << data.global_invocation_id.z
504          << ")";
505       break;
506     case Type::kVertexIndex:
507       ss << "VertexIndex(" << data.vertex_id << ")";
508       break;
509     case Type::kWindowSpacePosition:
510       ss << "WindowSpacePosition(" << data.window_space_position.x << ", "
511          << data.window_space_position.y << ")";
512       break;
513   }
514   return ss.str();
515 }
516 
operator ==(const InvocationKey & other) const517 bool InvocationKey::operator==(const InvocationKey& other) const {
518   if (type != other.type) {
519     return false;
520   }
521   switch (type) {
522     case Type::kGlobalInvocationId:
523       return data.global_invocation_id == other.data.global_invocation_id;
524     case Type::kVertexIndex:
525       return data.vertex_id == other.data.vertex_id;
526     case Type::kWindowSpacePosition:
527       return data.window_space_position == other.data.window_space_position;
528   }
529   return false;
530 }
531 
532 // Thread controls and verifies a single debugger thread of execution.
533 class Thread : public debug::Thread {
534  public:
Thread(VirtualFileStore * virtual_files,std::shared_ptr<dap::Session> session,dap::integer threadId,int lane,std::shared_ptr<const debug::ThreadScript> script)535   Thread(VirtualFileStore* virtual_files,
536          std::shared_ptr<dap::Session> session,
537          dap::integer threadId,
538          int lane,
539          std::shared_ptr<const debug::ThreadScript> script)
540       : virtual_files_(virtual_files),
541         thread_id_(threadId),
542         lane_(lane),
543         client_(session, [this](const std::string& err) { OnError(err); }) {
544     // The thread script runs concurrently with other debugger thread scripts.
545     // Run on a separate amber thread.
__anonf24a5b550402null546     thread_ = std::thread([this, script] {
547       script->Run(this);  // Begin running the thread script.
548       done_.Signal();     // Signal when done.
549     });
550   }
551 
~Thread()552   ~Thread() override { Flush(); }
553 
554   // Flush waits for the debugger thread script to complete, and returns any
555   // errors encountered.
Flush()556   Result Flush() {
557     if (done_.Wait(kThreadTimeout)) {
558       if (thread_.joinable()) {
559         thread_.join();
560       }
561     } else {
562       error_ += "Timed out performing actions";
563     }
564     return error_;
565   }
566 
567   // OnStep is called when the debugger sends a 'stopped' event with reason
568   // 'step'.
OnStep()569   void OnStep() { on_stepped_.Signal(); }
570 
571   // WaitForSteppedEvent halts execution of the thread until OnStep() is
572   // called, or kStepEventTimeout is reached.
WaitForStepEvent(const std::string & request)573   void WaitForStepEvent(const std::string& request) {
574     if (!on_stepped_.Wait(kStepEventTimeout)) {
575       OnError("Timeout waiting for " + request + " request to complete");
576     }
577   }
578 
579   // debug::Thread compliance
StepOver()580   void StepOver() override {
581     DEBUGGER_LOG("StepOver()");
582     dap::NextRequest request;
583     request.threadId = thread_id_;
584     client_.Send(request);
585     WaitForStepEvent("step-over");
586   }
587 
StepIn()588   void StepIn() override {
589     DEBUGGER_LOG("StepIn()");
590     dap::StepInRequest request;
591     request.threadId = thread_id_;
592     client_.Send(request);
593     WaitForStepEvent("step-in");
594   }
595 
StepOut()596   void StepOut() override {
597     DEBUGGER_LOG("StepOut()");
598     dap::StepOutRequest request;
599     request.threadId = thread_id_;
600     client_.Send(request);
601     WaitForStepEvent("step-out");
602   }
603 
Continue()604   void Continue() override {
605     DEBUGGER_LOG("Continue()");
606     dap::ContinueRequest request;
607     request.threadId = thread_id_;
608     client_.Send(request);
609   }
610 
ExpectLocation(const debug::Location & location,const std::string & line)611   void ExpectLocation(const debug::Location& location,
612                       const std::string& line) override {
613     DEBUGGER_LOG("ExpectLocation('%s', %d)", location.file.c_str(),
614                  location.line);
615 
616     dap::StackFrame frame;
617     if (!client_.TopStackFrame(thread_id_, &frame)) {
618       return;
619     }
620 
621     debug::Location got_location;
622     std::string got_source_line;
623     if (!client_.FrameLocation(virtual_files_, frame, &got_location,
624                                &got_source_line)) {
625       return;
626     }
627 
628     if (got_location.file != location.file) {
629       OnError("Expected file to be '" + location.file + "' but file was '" +
630               got_location.file + "'");
631     } else if (got_location.line != location.line) {
632       std::stringstream ss;
633       ss << "Expected line " << std::to_string(location.line);
634       if (line != "") {
635         ss << " `" << line << "`";
636       }
637       ss << " but line was " << std::to_string(got_location.line) << " `"
638          << got_source_line << "`";
639       OnError(ss.str());
640     } else if (line != "" && got_source_line != line) {
641       OnError("Expected source for line " + std::to_string(location.line) +
642               " to be: `" + line + "` but line was: `" + got_source_line + "`");
643     }
644   }
645 
ExpectCallstack(const std::vector<debug::StackFrame> & callstack)646   void ExpectCallstack(
647       const std::vector<debug::StackFrame>& callstack) override {
648     DEBUGGER_LOG("ExpectCallstack()");
649 
650     std::vector<dap::StackFrame> got_stack;
651     if (!client_.Callstack(thread_id_, &got_stack)) {
652       return;
653     }
654 
655     std::stringstream ss;
656 
657     size_t count = std::min(callstack.size(), got_stack.size());
658     for (size_t i = 0; i < count; i++) {
659       auto const& got_frame = got_stack[i];
660       auto const& want_frame = callstack[i];
661       bool ok = got_frame.name == want_frame.name;
662       if (ok && want_frame.location.file != "") {
663         ok = got_frame.source.has_value() &&
664              got_frame.source->name.value("") == want_frame.location.file;
665       }
666       if (ok && want_frame.location.line != 0) {
667         ok = got_frame.line == static_cast<int>(want_frame.location.line);
668       }
669       if (!ok) {
670         ss << "Unexpected stackframe at frame " << i
671            << "\nGot:      " << FrameString(got_frame)
672            << "\nExpected: " << FrameString(want_frame) << "\n";
673       }
674     }
675 
676     if (got_stack.size() > callstack.size()) {
677       ss << "Callstack has an additional "
678          << (got_stack.size() - callstack.size()) << " unexpected frames\n";
679     } else if (callstack.size() > got_stack.size()) {
680       ss << "Callstack is missing " << (callstack.size() - got_stack.size())
681          << " frames\n";
682     }
683 
684     if (ss.str().size() > 0) {
685       ss << "Full callstack:\n";
686       for (auto& frame : got_stack) {
687         ss << "  " << FrameString(frame) << "\n";
688       }
689       OnError(ss.str());
690     }
691   }
692 
ExpectLocal(const std::string & name,int64_t value)693   void ExpectLocal(const std::string& name, int64_t value) override {
694     DEBUGGER_LOG("ExpectLocal('%s', %d)", name.c_str(), (int)value);
695     ExpectLocalT(name, value, [](int64_t a, int64_t b) { return a == b; });
696   }
697 
ExpectLocal(const std::string & name,double value)698   void ExpectLocal(const std::string& name, double value) override {
699     DEBUGGER_LOG("ExpectLocal('%s', %f)", name.c_str(), value);
700     ExpectLocalT(name, value,
701                  [](double a, double b) { return std::abs(a - b) < 0.00001; });
702   }
703 
ExpectLocal(const std::string & name,const std::string & value)704   void ExpectLocal(const std::string& name, const std::string& value) override {
705     DEBUGGER_LOG("ExpectLocal('%s', '%s')", name.c_str(), value.c_str());
706     ExpectLocalT(name, value, [](const std::string& a, const std::string& b) {
707       return a == b;
708     });
709   }
710 
711   // ExpectLocalT verifies that the local variable with the given name has the
712   // expected value. |name| may contain `.` delimiters to index structure or
713   // array types. ExpectLocalT uses the function equal to compare the expected
714   // and actual values, returning true if considered equal, otherwise false.
715   // EQUAL_FUNC is a function-like type with the signature:
716   //   bool(T a, T b)
717   template <typename T, typename EQUAL_FUNC>
ExpectLocalT(const std::string & name,const T & expect,EQUAL_FUNC && equal)718   void ExpectLocalT(const std::string& name,
719                     const T& expect,
720                     EQUAL_FUNC&& equal) {
721     dap::StackFrame frame;
722     if (!client_.TopStackFrame(thread_id_, &frame)) {
723       return;
724     }
725 
726     Variables locals;
727     if (!client_.GetLocals(frame, &locals)) {
728       return;
729     }
730 
731     if (auto lane = client_.GetLane(locals, lane_)) {
732       auto owner = lane;
733       const Variable* var = nullptr;
734       std::string path;
735       for (auto part : Split(name, ".")) {
736         var = owner->Find(part);
737         if (!var) {
738           if (owner == lane) {
739             OnError("Local '" + name + "' not found.\nLocals: " +
740                     lane->AllNames() + ".\nLanes: " + locals.AllNames() + ".");
741           } else {
742             OnError("Local '" + path + "' does not contain '" + part +
743                     "'\nChildren: " + owner->AllNames());
744           }
745           return;
746         }
747         owner = &var->children;
748         path += (path.size() > 0) ? "." + part : part;
749       }
750 
751       T got = {};
752       if (!var->Get(&got)) {
753         OnError("Local '" + name + "' was not of expected type");
754         return;
755       }
756 
757       if (!equal(got, expect)) {
758         std::stringstream ss;
759         ss << "Local '" << name << "' did not have expected value. Value is '"
760            << got << "', expected '" << expect << "'";
761         OnError(ss.str());
762         return;
763       }
764     }
765   }
766 
767  private:
OnError(const std::string & err)768   void OnError(const std::string& err) {
769     DEBUGGER_LOG("ERROR: %s", err.c_str());
770     error_ += err;
771   }
772 
FrameString(const dap::StackFrame & frame)773   std::string FrameString(const dap::StackFrame& frame) {
774     std::stringstream ss;
775     ss << frame.name;
776     if (frame.source.has_value() && frame.source->name.has_value()) {
777       ss << " " << frame.source->name.value() << ":" << frame.line;
778     }
779     return ss.str();
780   }
781 
FrameString(const debug::StackFrame & frame)782   std::string FrameString(const debug::StackFrame& frame) {
783     std::stringstream ss;
784     ss << frame.name;
785     if (frame.location.file != "") {
786       ss << " " << frame.location.file;
787       if (frame.location.line != 0) {
788         ss << ":" << frame.location.line;
789       }
790     }
791     return ss.str();
792   }
793 
794   VirtualFileStore* const virtual_files_;
795   const dap::integer thread_id_;
796   const int lane_;
797   Client client_;
798   std::thread thread_;
799   Event done_;
800   Event on_stepped_;
801   Result error_;
802 };
803 
804 // VkDebugger is a private implementation of the Engine::Debugger interface.
805 class VkDebugger : public Engine::Debugger {
806   static constexpr const char* kComputeShaderFunctionName = "ComputeShader";
807   static constexpr const char* kVertexShaderFunctionName = "VertexShader";
808   static constexpr const char* kFragmentShaderFunctionName = "FragmentShader";
809   static constexpr const char* kGlobalInvocationId = "globalInvocationId";
810   static constexpr const char* kWindowSpacePosition = "windowSpacePosition";
811   static constexpr const char* kVertexIndex = "vertexIndex";
812 
813  public:
VkDebugger(VirtualFileStore * virtual_files)814   explicit VkDebugger(VirtualFileStore* virtual_files)
815       : virtual_files_(virtual_files) {}
816 
817   /// Connect establishes the connection to the shader debugger. Must be
818   /// called before any of the |debug::Events| methods.
Connect()819   Result Connect() {
820     auto port_str = getenv("VK_DEBUGGER_PORT");
821     if (!port_str) {
822       return Result("The environment variable VK_DEBUGGER_PORT was not set\n");
823     }
824     int port = atoi(port_str);
825     if (port == 0) {
826       return Result("Could not parse port number from VK_DEBUGGER_PORT ('" +
827                     std::string(port_str) + "')");
828     }
829 
830     constexpr int kMaxAttempts = 10;
831     // The socket might take a while to open - retry connecting.
832     for (int attempt = 0; attempt < kMaxAttempts; attempt++) {
833       auto connection = dap::net::connect("localhost", port);
834       if (!connection) {
835         std::this_thread::sleep_for(std::chrono::seconds(1));
836         continue;
837       }
838 
839       // Socket opened. Create the debugger session and bind.
840       session_ = dap::Session::create();
841       session_->bind(connection);
842 
843       // Register the thread stopped event.
844       // This is fired when breakpoints are hit (amongst other reasons).
845       // See:
846       // https://microsoft.github.io/debug-adapter-protocol/specification#Events_Stopped
847       session_->registerHandler([&](const dap::StoppedEvent& event) {
848         DEBUGGER_LOG("THREAD STOPPED. Reason: %s", event.reason.c_str());
849         if (event.reason == "function breakpoint" ||
850             event.reason == "breakpoint") {
851           OnBreakpointHit(event.threadId.value(0));
852         } else if (event.reason == "step") {
853           OnStep(event.threadId.value(0));
854         }
855       });
856 
857       // Start the debugger initialization sequence.
858       // See: https://microsoft.github.io/debug-adapter-protocol/overview for
859       // details.
860 
861       dap::InitializeRequest init_req = {};
862       auto init_res = session_->send(init_req).get();
863       if (init_res.error) {
864         DEBUGGER_LOG("InitializeRequest failed: %s",
865                      init_res.error.message.c_str());
866         return Result(init_res.error.message);
867       }
868 
869       // Set breakpoints on the various shader types, we do this even if we
870       // don't actually care about these threads. Once the breakpoint is hit,
871       // the pendingThreads_ map is probed, if nothing matches the thread is
872       // resumed.
873       // TODO(bclayton): Once we have conditional breakpoint support, we can
874       // reduce the number of breakpoints / scope of breakpoints.
875       dap::SetFunctionBreakpointsRequest fbp_req = {};
876       dap::FunctionBreakpoint fbp = {};
877       fbp.name = kComputeShaderFunctionName;
878       fbp_req.breakpoints.emplace_back(fbp);
879       fbp.name = kVertexShaderFunctionName;
880       fbp_req.breakpoints.emplace_back(fbp);
881       fbp.name = kFragmentShaderFunctionName;
882       fbp_req.breakpoints.emplace_back(fbp);
883       auto fbp_res = session_->send(fbp_req).get();
884       if (fbp_res.error) {
885         DEBUGGER_LOG("SetFunctionBreakpointsRequest failed: %s",
886                      fbp_res.error.message.c_str());
887         return Result(fbp_res.error.message);
888       }
889 
890       // ConfigurationDone signals the initialization has completed.
891       dap::ConfigurationDoneRequest cfg_req = {};
892       auto cfg_res = session_->send(cfg_req).get();
893       if (cfg_res.error) {
894         DEBUGGER_LOG("ConfigurationDoneRequest failed: %s",
895                      cfg_res.error.message.c_str());
896         return Result(cfg_res.error.message);
897       }
898 
899       return Result();
900     }
901     return Result("Unable to connect to debugger on port " +
902                   std::to_string(port));
903   }
904 
905   // Flush checks that all breakpoints were hit, waits for all threads to
906   // complete, and returns the globbed together results for all threads.
Flush()907   Result Flush() override {
908     Result result;
909     {
910       std::unique_lock<std::mutex> lock(error_mutex_);
911       result += error_;
912     }
913     {
914       std::unique_lock<std::mutex> lock(threads_mutex_);
915       for (auto& pending : pendingThreads_) {
916         result += "Thread did not run: " + pending.first.String();
917       }
918       for (auto& it : runningThreads_) {
919         result += it.second->Flush();
920       }
921       runningThreads_.clear();
922       completedThreads_.clear();
923     }
924     return result;
925   }
926 
927   // debug::Events compliance
BreakOnComputeGlobalInvocation(uint32_t x,uint32_t y,uint32_t z,const std::shared_ptr<const debug::ThreadScript> & script)928   void BreakOnComputeGlobalInvocation(
929       uint32_t x,
930       uint32_t y,
931       uint32_t z,
932       const std::shared_ptr<const debug::ThreadScript>& script) override {
933     std::unique_lock<std::mutex> lock(threads_mutex_);
934     pendingThreads_.emplace(GlobalInvocationId{x, y, z}, script);
935   }
936 
BreakOnVertexIndex(uint32_t index,const std::shared_ptr<const debug::ThreadScript> & script)937   void BreakOnVertexIndex(
938       uint32_t index,
939       const std::shared_ptr<const debug::ThreadScript>& script) override {
940     InvocationKey::Data data;
941     data.vertex_id = index;
942     auto key = InvocationKey{InvocationKey::Type::kVertexIndex, data};
943     std::unique_lock<std::mutex> lock(threads_mutex_);
944     pendingThreads_.emplace(key, script);
945   }
946 
BreakOnFragmentWindowSpacePosition(uint32_t x,uint32_t y,const std::shared_ptr<const debug::ThreadScript> & script)947   void BreakOnFragmentWindowSpacePosition(
948       uint32_t x,
949       uint32_t y,
950       const std::shared_ptr<const debug::ThreadScript>& script) override {
951     std::unique_lock<std::mutex> lock(threads_mutex_);
952     pendingThreads_.emplace(WindowSpacePosition{x, y}, script);
953   }
954 
955  private:
StartThread(dap::integer thread_id,int lane,const std::shared_ptr<const amber::debug::ThreadScript> & script)956   void StartThread(
957       dap::integer thread_id,
958       int lane,
959       const std::shared_ptr<const amber::debug::ThreadScript>& script) {
960     auto thread =
961         MakeUnique<Thread>(virtual_files_, session_, thread_id, lane, script);
962     auto it = runningThreads_.find(thread_id);
963     if (it != runningThreads_.end()) {
964       // The debugger has now started working on something else for the given
965       // thread id.
966       // Flush the Thread to ensure that all tasks have actually finished, and
967       // move the Thread to the completedThreads_ list.
968       auto completed_thread = std::move(it->second);
969       error_ += completed_thread->Flush();
970       completedThreads_.emplace_back(std::move(completed_thread));
971     }
972     runningThreads_[thread_id] = std::move(thread);
973   }
974 
975   // OnBreakpointHit is called when a debugger breakpoint is hit (breakpoints
976   // are set at shader entry points). pendingThreads_ is checked to see if this
977   // thread needs testing, and if so, creates a new ::Thread.
978   // If there's no pendingThread_ entry for the given thread, it is resumed to
979   // allow the shader to continue executing.
OnBreakpointHit(dap::integer thread_id)980   void OnBreakpointHit(dap::integer thread_id) {
981     DEBUGGER_LOG("Breakpoint hit: thread %d", (int)thread_id);
982     Client client(session_, [this](const std::string& err) { OnError(err); });
983 
984     std::unique_lock<std::mutex> lock(threads_mutex_);
985     for (auto it = pendingThreads_.begin(); it != pendingThreads_.end(); it++) {
986       auto& key = it->first;
987       auto& script = it->second;
988       switch (key.type) {
989         case InvocationKey::Type::kGlobalInvocationId: {
990           auto invocation_id = key.data.global_invocation_id;
991           int lane;
992           if (FindGlobalInvocationId(thread_id, invocation_id, &lane)) {
993             DEBUGGER_LOG("Breakpoint hit: GetGlobalInvocationId: [%d, %d, %d]",
994                          invocation_id.x, invocation_id.y, invocation_id.z);
995             StartThread(thread_id, lane, script);
996             pendingThreads_.erase(it);
997             return;
998           }
999           break;
1000         }
1001         case InvocationKey::Type::kVertexIndex: {
1002           auto vertex_id = key.data.vertex_id;
1003           int lane;
1004           if (FindVertexIndex(thread_id, vertex_id, &lane)) {
1005             DEBUGGER_LOG("Breakpoint hit: VertexId: %d", vertex_id);
1006             StartThread(thread_id, lane, script);
1007             pendingThreads_.erase(it);
1008             return;
1009           }
1010           break;
1011         }
1012         case InvocationKey::Type::kWindowSpacePosition: {
1013           auto position = key.data.window_space_position;
1014           int lane;
1015           if (FindWindowSpacePosition(thread_id, position, &lane)) {
1016             DEBUGGER_LOG("Breakpoint hit: VertexId: [%d, %d]", position.x,
1017                          position.y);
1018             StartThread(thread_id, lane, script);
1019             pendingThreads_.erase(it);
1020             return;
1021           }
1022           break;
1023         }
1024       }
1025     }
1026 
1027     // No pending tests for this thread. Let it carry on...
1028     dap::ContinueRequest request;
1029     request.threadId = thread_id;
1030     client.Send(request);
1031   }
1032 
1033   // OnStep is called when a debugger finishes a step request.
OnStep(dap::integer thread_id)1034   void OnStep(dap::integer thread_id) {
1035     DEBUGGER_LOG("Step event: thread %d", (int)thread_id);
1036     Client client(session_, [this](const std::string& err) { OnError(err); });
1037 
1038     std::unique_lock<std::mutex> lock(threads_mutex_);
1039     auto it = runningThreads_.find(thread_id);
1040     if (it == runningThreads_.end()) {
1041       OnError("Step event reported for non-running thread " +
1042               std::to_string(thread_id.operator int64_t()));
1043     }
1044 
1045     it->second->OnStep();
1046   }
1047 
1048   // FindLocal looks for the shader's local variable with the given name and
1049   // value in the stack frames' locals, returning true if found, and assigns the
1050   // index of the SIMD lane it was found in to |lane|.
1051   template <typename T>
FindLocal(dap::integer thread_id,const char * name,const T & value,int * lane)1052   bool FindLocal(dap::integer thread_id,
1053                  const char* name,
1054                  const T& value,
1055                  int* lane) {
1056     Client client(session_, [this](const std::string& err) { OnError(err); });
1057 
1058     dap::StackFrame frame;
1059     if (!client.TopStackFrame(thread_id, &frame)) {
1060       return false;
1061     }
1062 
1063     dap::ScopesRequest scopeReq;
1064     dap::ScopesResponse scopeRes;
1065     scopeReq.frameId = frame.id;
1066     if (!client.Send(scopeReq, &scopeRes)) {
1067       return false;
1068     }
1069 
1070     Variables locals;
1071     if (!client.GetLocals(frame, &locals)) {
1072       return false;
1073     }
1074 
1075     for (int i = 0;; i++) {
1076       auto lane_var = client.GetLane(locals, i);
1077       if (!lane_var) {
1078         break;
1079       }
1080       if (auto var = lane_var->Find(name)) {
1081         T got;
1082         if (var->Get(&got)) {
1083           if (got == value) {
1084             *lane = i;
1085             return true;
1086           }
1087         }
1088       }
1089     }
1090 
1091     return false;
1092   }
1093 
1094   // FindGlobalInvocationId looks for the compute shader's global invocation id
1095   // in the stack frames' locals, returning true if found, and assigns the index
1096   // of the SIMD lane it was found in to |lane|.
1097   // TODO(bclayton): This value should probably be in the globals, not locals!
FindGlobalInvocationId(dap::integer thread_id,const GlobalInvocationId & id,int * lane)1098   bool FindGlobalInvocationId(dap::integer thread_id,
1099                               const GlobalInvocationId& id,
1100                               int* lane) {
1101     return FindLocal(thread_id, kGlobalInvocationId, id, lane);
1102   }
1103 
1104   // FindVertexIndex looks for the requested vertex shader's vertex index in the
1105   // stack frames' locals, returning true if found, and assigns the index of the
1106   // SIMD lane it was found in to |lane|.
1107   // TODO(bclayton): This value should probably be in the globals, not locals!
FindVertexIndex(dap::integer thread_id,uint32_t index,int * lane)1108   bool FindVertexIndex(dap::integer thread_id, uint32_t index, int* lane) {
1109     return FindLocal(thread_id, kVertexIndex, index, lane);
1110   }
1111 
1112   // FindWindowSpacePosition looks for the fragment shader's window space
1113   // position in the stack frames' locals, returning true if found, and assigns
1114   // the index of the SIMD lane it was found in to |lane|.
1115   // TODO(bclayton): This value should probably be in the globals, not locals!
FindWindowSpacePosition(dap::integer thread_id,const WindowSpacePosition & pos,int * lane)1116   bool FindWindowSpacePosition(dap::integer thread_id,
1117                                const WindowSpacePosition& pos,
1118                                int* lane) {
1119     return FindLocal(thread_id, kWindowSpacePosition, pos, lane);
1120   }
1121 
OnError(const std::string & error)1122   void OnError(const std::string& error) {
1123     DEBUGGER_LOG("ERROR: %s", error.c_str());
1124     error_ += error;
1125   }
1126 
1127   using PendingThreadsMap =
1128       std::unordered_map<InvocationKey,
1129                          std::shared_ptr<const debug::ThreadScript>,
1130                          InvocationKey::Hash>;
1131   using ThreadVector = std::vector<std::unique_ptr<Thread>>;
1132   using ThreadMap = std::unordered_map<int64_t, std::unique_ptr<Thread>>;
1133   VirtualFileStore* const virtual_files_;
1134   std::shared_ptr<dap::Session> session_;
1135   std::mutex threads_mutex_;
1136 
1137   // The threads that are waiting to be started.
1138   PendingThreadsMap pendingThreads_;  // guarded by threads_mutex_
1139 
1140   // All threads that have been started.
1141   ThreadMap runningThreads_;  // guarded by threads_mutex_
1142 
1143   // Threads that have finished.
1144   ThreadVector completedThreads_;  // guarded by threads_mutex_
1145   std::mutex error_mutex_;
1146   Result error_;  // guarded by error_mutex_
1147 };
1148 }  // namespace
1149 
GetDebugger(VirtualFileStore * virtual_files)1150 std::pair<Engine::Debugger*, Result> EngineVulkan::GetDebugger(
1151     VirtualFileStore* virtual_files) {
1152   if (!debugger_) {
1153     auto debugger = new VkDebugger(virtual_files);
1154     debugger_.reset(debugger);
1155     auto res = debugger->Connect();
1156     if (!res.IsSuccess()) {
1157       return {nullptr, res};
1158     }
1159   }
1160   return {debugger_.get(), Result()};
1161 }
1162 
1163 }  // namespace vulkan
1164 }  // namespace amber
1165 
1166 #else  // AMBER_ENABLE_VK_DEBUGGING
1167 
1168 namespace amber {
1169 namespace vulkan {
1170 
GetDebugger(VirtualFileStore *)1171 std::pair<Engine::Debugger*, Result> EngineVulkan::GetDebugger(
1172     VirtualFileStore*) {
1173   return {nullptr,
1174           Result("Amber was not built with AMBER_ENABLE_VK_DEBUGGING enabled")};
1175 }
1176 
1177 }  // namespace vulkan
1178 }  // namespace amber
1179 
1180 #endif  // AMBER_ENABLE_VK_DEBUGGING
1181