1 // Copyright 2020 The Amber Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/vulkan/engine_vulkan.h"
16
17 #if AMBER_ENABLE_VK_DEBUGGING
18
19 #include <stdlib.h>
20
21 #include <chrono> // NOLINT(build/c++11)
22 #include <condition_variable> // NOLINT(build/c++11)
23 #include <fstream>
24 #include <mutex> // NOLINT(build/c++11)
25 #include <sstream>
26 #include <thread> // NOLINT(build/c++11)
27 #include <unordered_map>
28
29 #include "dap/network.h"
30 #include "dap/protocol.h"
31 #include "dap/session.h"
32 #include "src/virtual_file_store.h"
33
34 // Set to 1 to enable verbose debugger logging
35 #define ENABLE_DEBUGGER_LOG 0
36
37 #if ENABLE_DEBUGGER_LOG
38 #define DEBUGGER_LOG(...) \
39 do { \
40 printf(__VA_ARGS__); \
41 printf("\n"); \
42 } while (false)
43 #else
44 #define DEBUGGER_LOG(...)
45 #endif
46
47 namespace amber {
48 namespace vulkan {
49
50 namespace {
51
52 // kThreadTimeout is the maximum time to wait for a thread to complete.
53 static constexpr const auto kThreadTimeout = std::chrono::minutes(1);
54
55 // kStepEventTimeout is the maximum time to wait for a thread to complete a step
56 // request.
57 static constexpr const auto kStepEventTimeout = std::chrono::seconds(1);
58
59 // Event provides a basic wait-and-signal synchronization primitive.
60 class Event {
61 public:
62 // Wait blocks until the event is fired, or the timeout is reached.
63 // If the Event was signalled, then Wait returns true, otherwise false.
64 template <typename Rep, typename Period>
Wait(const std::chrono::duration<Rep,Period> & duration)65 bool Wait(const std::chrono::duration<Rep, Period>& duration) {
66 std::unique_lock<std::mutex> lock(mutex_);
67 return cv_.wait_for(lock, duration, [&] { return signalled_; });
68 }
69
70 // Signal signals the Event, unblocking any calls to Wait.
Signal()71 void Signal() {
72 std::unique_lock<std::mutex> lock(mutex_);
73 signalled_ = true;
74 cv_.notify_all();
75 }
76
77 private:
78 std::condition_variable cv_;
79 std::mutex mutex_;
80 bool signalled_ = false;
81 };
82
83 // Split slices str into all substrings separated by sep and returns a vector of
84 // the substrings between those separators.
Split(const std::string & str,const std::string & sep)85 std::vector<std::string> Split(const std::string& str, const std::string& sep) {
86 std::vector<std::string> out;
87 std::size_t cur = 0;
88 std::size_t prev = 0;
89 while ((cur = str.find(sep, prev)) != std::string::npos) {
90 out.push_back(str.substr(prev, cur - prev));
91 prev = cur + 1;
92 }
93 out.push_back(str.substr(prev));
94 return out;
95 }
96
97 // GlobalInvocationId holds a three-element unsigned integer index, used to
98 // identifiy a single compute invocation.
99 struct GlobalInvocationId {
hashamber::vulkan::__anonf24a5b550111::GlobalInvocationId100 size_t hash() const { return x << 20 | y << 10 | z; }
operator ==amber::vulkan::__anonf24a5b550111::GlobalInvocationId101 bool operator==(const GlobalInvocationId& other) const {
102 return x == other.x && y == other.y && z == other.z;
103 }
104
105 uint32_t x;
106 uint32_t y;
107 uint32_t z;
108 };
109
110 // WindowSpacePosition holds a two-element unsigned integer index, used to
111 // identifiy a single fragment invocation.
112 struct WindowSpacePosition {
hashamber::vulkan::__anonf24a5b550111::WindowSpacePosition113 size_t hash() const { return x << 10 | y; }
operator ==amber::vulkan::__anonf24a5b550111::WindowSpacePosition114 bool operator==(const WindowSpacePosition& other) const {
115 return x == other.x && y == other.y;
116 }
117
118 uint32_t x;
119 uint32_t y;
120 };
121
122 // Forward declaration.
123 struct Variable;
124
125 // Variables is a list of Variable (), with helper methods.
126 class Variables : public std::vector<Variable> {
127 public:
128 inline const Variable* Find(const std::string& name) const;
129 inline std::string AllNames() const;
130 };
131
132 // Variable holds a debugger returned named value (local, global, etc).
133 // Variables can hold child variables (for structs, arrays, etc).
134 struct Variable {
135 std::string name;
136 std::string value;
137 Variables children;
138
139 // Get parses the Variable value for the requested type, assigning the result
140 // to |out|. Returns true on success, otherwise false.
Getamber::vulkan::__anonf24a5b550111::Variable141 bool Get(uint32_t* out) const {
142 *out = static_cast<uint32_t>(std::atoi(value.c_str()));
143 return true; // TODO(bclayton): Verify the value parsed correctly.
144 }
145
Getamber::vulkan::__anonf24a5b550111::Variable146 bool Get(int64_t* out) const {
147 *out = static_cast<int64_t>(std::atoi(value.c_str()));
148 return true; // TODO(bclayton): Verify the value parsed correctly.
149 }
150
Getamber::vulkan::__anonf24a5b550111::Variable151 bool Get(double* out) const {
152 *out = std::atof(value.c_str());
153 return true; // TODO(bclayton): Verify the value parsed correctly.
154 }
155
Getamber::vulkan::__anonf24a5b550111::Variable156 bool Get(std::string* out) const {
157 *out = value;
158 return true;
159 }
160
Getamber::vulkan::__anonf24a5b550111::Variable161 bool Get(GlobalInvocationId* out) const {
162 auto x = children.Find("x");
163 auto y = children.Find("y");
164 auto z = children.Find("z");
165 return (x != nullptr && y != nullptr && z != nullptr && x->Get(&out->x) &&
166 y->Get(&out->y) && z->Get(&out->z));
167 }
168
Getamber::vulkan::__anonf24a5b550111::Variable169 bool Get(WindowSpacePosition* out) const {
170 auto x = children.Find("x");
171 auto y = children.Find("y");
172 return (x != nullptr && y != nullptr && x->Get(&out->x) && y->Get(&out->y));
173 }
174 };
175
Find(const std::string & name) const176 const Variable* Variables::Find(const std::string& name) const {
177 for (auto& child : *this) {
178 if (child.name == name) {
179 return &child;
180 }
181 }
182 return nullptr;
183 }
184
AllNames() const185 std::string Variables::AllNames() const {
186 std::string out;
187 for (auto& var : *this) {
188 if (out.size() > 0) {
189 out += ", ";
190 }
191 out += "'" + var.name + "'";
192 }
193 return out;
194 }
195
196 // Client wraps a dap::Session and a error handler, and provides a more
197 // convenient interface for talking to the debugger. Client also provides basic
198 // immutable data caching to help performance.
199 class Client {
200 static constexpr const char* kLocals = "locals";
201 static constexpr const char* kLane = "Lane";
202
203 public:
204 using ErrorHandler = std::function<void(const std::string&)>;
205 using SourceLines = std::vector<std::string>;
206
Client(const std::shared_ptr<dap::Session> & session,const ErrorHandler & onerror)207 Client(const std::shared_ptr<dap::Session>& session,
208 const ErrorHandler& onerror)
209 : session_(session), onerror_(onerror) {}
210
211 // TopStackFrame retrieves the frame at the top of the thread's call stack.
212 // Returns true on success, false on error.
TopStackFrame(dap::integer thread_id,dap::StackFrame * frame)213 bool TopStackFrame(dap::integer thread_id, dap::StackFrame* frame) {
214 std::vector<dap::StackFrame> stack;
215 if (!Callstack(thread_id, &stack)) {
216 return false;
217 }
218 *frame = stack.front();
219 return true;
220 }
221
222 // Callstack retrieves the thread's full call stack.
223 // Returns true on success, false on error.
Callstack(dap::integer thread_id,std::vector<dap::StackFrame> * stack)224 bool Callstack(dap::integer thread_id, std::vector<dap::StackFrame>* stack) {
225 dap::StackTraceRequest request;
226 request.threadId = thread_id;
227 auto response = session_->send(request).get();
228 if (response.error) {
229 onerror_(response.error.message);
230 return false;
231 }
232 if (response.response.stackFrames.size() == 0) {
233 onerror_("Stack frame is empty");
234 return false;
235 }
236 *stack = response.response.stackFrames;
237 return true;
238 }
239
240 // FrameLocation retrieves the current frame source location, and optional
241 // source line text.
242 // Returns true on success, false on error.
FrameLocation(VirtualFileStore * virtual_files,const dap::StackFrame & frame,debug::Location * location,std::string * line=nullptr)243 bool FrameLocation(VirtualFileStore* virtual_files,
244 const dap::StackFrame& frame,
245 debug::Location* location,
246 std::string* line = nullptr) {
247 location->line = static_cast<uint32_t>(frame.line);
248
249 if (!frame.source.has_value()) {
250 onerror_("Stack frame with name '" + frame.name + "' has no source");
251 return false;
252 } else if (frame.source->path.has_value()) {
253 location->file = frame.source.value().path.value();
254 } else if (frame.source->name.has_value()) {
255 location->file = frame.source.value().name.value();
256 } else {
257 onerror_("Frame source had no path or name");
258 return false;
259 }
260
261 if (location->line < 1) {
262 onerror_("Line location is " + std::to_string(location->line));
263 return false;
264 }
265
266 if (line != nullptr) {
267 SourceLines lines;
268 if (!SourceContent(virtual_files, frame.source.value(), &lines)) {
269 return false;
270 }
271 if (location->line > lines.size()) {
272 onerror_("Line " + std::to_string(location->line) +
273 " is greater than the number of lines in the source file (" +
274 std::to_string(lines.size()) + ")");
275 return false;
276 }
277 if (location->line < 1) {
278 onerror_("Line number reported was " + std::to_string(location->line));
279 return false;
280 }
281 *line = lines[location->line - 1];
282 }
283
284 return true;
285 }
286
287 // SourceContext retrieves the the SourceLines for the given source.
288 // Returns true on success, false on error.
SourceContent(VirtualFileStore * virtual_files,const dap::Source & source,SourceLines * out)289 bool SourceContent(VirtualFileStore* virtual_files,
290 const dap::Source& source,
291 SourceLines* out) {
292 auto path = source.path.value("");
293
294 if (source.sourceReference.has_value()) {
295 auto ref = source.sourceReference.value();
296 auto it = sourceCache_.by_ref.find(ref);
297 if (it != sourceCache_.by_ref.end()) {
298 *out = it->second;
299 return true;
300 }
301
302 dap::SourceRequest request;
303 dap::SourceResponse response;
304 request.sourceReference = ref;
305 if (!Send(request, &response)) {
306 return false;
307 }
308 auto lines = Split(response.content, "\n");
309 sourceCache_.by_ref.emplace(ref, lines);
310 *out = lines;
311 return true;
312 }
313
314 if (path != "") {
315 auto it = sourceCache_.by_path.find(path);
316 if (it != sourceCache_.by_path.end()) {
317 *out = it->second;
318 return true;
319 }
320
321 // First search the virtual file store.
322 std::string content;
323 if (virtual_files->Get(path, &content).IsSuccess()) {
324 auto lines = Split(content, "\n");
325 sourceCache_.by_path.emplace(path, lines);
326 *out = lines;
327 return true;
328 }
329
330 // TODO(bclayton) - We shouldn't be doing direct file IO here. We should
331 // bubble the IO request to the amber 'embedder'.
332 // See: https://github.com/google/amber/issues/777
333 if (auto file = std::ifstream(path)) {
334 SourceLines lines;
335 std::string line;
336 while (std::getline(file, line)) {
337 lines.emplace_back(line);
338 }
339
340 sourceCache_.by_path.emplace(path, lines);
341 *out = lines;
342 return true;
343 }
344 }
345
346 onerror_("Could not get source content");
347 return false;
348 }
349
350 // Send sends the request to the debugger, waits for the request to complete,
351 // and then assigns the response to |res|.
352 // Returns true on success, false on error.
353 template <typename REQUEST, typename RESPONSE>
Send(const REQUEST & request,RESPONSE * res)354 bool Send(const REQUEST& request, RESPONSE* res) {
355 auto r = session_->send(request).get();
356 if (r.error) {
357 onerror_(r.error.message);
358 return false;
359 }
360 *res = r.response;
361 return true;
362 }
363
364 // Send sends the request to the debugger, and waits for the request to
365 // complete.
366 // Returns true on success, false on error.
367 template <typename REQUEST>
Send(const REQUEST & request)368 bool Send(const REQUEST& request) {
369 using RESPONSE = typename REQUEST::Response;
370 RESPONSE response;
371 return Send(request, &response);
372 }
373
374 // GetVariables fetches the fully traversed set of Variables from the debugger
375 // for the given reference identifier.
376 // Returns true on success, false on error.
GetVariables(dap::integer variablesRef,Variables * out)377 bool GetVariables(dap::integer variablesRef, Variables* out) {
378 dap::VariablesRequest request;
379 dap::VariablesResponse response;
380 request.variablesReference = variablesRef;
381 if (!Send(request, &response)) {
382 return false;
383 }
384 for (auto var : response.variables) {
385 Variable v;
386 v.name = var.name;
387 v.value = var.value;
388 if (var.variablesReference > 0) {
389 if (!GetVariables(var.variablesReference, &v.children)) {
390 return false;
391 }
392 }
393 out->emplace_back(v);
394 }
395 return true;
396 }
397
398 // GetLocals fetches the fully traversed set of local Variables from the
399 // debugger for the given stack frame.
400 // Returns true on success, false on error.
GetLocals(const dap::StackFrame & frame,Variables * out)401 bool GetLocals(const dap::StackFrame& frame, Variables* out) {
402 dap::ScopesRequest scopeReq;
403 dap::ScopesResponse scopeRes;
404 scopeReq.frameId = frame.id;
405 if (!Send(scopeReq, &scopeRes)) {
406 return false;
407 }
408
409 for (auto scope : scopeRes.scopes) {
410 if (scope.presentationHint.value("") == kLocals) {
411 return GetVariables(scope.variablesReference, out);
412 }
413 }
414
415 onerror_("Locals scope not found");
416 return false;
417 }
418
419 // GetLane returns a pointer to the Variables representing the thread's SIMD
420 // lane with the given index, or nullptr if the lane was not found.
GetLane(const Variables & lanes,int lane)421 const Variables* GetLane(const Variables& lanes, int lane) {
422 auto out = lanes.Find(std::string(kLane) + " " + std::to_string(lane));
423 if (out == nullptr) {
424 return nullptr;
425 }
426 return &out->children;
427 }
428
429 private:
430 struct SourceCache {
431 std::unordered_map<int64_t, SourceLines> by_ref;
432 std::unordered_map<std::string, SourceLines> by_path;
433 };
434
435 std::shared_ptr<dap::Session> session_;
436 ErrorHandler onerror_;
437 SourceCache sourceCache_;
438 };
439
440 // InvocationKey is a tagged-union structure that identifies a single shader
441 // invocation.
442 struct InvocationKey {
443 // Hash is a custom hasher that can enable InvocationKeys to be used as keys
444 // in std containers.
445 struct Hash {
446 size_t operator()(const InvocationKey& key) const;
447 };
448
449 enum class Type { kGlobalInvocationId, kVertexIndex, kWindowSpacePosition };
450 union Data {
451 GlobalInvocationId global_invocation_id;
452 uint32_t vertex_id;
453 WindowSpacePosition window_space_position;
454 };
455
456 explicit InvocationKey(const GlobalInvocationId&);
457 explicit InvocationKey(const WindowSpacePosition&);
458 InvocationKey(Type, const Data&);
459
460 bool operator==(const InvocationKey& other) const;
461
462 // String returns a human-readable description of the key.
463 std::string String() const;
464
465 Type type;
466 Data data;
467 };
468
operator ()(const InvocationKey & key) const469 size_t InvocationKey::Hash::operator()(const InvocationKey& key) const {
470 size_t hash = 31 * static_cast<size_t>(key.type);
471 switch (key.type) {
472 case Type::kGlobalInvocationId:
473 hash += key.data.global_invocation_id.hash();
474 break;
475 case Type::kVertexIndex:
476 hash += key.data.vertex_id;
477 break;
478 case Type::kWindowSpacePosition:
479 hash += key.data.window_space_position.hash();
480 break;
481 }
482 return hash;
483 }
484
InvocationKey(const GlobalInvocationId & id)485 InvocationKey::InvocationKey(const GlobalInvocationId& id)
486 : type(Type::kGlobalInvocationId) {
487 data.global_invocation_id = id;
488 }
489
InvocationKey(const WindowSpacePosition & pos)490 InvocationKey::InvocationKey(const WindowSpacePosition& pos)
491 : type(Type::kWindowSpacePosition) {
492 data.window_space_position = pos;
493 }
494
InvocationKey(Type type_,const Data & data_)495 InvocationKey::InvocationKey(Type type_, const Data& data_)
496 : type(type_), data(data_) {}
497
String() const498 std::string InvocationKey::String() const {
499 std::stringstream ss;
500 switch (type) {
501 case Type::kGlobalInvocationId:
502 ss << "GlobalInvocation(" << data.global_invocation_id.x << ", "
503 << data.global_invocation_id.y << ", " << data.global_invocation_id.z
504 << ")";
505 break;
506 case Type::kVertexIndex:
507 ss << "VertexIndex(" << data.vertex_id << ")";
508 break;
509 case Type::kWindowSpacePosition:
510 ss << "WindowSpacePosition(" << data.window_space_position.x << ", "
511 << data.window_space_position.y << ")";
512 break;
513 }
514 return ss.str();
515 }
516
operator ==(const InvocationKey & other) const517 bool InvocationKey::operator==(const InvocationKey& other) const {
518 if (type != other.type) {
519 return false;
520 }
521 switch (type) {
522 case Type::kGlobalInvocationId:
523 return data.global_invocation_id == other.data.global_invocation_id;
524 case Type::kVertexIndex:
525 return data.vertex_id == other.data.vertex_id;
526 case Type::kWindowSpacePosition:
527 return data.window_space_position == other.data.window_space_position;
528 }
529 return false;
530 }
531
532 // Thread controls and verifies a single debugger thread of execution.
533 class Thread : public debug::Thread {
534 public:
Thread(VirtualFileStore * virtual_files,std::shared_ptr<dap::Session> session,dap::integer threadId,int lane,std::shared_ptr<const debug::ThreadScript> script)535 Thread(VirtualFileStore* virtual_files,
536 std::shared_ptr<dap::Session> session,
537 dap::integer threadId,
538 int lane,
539 std::shared_ptr<const debug::ThreadScript> script)
540 : virtual_files_(virtual_files),
541 thread_id_(threadId),
542 lane_(lane),
543 client_(session, [this](const std::string& err) { OnError(err); }) {
544 // The thread script runs concurrently with other debugger thread scripts.
545 // Run on a separate amber thread.
__anonf24a5b550402null546 thread_ = std::thread([this, script] {
547 script->Run(this); // Begin running the thread script.
548 done_.Signal(); // Signal when done.
549 });
550 }
551
~Thread()552 ~Thread() override { Flush(); }
553
554 // Flush waits for the debugger thread script to complete, and returns any
555 // errors encountered.
Flush()556 Result Flush() {
557 if (done_.Wait(kThreadTimeout)) {
558 if (thread_.joinable()) {
559 thread_.join();
560 }
561 } else {
562 error_ += "Timed out performing actions";
563 }
564 return error_;
565 }
566
567 // OnStep is called when the debugger sends a 'stopped' event with reason
568 // 'step'.
OnStep()569 void OnStep() { on_stepped_.Signal(); }
570
571 // WaitForSteppedEvent halts execution of the thread until OnStep() is
572 // called, or kStepEventTimeout is reached.
WaitForStepEvent(const std::string & request)573 void WaitForStepEvent(const std::string& request) {
574 if (!on_stepped_.Wait(kStepEventTimeout)) {
575 OnError("Timeout waiting for " + request + " request to complete");
576 }
577 }
578
579 // debug::Thread compliance
StepOver()580 void StepOver() override {
581 DEBUGGER_LOG("StepOver()");
582 dap::NextRequest request;
583 request.threadId = thread_id_;
584 client_.Send(request);
585 WaitForStepEvent("step-over");
586 }
587
StepIn()588 void StepIn() override {
589 DEBUGGER_LOG("StepIn()");
590 dap::StepInRequest request;
591 request.threadId = thread_id_;
592 client_.Send(request);
593 WaitForStepEvent("step-in");
594 }
595
StepOut()596 void StepOut() override {
597 DEBUGGER_LOG("StepOut()");
598 dap::StepOutRequest request;
599 request.threadId = thread_id_;
600 client_.Send(request);
601 WaitForStepEvent("step-out");
602 }
603
Continue()604 void Continue() override {
605 DEBUGGER_LOG("Continue()");
606 dap::ContinueRequest request;
607 request.threadId = thread_id_;
608 client_.Send(request);
609 }
610
ExpectLocation(const debug::Location & location,const std::string & line)611 void ExpectLocation(const debug::Location& location,
612 const std::string& line) override {
613 DEBUGGER_LOG("ExpectLocation('%s', %d)", location.file.c_str(),
614 location.line);
615
616 dap::StackFrame frame;
617 if (!client_.TopStackFrame(thread_id_, &frame)) {
618 return;
619 }
620
621 debug::Location got_location;
622 std::string got_source_line;
623 if (!client_.FrameLocation(virtual_files_, frame, &got_location,
624 &got_source_line)) {
625 return;
626 }
627
628 if (got_location.file != location.file) {
629 OnError("Expected file to be '" + location.file + "' but file was '" +
630 got_location.file + "'");
631 } else if (got_location.line != location.line) {
632 std::stringstream ss;
633 ss << "Expected line " << std::to_string(location.line);
634 if (line != "") {
635 ss << " `" << line << "`";
636 }
637 ss << " but line was " << std::to_string(got_location.line) << " `"
638 << got_source_line << "`";
639 OnError(ss.str());
640 } else if (line != "" && got_source_line != line) {
641 OnError("Expected source for line " + std::to_string(location.line) +
642 " to be: `" + line + "` but line was: `" + got_source_line + "`");
643 }
644 }
645
ExpectCallstack(const std::vector<debug::StackFrame> & callstack)646 void ExpectCallstack(
647 const std::vector<debug::StackFrame>& callstack) override {
648 DEBUGGER_LOG("ExpectCallstack()");
649
650 std::vector<dap::StackFrame> got_stack;
651 if (!client_.Callstack(thread_id_, &got_stack)) {
652 return;
653 }
654
655 std::stringstream ss;
656
657 size_t count = std::min(callstack.size(), got_stack.size());
658 for (size_t i = 0; i < count; i++) {
659 auto const& got_frame = got_stack[i];
660 auto const& want_frame = callstack[i];
661 bool ok = got_frame.name == want_frame.name;
662 if (ok && want_frame.location.file != "") {
663 ok = got_frame.source.has_value() &&
664 got_frame.source->name.value("") == want_frame.location.file;
665 }
666 if (ok && want_frame.location.line != 0) {
667 ok = got_frame.line == static_cast<int>(want_frame.location.line);
668 }
669 if (!ok) {
670 ss << "Unexpected stackframe at frame " << i
671 << "\nGot: " << FrameString(got_frame)
672 << "\nExpected: " << FrameString(want_frame) << "\n";
673 }
674 }
675
676 if (got_stack.size() > callstack.size()) {
677 ss << "Callstack has an additional "
678 << (got_stack.size() - callstack.size()) << " unexpected frames\n";
679 } else if (callstack.size() > got_stack.size()) {
680 ss << "Callstack is missing " << (callstack.size() - got_stack.size())
681 << " frames\n";
682 }
683
684 if (ss.str().size() > 0) {
685 ss << "Full callstack:\n";
686 for (auto& frame : got_stack) {
687 ss << " " << FrameString(frame) << "\n";
688 }
689 OnError(ss.str());
690 }
691 }
692
ExpectLocal(const std::string & name,int64_t value)693 void ExpectLocal(const std::string& name, int64_t value) override {
694 DEBUGGER_LOG("ExpectLocal('%s', %d)", name.c_str(), (int)value);
695 ExpectLocalT(name, value, [](int64_t a, int64_t b) { return a == b; });
696 }
697
ExpectLocal(const std::string & name,double value)698 void ExpectLocal(const std::string& name, double value) override {
699 DEBUGGER_LOG("ExpectLocal('%s', %f)", name.c_str(), value);
700 ExpectLocalT(name, value,
701 [](double a, double b) { return std::abs(a - b) < 0.00001; });
702 }
703
ExpectLocal(const std::string & name,const std::string & value)704 void ExpectLocal(const std::string& name, const std::string& value) override {
705 DEBUGGER_LOG("ExpectLocal('%s', '%s')", name.c_str(), value.c_str());
706 ExpectLocalT(name, value, [](const std::string& a, const std::string& b) {
707 return a == b;
708 });
709 }
710
711 // ExpectLocalT verifies that the local variable with the given name has the
712 // expected value. |name| may contain `.` delimiters to index structure or
713 // array types. ExpectLocalT uses the function equal to compare the expected
714 // and actual values, returning true if considered equal, otherwise false.
715 // EQUAL_FUNC is a function-like type with the signature:
716 // bool(T a, T b)
717 template <typename T, typename EQUAL_FUNC>
ExpectLocalT(const std::string & name,const T & expect,EQUAL_FUNC && equal)718 void ExpectLocalT(const std::string& name,
719 const T& expect,
720 EQUAL_FUNC&& equal) {
721 dap::StackFrame frame;
722 if (!client_.TopStackFrame(thread_id_, &frame)) {
723 return;
724 }
725
726 Variables locals;
727 if (!client_.GetLocals(frame, &locals)) {
728 return;
729 }
730
731 if (auto lane = client_.GetLane(locals, lane_)) {
732 auto owner = lane;
733 const Variable* var = nullptr;
734 std::string path;
735 for (auto part : Split(name, ".")) {
736 var = owner->Find(part);
737 if (!var) {
738 if (owner == lane) {
739 OnError("Local '" + name + "' not found.\nLocals: " +
740 lane->AllNames() + ".\nLanes: " + locals.AllNames() + ".");
741 } else {
742 OnError("Local '" + path + "' does not contain '" + part +
743 "'\nChildren: " + owner->AllNames());
744 }
745 return;
746 }
747 owner = &var->children;
748 path += (path.size() > 0) ? "." + part : part;
749 }
750
751 T got = {};
752 if (!var->Get(&got)) {
753 OnError("Local '" + name + "' was not of expected type");
754 return;
755 }
756
757 if (!equal(got, expect)) {
758 std::stringstream ss;
759 ss << "Local '" << name << "' did not have expected value. Value is '"
760 << got << "', expected '" << expect << "'";
761 OnError(ss.str());
762 return;
763 }
764 }
765 }
766
767 private:
OnError(const std::string & err)768 void OnError(const std::string& err) {
769 DEBUGGER_LOG("ERROR: %s", err.c_str());
770 error_ += err;
771 }
772
FrameString(const dap::StackFrame & frame)773 std::string FrameString(const dap::StackFrame& frame) {
774 std::stringstream ss;
775 ss << frame.name;
776 if (frame.source.has_value() && frame.source->name.has_value()) {
777 ss << " " << frame.source->name.value() << ":" << frame.line;
778 }
779 return ss.str();
780 }
781
FrameString(const debug::StackFrame & frame)782 std::string FrameString(const debug::StackFrame& frame) {
783 std::stringstream ss;
784 ss << frame.name;
785 if (frame.location.file != "") {
786 ss << " " << frame.location.file;
787 if (frame.location.line != 0) {
788 ss << ":" << frame.location.line;
789 }
790 }
791 return ss.str();
792 }
793
794 VirtualFileStore* const virtual_files_;
795 const dap::integer thread_id_;
796 const int lane_;
797 Client client_;
798 std::thread thread_;
799 Event done_;
800 Event on_stepped_;
801 Result error_;
802 };
803
804 // VkDebugger is a private implementation of the Engine::Debugger interface.
805 class VkDebugger : public Engine::Debugger {
806 static constexpr const char* kComputeShaderFunctionName = "ComputeShader";
807 static constexpr const char* kVertexShaderFunctionName = "VertexShader";
808 static constexpr const char* kFragmentShaderFunctionName = "FragmentShader";
809 static constexpr const char* kGlobalInvocationId = "globalInvocationId";
810 static constexpr const char* kWindowSpacePosition = "windowSpacePosition";
811 static constexpr const char* kVertexIndex = "vertexIndex";
812
813 public:
VkDebugger(VirtualFileStore * virtual_files)814 explicit VkDebugger(VirtualFileStore* virtual_files)
815 : virtual_files_(virtual_files) {}
816
817 /// Connect establishes the connection to the shader debugger. Must be
818 /// called before any of the |debug::Events| methods.
Connect()819 Result Connect() {
820 auto port_str = getenv("VK_DEBUGGER_PORT");
821 if (!port_str) {
822 return Result("The environment variable VK_DEBUGGER_PORT was not set\n");
823 }
824 int port = atoi(port_str);
825 if (port == 0) {
826 return Result("Could not parse port number from VK_DEBUGGER_PORT ('" +
827 std::string(port_str) + "')");
828 }
829
830 constexpr int kMaxAttempts = 10;
831 // The socket might take a while to open - retry connecting.
832 for (int attempt = 0; attempt < kMaxAttempts; attempt++) {
833 auto connection = dap::net::connect("localhost", port);
834 if (!connection) {
835 std::this_thread::sleep_for(std::chrono::seconds(1));
836 continue;
837 }
838
839 // Socket opened. Create the debugger session and bind.
840 session_ = dap::Session::create();
841 session_->bind(connection);
842
843 // Register the thread stopped event.
844 // This is fired when breakpoints are hit (amongst other reasons).
845 // See:
846 // https://microsoft.github.io/debug-adapter-protocol/specification#Events_Stopped
847 session_->registerHandler([&](const dap::StoppedEvent& event) {
848 DEBUGGER_LOG("THREAD STOPPED. Reason: %s", event.reason.c_str());
849 if (event.reason == "function breakpoint" ||
850 event.reason == "breakpoint") {
851 OnBreakpointHit(event.threadId.value(0));
852 } else if (event.reason == "step") {
853 OnStep(event.threadId.value(0));
854 }
855 });
856
857 // Start the debugger initialization sequence.
858 // See: https://microsoft.github.io/debug-adapter-protocol/overview for
859 // details.
860
861 dap::InitializeRequest init_req = {};
862 auto init_res = session_->send(init_req).get();
863 if (init_res.error) {
864 DEBUGGER_LOG("InitializeRequest failed: %s",
865 init_res.error.message.c_str());
866 return Result(init_res.error.message);
867 }
868
869 // Set breakpoints on the various shader types, we do this even if we
870 // don't actually care about these threads. Once the breakpoint is hit,
871 // the pendingThreads_ map is probed, if nothing matches the thread is
872 // resumed.
873 // TODO(bclayton): Once we have conditional breakpoint support, we can
874 // reduce the number of breakpoints / scope of breakpoints.
875 dap::SetFunctionBreakpointsRequest fbp_req = {};
876 dap::FunctionBreakpoint fbp = {};
877 fbp.name = kComputeShaderFunctionName;
878 fbp_req.breakpoints.emplace_back(fbp);
879 fbp.name = kVertexShaderFunctionName;
880 fbp_req.breakpoints.emplace_back(fbp);
881 fbp.name = kFragmentShaderFunctionName;
882 fbp_req.breakpoints.emplace_back(fbp);
883 auto fbp_res = session_->send(fbp_req).get();
884 if (fbp_res.error) {
885 DEBUGGER_LOG("SetFunctionBreakpointsRequest failed: %s",
886 fbp_res.error.message.c_str());
887 return Result(fbp_res.error.message);
888 }
889
890 // ConfigurationDone signals the initialization has completed.
891 dap::ConfigurationDoneRequest cfg_req = {};
892 auto cfg_res = session_->send(cfg_req).get();
893 if (cfg_res.error) {
894 DEBUGGER_LOG("ConfigurationDoneRequest failed: %s",
895 cfg_res.error.message.c_str());
896 return Result(cfg_res.error.message);
897 }
898
899 return Result();
900 }
901 return Result("Unable to connect to debugger on port " +
902 std::to_string(port));
903 }
904
905 // Flush checks that all breakpoints were hit, waits for all threads to
906 // complete, and returns the globbed together results for all threads.
Flush()907 Result Flush() override {
908 Result result;
909 {
910 std::unique_lock<std::mutex> lock(error_mutex_);
911 result += error_;
912 }
913 {
914 std::unique_lock<std::mutex> lock(threads_mutex_);
915 for (auto& pending : pendingThreads_) {
916 result += "Thread did not run: " + pending.first.String();
917 }
918 for (auto& it : runningThreads_) {
919 result += it.second->Flush();
920 }
921 runningThreads_.clear();
922 completedThreads_.clear();
923 }
924 return result;
925 }
926
927 // debug::Events compliance
BreakOnComputeGlobalInvocation(uint32_t x,uint32_t y,uint32_t z,const std::shared_ptr<const debug::ThreadScript> & script)928 void BreakOnComputeGlobalInvocation(
929 uint32_t x,
930 uint32_t y,
931 uint32_t z,
932 const std::shared_ptr<const debug::ThreadScript>& script) override {
933 std::unique_lock<std::mutex> lock(threads_mutex_);
934 pendingThreads_.emplace(GlobalInvocationId{x, y, z}, script);
935 }
936
BreakOnVertexIndex(uint32_t index,const std::shared_ptr<const debug::ThreadScript> & script)937 void BreakOnVertexIndex(
938 uint32_t index,
939 const std::shared_ptr<const debug::ThreadScript>& script) override {
940 InvocationKey::Data data;
941 data.vertex_id = index;
942 auto key = InvocationKey{InvocationKey::Type::kVertexIndex, data};
943 std::unique_lock<std::mutex> lock(threads_mutex_);
944 pendingThreads_.emplace(key, script);
945 }
946
BreakOnFragmentWindowSpacePosition(uint32_t x,uint32_t y,const std::shared_ptr<const debug::ThreadScript> & script)947 void BreakOnFragmentWindowSpacePosition(
948 uint32_t x,
949 uint32_t y,
950 const std::shared_ptr<const debug::ThreadScript>& script) override {
951 std::unique_lock<std::mutex> lock(threads_mutex_);
952 pendingThreads_.emplace(WindowSpacePosition{x, y}, script);
953 }
954
955 private:
StartThread(dap::integer thread_id,int lane,const std::shared_ptr<const amber::debug::ThreadScript> & script)956 void StartThread(
957 dap::integer thread_id,
958 int lane,
959 const std::shared_ptr<const amber::debug::ThreadScript>& script) {
960 auto thread =
961 MakeUnique<Thread>(virtual_files_, session_, thread_id, lane, script);
962 auto it = runningThreads_.find(thread_id);
963 if (it != runningThreads_.end()) {
964 // The debugger has now started working on something else for the given
965 // thread id.
966 // Flush the Thread to ensure that all tasks have actually finished, and
967 // move the Thread to the completedThreads_ list.
968 auto completed_thread = std::move(it->second);
969 error_ += completed_thread->Flush();
970 completedThreads_.emplace_back(std::move(completed_thread));
971 }
972 runningThreads_[thread_id] = std::move(thread);
973 }
974
975 // OnBreakpointHit is called when a debugger breakpoint is hit (breakpoints
976 // are set at shader entry points). pendingThreads_ is checked to see if this
977 // thread needs testing, and if so, creates a new ::Thread.
978 // If there's no pendingThread_ entry for the given thread, it is resumed to
979 // allow the shader to continue executing.
OnBreakpointHit(dap::integer thread_id)980 void OnBreakpointHit(dap::integer thread_id) {
981 DEBUGGER_LOG("Breakpoint hit: thread %d", (int)thread_id);
982 Client client(session_, [this](const std::string& err) { OnError(err); });
983
984 std::unique_lock<std::mutex> lock(threads_mutex_);
985 for (auto it = pendingThreads_.begin(); it != pendingThreads_.end(); it++) {
986 auto& key = it->first;
987 auto& script = it->second;
988 switch (key.type) {
989 case InvocationKey::Type::kGlobalInvocationId: {
990 auto invocation_id = key.data.global_invocation_id;
991 int lane;
992 if (FindGlobalInvocationId(thread_id, invocation_id, &lane)) {
993 DEBUGGER_LOG("Breakpoint hit: GetGlobalInvocationId: [%d, %d, %d]",
994 invocation_id.x, invocation_id.y, invocation_id.z);
995 StartThread(thread_id, lane, script);
996 pendingThreads_.erase(it);
997 return;
998 }
999 break;
1000 }
1001 case InvocationKey::Type::kVertexIndex: {
1002 auto vertex_id = key.data.vertex_id;
1003 int lane;
1004 if (FindVertexIndex(thread_id, vertex_id, &lane)) {
1005 DEBUGGER_LOG("Breakpoint hit: VertexId: %d", vertex_id);
1006 StartThread(thread_id, lane, script);
1007 pendingThreads_.erase(it);
1008 return;
1009 }
1010 break;
1011 }
1012 case InvocationKey::Type::kWindowSpacePosition: {
1013 auto position = key.data.window_space_position;
1014 int lane;
1015 if (FindWindowSpacePosition(thread_id, position, &lane)) {
1016 DEBUGGER_LOG("Breakpoint hit: VertexId: [%d, %d]", position.x,
1017 position.y);
1018 StartThread(thread_id, lane, script);
1019 pendingThreads_.erase(it);
1020 return;
1021 }
1022 break;
1023 }
1024 }
1025 }
1026
1027 // No pending tests for this thread. Let it carry on...
1028 dap::ContinueRequest request;
1029 request.threadId = thread_id;
1030 client.Send(request);
1031 }
1032
1033 // OnStep is called when a debugger finishes a step request.
OnStep(dap::integer thread_id)1034 void OnStep(dap::integer thread_id) {
1035 DEBUGGER_LOG("Step event: thread %d", (int)thread_id);
1036 Client client(session_, [this](const std::string& err) { OnError(err); });
1037
1038 std::unique_lock<std::mutex> lock(threads_mutex_);
1039 auto it = runningThreads_.find(thread_id);
1040 if (it == runningThreads_.end()) {
1041 OnError("Step event reported for non-running thread " +
1042 std::to_string(thread_id.operator int64_t()));
1043 }
1044
1045 it->second->OnStep();
1046 }
1047
1048 // FindLocal looks for the shader's local variable with the given name and
1049 // value in the stack frames' locals, returning true if found, and assigns the
1050 // index of the SIMD lane it was found in to |lane|.
1051 template <typename T>
FindLocal(dap::integer thread_id,const char * name,const T & value,int * lane)1052 bool FindLocal(dap::integer thread_id,
1053 const char* name,
1054 const T& value,
1055 int* lane) {
1056 Client client(session_, [this](const std::string& err) { OnError(err); });
1057
1058 dap::StackFrame frame;
1059 if (!client.TopStackFrame(thread_id, &frame)) {
1060 return false;
1061 }
1062
1063 dap::ScopesRequest scopeReq;
1064 dap::ScopesResponse scopeRes;
1065 scopeReq.frameId = frame.id;
1066 if (!client.Send(scopeReq, &scopeRes)) {
1067 return false;
1068 }
1069
1070 Variables locals;
1071 if (!client.GetLocals(frame, &locals)) {
1072 return false;
1073 }
1074
1075 for (int i = 0;; i++) {
1076 auto lane_var = client.GetLane(locals, i);
1077 if (!lane_var) {
1078 break;
1079 }
1080 if (auto var = lane_var->Find(name)) {
1081 T got;
1082 if (var->Get(&got)) {
1083 if (got == value) {
1084 *lane = i;
1085 return true;
1086 }
1087 }
1088 }
1089 }
1090
1091 return false;
1092 }
1093
1094 // FindGlobalInvocationId looks for the compute shader's global invocation id
1095 // in the stack frames' locals, returning true if found, and assigns the index
1096 // of the SIMD lane it was found in to |lane|.
1097 // TODO(bclayton): This value should probably be in the globals, not locals!
FindGlobalInvocationId(dap::integer thread_id,const GlobalInvocationId & id,int * lane)1098 bool FindGlobalInvocationId(dap::integer thread_id,
1099 const GlobalInvocationId& id,
1100 int* lane) {
1101 return FindLocal(thread_id, kGlobalInvocationId, id, lane);
1102 }
1103
1104 // FindVertexIndex looks for the requested vertex shader's vertex index in the
1105 // stack frames' locals, returning true if found, and assigns the index of the
1106 // SIMD lane it was found in to |lane|.
1107 // TODO(bclayton): This value should probably be in the globals, not locals!
FindVertexIndex(dap::integer thread_id,uint32_t index,int * lane)1108 bool FindVertexIndex(dap::integer thread_id, uint32_t index, int* lane) {
1109 return FindLocal(thread_id, kVertexIndex, index, lane);
1110 }
1111
1112 // FindWindowSpacePosition looks for the fragment shader's window space
1113 // position in the stack frames' locals, returning true if found, and assigns
1114 // the index of the SIMD lane it was found in to |lane|.
1115 // TODO(bclayton): This value should probably be in the globals, not locals!
FindWindowSpacePosition(dap::integer thread_id,const WindowSpacePosition & pos,int * lane)1116 bool FindWindowSpacePosition(dap::integer thread_id,
1117 const WindowSpacePosition& pos,
1118 int* lane) {
1119 return FindLocal(thread_id, kWindowSpacePosition, pos, lane);
1120 }
1121
OnError(const std::string & error)1122 void OnError(const std::string& error) {
1123 DEBUGGER_LOG("ERROR: %s", error.c_str());
1124 error_ += error;
1125 }
1126
1127 using PendingThreadsMap =
1128 std::unordered_map<InvocationKey,
1129 std::shared_ptr<const debug::ThreadScript>,
1130 InvocationKey::Hash>;
1131 using ThreadVector = std::vector<std::unique_ptr<Thread>>;
1132 using ThreadMap = std::unordered_map<int64_t, std::unique_ptr<Thread>>;
1133 VirtualFileStore* const virtual_files_;
1134 std::shared_ptr<dap::Session> session_;
1135 std::mutex threads_mutex_;
1136
1137 // The threads that are waiting to be started.
1138 PendingThreadsMap pendingThreads_; // guarded by threads_mutex_
1139
1140 // All threads that have been started.
1141 ThreadMap runningThreads_; // guarded by threads_mutex_
1142
1143 // Threads that have finished.
1144 ThreadVector completedThreads_; // guarded by threads_mutex_
1145 std::mutex error_mutex_;
1146 Result error_; // guarded by error_mutex_
1147 };
1148 } // namespace
1149
GetDebugger(VirtualFileStore * virtual_files)1150 std::pair<Engine::Debugger*, Result> EngineVulkan::GetDebugger(
1151 VirtualFileStore* virtual_files) {
1152 if (!debugger_) {
1153 auto debugger = new VkDebugger(virtual_files);
1154 debugger_.reset(debugger);
1155 auto res = debugger->Connect();
1156 if (!res.IsSuccess()) {
1157 return {nullptr, res};
1158 }
1159 }
1160 return {debugger_.get(), Result()};
1161 }
1162
1163 } // namespace vulkan
1164 } // namespace amber
1165
1166 #else // AMBER_ENABLE_VK_DEBUGGING
1167
1168 namespace amber {
1169 namespace vulkan {
1170
GetDebugger(VirtualFileStore *)1171 std::pair<Engine::Debugger*, Result> EngineVulkan::GetDebugger(
1172 VirtualFileStore*) {
1173 return {nullptr,
1174 Result("Amber was not built with AMBER_ENABLE_VK_DEBUGGING enabled")};
1175 }
1176
1177 } // namespace vulkan
1178 } // namespace amber
1179
1180 #endif // AMBER_ENABLE_VK_DEBUGGING
1181