• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (c) 2024 Huawei Device Co., Ltd.
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import dataclasses
19from contextlib import asynccontextmanager
20from dataclasses import dataclass
21from inspect import getfullargspec
22from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeAlias
23
24import trio
25import trio_cdp
26from cdp import debugger, runtime
27
28from arkdb.compiler import StringCodeCompiler
29from arkdb.debug_connection import ArkConnection, Proxy, ScriptsCache, SourcesCache, connect_cdp
30
31T: TypeAlias = trio_cdp.T
32
33
34@dataclass
35class DebuggerConfig:
36    pause_on_exceptions_mode: Literal["none", "caught", "uncaught", "all"] = "none"
37
38
39class DebuggerClient:
40    def __init__(
41        self,
42        connection: ArkConnection,
43        config: DebuggerConfig,
44        debugger_id: runtime.UniqueDebuggerId,
45        scripts: ScriptsCache,
46        sources: SourcesCache,
47        context: runtime.ExecutionContextDescription,
48        code_compiler: StringCodeCompiler,
49    ) -> None:
50        self.connection = connection
51        self.config = config
52        self.debugger_id = debugger_id
53        self.scripts = scripts
54        self.sources = sources
55        self.context = context
56        self.code_compiler = code_compiler
57
58    async def configure(self, nursery: trio.Nursery):
59        self._listen(nursery, self._create_on_execution_contexts_cleared(nursery))
60        await self.set_pause_on_exceptions()
61
62    async def run_if_waiting_for_debugger(self) -> debugger.Paused:
63        return await self.connection.send_and_wait_for(
64            runtime.run_if_waiting_for_debugger(),
65            debugger.Paused,
66        )
67
68    @asynccontextmanager
69    async def wait_for(self, event_type: Type[T], buffer_size=1):
70        proxy: Proxy[T]
71        async with self.connection.wait_for(event_type=event_type, buffer_size=buffer_size) as proxy:
72            yield proxy
73
74    async def set_pause_on_exceptions(
75        self,
76        mode: Optional[Literal["none", "caught", "uncaught", "all"]] = None,
77    ):
78        await self.connection.send(
79            debugger.set_pause_on_exceptions(
80                mode if mode is not None else self.config.pause_on_exceptions_mode,
81            ),
82        )
83
84    async def resume(self) -> debugger.Resumed:
85        return await self.connection.send_and_wait_for(
86            debugger.resume(),
87            debugger.Resumed,
88        )
89
90    async def resume_and_wait_for_paused(self) -> debugger.Paused:
91        async with self.wait_for(debugger.Paused) as proxy:
92            await self.resume()
93        await trio.lowlevel.checkpoint()
94        return proxy.value
95
96    async def send_and_wait_for_paused(self, send_arg) -> debugger.Paused:
97        async with self.wait_for(debugger.Paused) as proxy:
98            await self.connection.send(send_arg)
99        await trio.lowlevel.checkpoint()
100        return proxy.value
101
102    async def continue_to_location(
103        self,
104        script_id: runtime.ScriptId,
105        line_number: int,
106    ) -> debugger.Paused:
107        return await self.connection.send_and_wait_for(
108            debugger.continue_to_location(
109                location=debugger.Location(
110                    script_id=script_id,
111                    line_number=line_number,
112                ),
113            ),
114            debugger.Paused,
115        )
116
117    async def get_script_source(
118        self,
119        script_id: runtime.ScriptId,
120    ) -> str:
121        return await self.connection.send(debugger.get_script_source(script_id))
122
123    async def get_script_source_cached(
124        self,
125        script_id: runtime.ScriptId,
126    ) -> str:
127        return await self.sources.get(script_id, self.get_script_source)
128
129    async def get_properties(
130        self,
131        object_id: runtime.RemoteObjectId,
132        own_properties: Optional[bool] = None,
133        accessor_properties_only: Optional[bool] = None,
134        generate_preview: Optional[bool] = None,
135    ) -> Tuple[
136        List[runtime.PropertyDescriptor],
137        Optional[List[runtime.InternalPropertyDescriptor]],
138        Optional[List[runtime.PrivatePropertyDescriptor]],
139        Optional[runtime.ExceptionDetails],
140    ]:
141        return await self.connection.send(
142            runtime.get_properties(
143                object_id=object_id,
144                own_properties=own_properties,
145                accessor_properties_only=accessor_properties_only,
146                generate_preview=generate_preview,
147            )
148        )
149
150    async def set_breakpoint(
151        self,
152        location: debugger.Location,
153        condition: Optional[str] = None,
154    ) -> Tuple[debugger.BreakpointId, debugger.Location]:
155        return await self.connection.send(
156            debugger.set_breakpoint(
157                location=location,
158                condition=condition,
159            ),
160        )
161
162    async def set_breakpoint_by_url(self, *args, **kwargs) -> Tuple[debugger.BreakpointId, List[debugger.Location]]:
163        return await self.connection.send(debugger.set_breakpoint_by_url(*args, **kwargs))
164
165    async def get_possible_breakpoints(
166        self,
167        start: debugger.Location,
168        end: Optional[debugger.Location] = None,
169        restrict_to_function: Optional[bool] = None,
170    ) -> List[debugger.BreakLocation]:
171        return await self.connection.send(
172            debugger.get_possible_breakpoints(
173                start=start,
174                end=end,
175                restrict_to_function=restrict_to_function,
176            )
177        )
178
179    async def set_breakpoints_active(self, active: bool) -> None:
180        await self.connection.send(debugger.set_breakpoints_active(active=active))
181
182    async def evaluate(self, expression: str) -> tuple[runtime.RemoteObject, runtime.ExceptionDetails | None]:
183        return await self.connection.send(runtime.evaluate(expression))
184
185    async def restart_frame(self, frame_number: int) -> debugger.Paused:
186        return await self.send_and_wait_for_paused(debugger.restart_frame(debugger.CallFrameId(str(frame_number))))
187
188    async def step_into(self) -> debugger.Paused:
189        return await self.send_and_wait_for_paused(debugger.step_into())
190
191    async def step_out(self) -> debugger.Paused:
192        return await self.send_and_wait_for_paused(debugger.step_out())
193
194    async def step_over(self) -> debugger.Paused:
195        return await self.send_and_wait_for_paused(debugger.step_over())
196
197    def _create_on_execution_contexts_cleared(self, nursery: trio.Nursery):
198        def _on_execution_contexts_cleared(_: runtime.ExecutionContextsCleared):
199            # A deadlock can occur when client awaits a response after server's disconnect.
200            # ArkTS debugger implementation notifies about execution end via `runtime.ExecutionContextsCleared` event,
201            # which is used here to force client disconnect.
202            nursery.cancel_scope.cancel()
203
204        return _on_execution_contexts_cleared
205
206    def _listen(
207        self,
208        nursery: trio.Nursery,
209        handler: Callable[[T], None],
210    ):
211        async def _t():
212            args_annotations = getfullargspec(handler).annotations
213            event_type = list(args_annotations.values())[0]
214            # Passing `T` as event type will not work
215            async for event in self.connection.listen(event_type):
216                handler(event)
217
218        nursery.start_soon(_t)
219
220
221@asynccontextmanager
222async def create_debugger_client(
223    connection: ArkConnection,
224    scripts: ScriptsCache,
225    sources: SourcesCache,
226    code_compiler: StringCodeCompiler,
227    debugger_config: DebuggerConfig = DebuggerConfig(),
228) -> AsyncIterator[DebuggerClient]:
229    context = await connection.send_and_wait_for(
230        runtime.enable(),
231        runtime.ExecutionContextCreated,
232    )
233    debugger_id = await connection.send(
234        debugger.enable(),
235    )
236    yield DebuggerClient(
237        connection=connection,
238        config=debugger_config,
239        debugger_id=debugger_id,
240        scripts=scripts,
241        sources=sources,
242        context=context.context,
243        code_compiler=code_compiler,
244    )
245
246
247class BreakpointManager:
248
249    def __init__(self, client: DebuggerClient) -> None:
250        self._lock = trio.Lock()
251        self.client = client
252        self._breaks: Dict[debugger.BreakpointId, List[debugger.Location]] = dict()
253
254    async def set_by_url(self, line_number: int, url: Optional[str]) -> None:
255        await self.client.set_breakpoints_active(True)
256        br, locs = await self.client.set_breakpoint_by_url(line_number=line_number, url=url)
257        async with self._lock:
258            self._breaks[br] = locs
259
260    async def get(self, bp_id: debugger.BreakpointId) -> Optional[List[debugger.Location]]:
261        async with self._lock:
262            return self._breaks.get(bp_id)
263
264    @asynccontextmanager
265    async def get_all(self):
266        async with self._lock:
267            breaks = self._breaks.copy()
268        for br, locs in breaks:
269            yield (br, locs)
270            await trio.lowlevel.checkpoint()
271
272    async def get_possible_breakpoints(self) -> Dict[debugger.BreakpointId, List[debugger.BreakLocation]]:
273        # Сhrome does this after set_breakpoint_by_url
274        async with self.get_all() as pair:
275            return {
276                br: br_locs
277                async for br, locs in pair
278                for br_locs in [
279                    await self.client.get_possible_breakpoints(
280                        start=dataclasses.replace(loc, column_number=0),
281                        end=dataclasses.replace(loc, column_number=1),
282                    )
283                    for loc in locs
284                ]
285            }
286
287
288class DebugLocator:
289    scripts: ScriptsCache
290    sources: SourcesCache
291
292    def __init__(self, code_compiler: StringCodeCompiler, url: Any) -> None:
293        self.code_compiler = code_compiler
294        self.url = url
295        self.scripts = ScriptsCache()
296        self.sources = SourcesCache()
297
298    @asynccontextmanager
299    async def connect(self, nursery: trio.Nursery) -> AsyncIterator[DebuggerClient]:
300        cdp = await connect_cdp(nursery, self.url, 10)
301        async with cdp:
302            connection = ArkConnection(cdp, nursery)
303            async with create_debugger_client(
304                connection,
305                self.scripts,
306                self.sources,
307                self.code_compiler,
308            ) as debugger_client:
309                yield debugger_client
310