• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (c) 2024 Huawei Device Co., Ltd.
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import typing
19from contextlib import asynccontextmanager
20from typing import Any, AsyncGenerator, Callable, Coroutine, Dict, Generator, Optional, TypeAlias
21
22import cdp
23import cdp.util
24import trio
25import trio_cdp
26from cdp import debugger, runtime
27
28T: TypeAlias = trio_cdp.T
29E = typing.TypeVar("E")
30T_JSON_DICT: TypeAlias = cdp.util.T_JSON_DICT
31
32
33class Proxy(typing.Generic[E]):
34    def __init__(self) -> None:
35        self._value: Optional[E] = None
36
37    @property
38    def value(self) -> E:
39        if self._value is None:
40            raise ValueError()
41        return self._value
42
43    @value.setter
44    def value(self, value: E):
45        self._value = value
46
47
48class ArkConnection:
49
50    def __init__(self, conn: trio_cdp.CdpConnection, nursery: trio.Nursery) -> None:
51        self._conn = conn
52        self._nursery = nursery
53        self.context_id: runtime.ExecutionContextId
54
55    @property
56    def connection(self) -> trio_cdp.CdpConnection:
57        return self._conn
58
59    def listen(self, *event_types, buffer_size=10) -> trio.MemoryReceiveChannel:
60        return self._conn.listen(*event_types, buffer_size=buffer_size)
61
62    async def send(self, cmd: Generator[dict, T, Any]) -> T:
63        return await self._conn.execute(cmd)
64
65    @asynccontextmanager
66    async def wait_for(self, event_type: typing.Type[T], buffer_size=10) -> AsyncGenerator[Proxy[T], None]:
67        cmd_proxy: trio_cdp.CmEventProxy
68        proxy = Proxy[T]()
69        async with self._conn.wait_for(event_type, buffer_size) as cmd_proxy:
70            yield proxy
71        proxy.value = cmd_proxy.value  # type: ignore[attr-defined]
72
73    async def send_and_wait_for(
74        self,
75        cmd: Generator[dict, T, Any],
76        event_type: typing.Type[E],
77        buffer_size=10,
78    ) -> E:
79        async with self.wait_for(event_type, buffer_size) as proxy:
80            await self._conn.execute(cmd)
81        return proxy.value
82
83    def _listen(self, event_type: T, handler):
84        async def a():
85            async for event in self._conn.listen(event_type):
86                handler(event)
87
88        self._nursery.start_soon(a)
89
90
91class DebugConnection(trio_cdp.CdpConnection):
92
93    async def aclose(self):
94        await super().aclose()
95        self._close_channels()
96
97    async def reader_task(self):
98        try:
99            await super()._reader_task()
100        finally:
101            self._close_channels()
102
103    def _close_channels(self):
104        channels: set[trio.MemorySendChannel] = set([c for s in self.channels.values() for c in s])
105        self.channels.clear()
106        for ch in channels:
107            ch.close()
108
109
110async def connect_cdp(nursery: trio.Nursery, url, max_retries: int) -> DebugConnection:
111    counter = max_retries
112    while True:
113        try:
114            conn = DebugConnection(
115                await trio_cdp.connect_websocket_url(
116                    nursery,
117                    url,
118                    max_message_size=trio_cdp.MAX_WS_MESSAGE_SIZE,
119                )
120            )
121            nursery.start_soon(conn.reader_task)
122            return conn
123        except OSError:
124            counter -= 1
125            if counter == 0:
126                raise
127            await trio.sleep(1)
128
129
130class ScriptsCache:
131    def __init__(self) -> None:
132        self._lock = trio.Lock()
133        self._scripts: Dict[runtime.ScriptId, debugger.ScriptParsed] = {}
134
135    async def __getitem__(self, script_id: runtime.ScriptId) -> debugger.ScriptParsed:
136        async with self._lock:
137            if script_id not in self._scripts:
138                raise KeyError(script_id)
139            return self._scripts[script_id]
140
141    async def get(self, script_id: runtime.ScriptId) -> Optional[debugger.ScriptParsed]:
142        async with self._lock:
143            return self._scripts.get(script_id)
144
145    async def get_by_url(self, url: str) -> Optional[debugger.ScriptParsed]:
146        async with self._lock:
147            for s in self._scripts.values():
148                if url == s.url:
149                    return s
150            return None
151
152    async def add(self, *scripts: debugger.ScriptParsed) -> None:
153        async with self._lock:
154            for s in scripts:
155                self._scripts[s.script_id] = s
156
157
158class SourcesCache:
159
160    def __init__(
161        self,
162    ) -> None:
163        self._lock = trio.Lock()
164        self._scripts: Dict[runtime.ScriptId, str] = {}
165
166    async def get(
167        self,
168        script_id: runtime.ScriptId,
169        get_source: Optional[Callable[[runtime.ScriptId], Coroutine[Any, Any, str]]] = None,
170    ) -> str:
171        async with self._lock:
172            result = self._scripts.get(script_id)
173            if result is None:
174                if get_source is not None:
175                    result = await get_source(script_id)
176                    self._scripts[script_id] = result
177                else:
178                    raise IndexError()
179            return result
180