• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (c) 2024-2025 Huawei Device Co., Ltd.
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import dataclasses
19from contextlib import asynccontextmanager
20from dataclasses import dataclass
21from inspect import getfullargspec
22from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeAlias
23
24import trio
25import trio_cdp
26from cdp import debugger, runtime, profiler
27
28from arkdb.compiler import StringCodeCompiler
29from arkdb.debug_connection import ArkConnection, Proxy, ScriptsCache, SourcesCache, connect_cdp
30from arkdb.extensions import debugger as ext_debugger
31from arkdb.extensions import profiler as ext_profiler
32
33T: TypeAlias = trio_cdp.T
34
35
36@dataclass
37class DebuggerConfig:
38    pause_on_exceptions_mode: Literal["none", "caught", "uncaught", "all"] = "none"
39
40
41class DebuggerClient:
42    def __init__(
43        self,
44        connection: ArkConnection,
45        config: DebuggerConfig,
46        debugger_id: runtime.UniqueDebuggerId,
47        scripts: ScriptsCache,
48        sources: SourcesCache,
49        context: runtime.ExecutionContextDescription,
50        code_compiler: StringCodeCompiler,
51    ) -> None:
52        self.connection = connection
53        self.config = config
54        self.debugger_id = debugger_id
55        self.scripts = scripts
56        self.sources = sources
57        self.context = context
58        self.code_compiler = code_compiler
59
60    async def configure(self, nursery: trio.Nursery):
61        self._listen(nursery, self._create_on_execution_contexts_cleared(nursery))
62        await self.set_pause_on_exceptions()
63
64    async def run_if_waiting_for_debugger(self) -> debugger.Paused:
65        return await self.connection.send_and_wait_for(
66            runtime.run_if_waiting_for_debugger(),
67            debugger.Paused,
68        )
69
70    @asynccontextmanager
71    async def wait_for(self, event_type: Type[T], buffer_size=1):
72        proxy: Proxy[T]
73        async with self.connection.wait_for(event_type=event_type, buffer_size=buffer_size) as proxy:
74            yield proxy
75
76    async def set_pause_on_exceptions(
77        self,
78        mode: Optional[Literal["none", "caught", "uncaught", "all"]] = None,
79    ):
80        await self.connection.send(
81            debugger.set_pause_on_exceptions(
82                mode if mode is not None else self.config.pause_on_exceptions_mode,
83            ),
84        )
85
86    async def resume(self) -> debugger.Resumed:
87        return await self.connection.send_and_wait_for(
88            debugger.resume(),
89            debugger.Resumed,
90        )
91
92    async def resume_and_wait_for_paused(self) -> debugger.Paused:
93        async with self.wait_for(debugger.Paused) as proxy:
94            await self.resume()
95        await trio.lowlevel.checkpoint()
96        return proxy.value
97
98    async def send_and_wait_for_paused(self, send_arg) -> debugger.Paused:
99        async with self.wait_for(debugger.Paused) as proxy:
100            await self.connection.send(send_arg)
101        await trio.lowlevel.checkpoint()
102        return proxy.value
103
104    async def continue_to_location(
105        self,
106        script_id: runtime.ScriptId,
107        line_number: int,
108    ) -> debugger.Paused:
109        return await self.connection.send_and_wait_for(
110            debugger.continue_to_location(
111                location=debugger.Location(
112                    script_id=script_id,
113                    line_number=line_number,
114                ),
115            ),
116            debugger.Paused,
117        )
118
119    async def get_script_source(
120        self,
121        script_id: runtime.ScriptId,
122    ) -> str:
123        return await self.connection.send(debugger.get_script_source(script_id))
124
125    async def get_script_source_cached(
126        self,
127        script_id: runtime.ScriptId,
128    ) -> str:
129        return await self.sources.get(script_id, self.get_script_source)
130
131    async def get_properties(
132        self,
133        object_id: runtime.RemoteObjectId,
134        own_properties: Optional[bool] = None,
135        accessor_properties_only: Optional[bool] = None,
136        generate_preview: Optional[bool] = None,
137    ) -> Tuple[
138        List[runtime.PropertyDescriptor],
139        Optional[List[runtime.InternalPropertyDescriptor]],
140        Optional[List[runtime.PrivatePropertyDescriptor]],
141        Optional[runtime.ExceptionDetails],
142    ]:
143        return await self.connection.send(
144            runtime.get_properties(
145                object_id=object_id,
146                own_properties=own_properties,
147                accessor_properties_only=accessor_properties_only,
148                generate_preview=generate_preview,
149            )
150        )
151
152    async def set_breakpoint(
153        self,
154        location: debugger.Location,
155        condition: Optional[str] = None,
156    ) -> Tuple[debugger.BreakpointId, debugger.Location]:
157        return await self.connection.send(
158            debugger.set_breakpoint(
159                location=location,
160                condition=condition,
161            ),
162        )
163
164    async def remove_breakpoint(self, breakpoint_id: debugger.BreakpointId) -> None:
165        return await self.connection.send(
166            debugger.remove_breakpoint(
167                breakpoint_id=breakpoint_id,
168            ),
169        )
170
171    async def remove_breakpoints_by_url(self, url: str) -> None:
172        await self.connection.send(ext_debugger.remove_breakpoints_by_url(url))
173
174    async def set_breakpoint_by_url(self, *args, **kwargs) -> Tuple[debugger.BreakpointId, List[debugger.Location]]:
175        return await self.connection.send(debugger.set_breakpoint_by_url(*args, **kwargs))
176
177    async def get_possible_breakpoints(
178        self,
179        start: debugger.Location,
180        end: Optional[debugger.Location] = None,
181        restrict_to_function: Optional[bool] = None,
182    ) -> List[debugger.BreakLocation]:
183        return await self.connection.send(
184            debugger.get_possible_breakpoints(
185                start=start,
186                end=end,
187                restrict_to_function=restrict_to_function,
188            )
189        )
190
191    async def get_possible_and_set_breakpoint_by_url(
192        self,
193        locations: List[ext_debugger.UrlBreakpointRequest],
194    ) -> List[ext_debugger.CustomUrlBreakpointResponse]:
195        return await self.connection.send(ext_debugger.get_possible_and_set_breakpoint_by_url(locations))
196
197    async def set_breakpoints_active(self, active: bool) -> None:
198        await self.connection.send(debugger.set_breakpoints_active(active=active))
199
200    async def set_skip_all_pauses(self, skip: bool) -> None:
201        await self.connection.send(debugger.set_skip_all_pauses(skip=skip))
202
203    async def evaluate_on_call_frame(
204        self,
205        call_frame_id: debugger.CallFrameId,
206        expression: str,
207        object_group: Optional[str] = None,
208        include_command_line_api: Optional[bool] = None,
209        silent: Optional[bool] = None,
210        return_by_value: Optional[bool] = None,
211    ) -> Tuple[runtime.RemoteObject, Optional[runtime.ExceptionDetails]]:
212        return await self.connection.send(
213            debugger.evaluate_on_call_frame(
214                call_frame_id=call_frame_id,
215                expression=expression,
216                object_group=object_group,
217                include_command_line_api=include_command_line_api,
218                silent=silent,
219                return_by_value=return_by_value,
220            )
221        )
222
223    async def evaluate(self, expression: str) -> tuple[runtime.RemoteObject, runtime.ExceptionDetails | None]:
224        return await self.connection.send(runtime.evaluate(expression))
225
226    async def restart_frame(self, frame_number: int) -> debugger.Paused:
227        return await self.send_and_wait_for_paused(debugger.restart_frame(debugger.CallFrameId(str(frame_number))))
228
229    async def step_into(self) -> debugger.Paused:
230        return await self.send_and_wait_for_paused(debugger.step_into())
231
232    async def step_out(self) -> debugger.Paused:
233        return await self.send_and_wait_for_paused(debugger.step_out())
234
235    async def step_over(self) -> debugger.Paused:
236        return await self.send_and_wait_for_paused(debugger.step_over())
237
238    async def profiler_enable(self) -> None:
239        await self.connection.send(profiler.enable())
240
241    async def profiler_set_sampling_interval(self, interval: int) -> None:
242        await self.connection.send(profiler.set_sampling_interval(interval))
243
244    async def profiler_start(self) -> None:
245        await self.connection.send(profiler.start())
246
247    async def profiler_stop(self) -> ext_profiler.ProfileArray:
248        return await self.connection.send(ext_profiler.profiler_stop())
249
250    async def profiler_disable(self) -> None:
251        await self.connection.send(profiler.disable())
252
253    def _create_on_execution_contexts_cleared(self, nursery: trio.Nursery):
254        def _on_execution_contexts_cleared(_: runtime.ExecutionContextsCleared):
255            # A deadlock can occur when client awaits a response after server's disconnect.
256            # ArkTS debugger implementation notifies about execution end via `runtime.ExecutionContextsCleared` event,
257            # which is used here to force client disconnect.
258            nursery.cancel_scope.cancel()
259
260        return _on_execution_contexts_cleared
261
262    def _listen(
263        self,
264        nursery: trio.Nursery,
265        handler: Callable[[T], None],
266    ):
267        async def _t():
268            args_annotations = getfullargspec(handler).annotations
269            event_type = list(args_annotations.values())[0]
270            # Passing `T` as event type will not work
271            async for event in self.connection.listen(event_type):
272                handler(event)
273
274        nursery.start_soon(_t)
275
276
277@asynccontextmanager
278async def create_debugger_client(
279    connection: ArkConnection,
280    scripts: ScriptsCache,
281    sources: SourcesCache,
282    code_compiler: StringCodeCompiler,
283    debugger_config: DebuggerConfig = DebuggerConfig(),
284) -> AsyncIterator[DebuggerClient]:
285    context = await connection.send_and_wait_for(
286        runtime.enable(),
287        runtime.ExecutionContextCreated,
288    )
289    debugger_id = await connection.send(
290        debugger.enable(),
291    )
292    yield DebuggerClient(
293        connection=connection,
294        config=debugger_config,
295        debugger_id=debugger_id,
296        scripts=scripts,
297        sources=sources,
298        context=context.context,
299        code_compiler=code_compiler,
300    )
301
302
303class BreakpointManager:
304
305    def __init__(self, client: DebuggerClient) -> None:
306        self._lock = trio.Lock()
307        self.client = client
308        self._breaks: Dict[debugger.BreakpointId, List[debugger.Location]] = dict()
309
310    async def set_by_url(self, line_number: int, url: Optional[str]) -> None:
311        await self.client.set_breakpoints_active(True)
312        br, locs = await self.client.set_breakpoint_by_url(line_number=line_number, url=url)
313        async with self._lock:
314            self._breaks[br] = locs
315
316    async def get(self, bp_id: debugger.BreakpointId) -> Optional[List[debugger.Location]]:
317        async with self._lock:
318            return self._breaks.get(bp_id)
319
320    @asynccontextmanager
321    async def get_all(self):
322        async with self._lock:
323            breaks = self._breaks.copy()
324        for br, locs in breaks:
325            yield (br, locs)
326            await trio.lowlevel.checkpoint()
327
328    async def get_possible_breakpoints(self) -> Dict[debugger.BreakpointId, List[debugger.BreakLocation]]:
329        # Сhrome does this after set_breakpoint_by_url
330        async with self.get_all() as pair:
331            return {
332                br: br_locs
333                async for br, locs in pair
334                for br_locs in [
335                    await self.client.get_possible_breakpoints(
336                        start=dataclasses.replace(loc, column_number=0),
337                        end=dataclasses.replace(loc, column_number=1),
338                    )
339                    for loc in locs
340                ]
341            }
342
343
344class DebugLocator:
345    scripts: ScriptsCache
346    sources: SourcesCache
347
348    def __init__(self, code_compiler: StringCodeCompiler, url: Any) -> None:
349        self.code_compiler = code_compiler
350        self.url = url
351        self.scripts = ScriptsCache()
352        self.sources = SourcesCache()
353
354    @asynccontextmanager
355    async def connect(self, nursery: trio.Nursery) -> AsyncIterator[DebuggerClient]:
356        cdp = await connect_cdp(nursery, self.url, 10)
357        async with cdp:
358            connection = ArkConnection(cdp, nursery)
359            async with create_debugger_client(
360                connection,
361                self.scripts,
362                self.sources,
363                self.code_compiler,
364            ) as debugger_client:
365                yield debugger_client
366