1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3# 4# Copyright (c) 2024 Huawei Device Co., Ltd. 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18import typing 19from contextlib import asynccontextmanager 20from typing import Any, AsyncGenerator, Callable, Coroutine, Dict, Generator, Optional, TypeAlias 21 22import cdp 23import cdp.util 24import trio 25import trio_cdp 26from cdp import debugger, runtime 27 28T: TypeAlias = trio_cdp.T 29E = typing.TypeVar("E") 30T_JSON_DICT: TypeAlias = cdp.util.T_JSON_DICT 31 32 33class Proxy(typing.Generic[E]): 34 def __init__(self) -> None: 35 self._value: Optional[E] = None 36 37 @property 38 def value(self) -> E: 39 if self._value is None: 40 raise ValueError() 41 return self._value 42 43 @value.setter 44 def value(self, value: E): 45 self._value = value 46 47 48class ArkConnection: 49 50 def __init__(self, conn: trio_cdp.CdpConnection, nursery: trio.Nursery) -> None: 51 self._conn = conn 52 self._nursery = nursery 53 self.context_id: runtime.ExecutionContextId 54 55 @property 56 def connection(self) -> trio_cdp.CdpConnection: 57 return self._conn 58 59 def listen(self, *event_types, buffer_size=10) -> trio.MemoryReceiveChannel: 60 return self._conn.listen(*event_types, buffer_size=buffer_size) 61 62 async def send(self, cmd: Generator[dict, T, Any]) -> T: 63 return await self._conn.execute(cmd) 64 65 @asynccontextmanager 66 async def wait_for(self, event_type: typing.Type[T], buffer_size=10) -> AsyncGenerator[Proxy[T], None]: 67 cmd_proxy: trio_cdp.CmEventProxy 68 proxy = Proxy[T]() 69 async with self._conn.wait_for(event_type, buffer_size) as cmd_proxy: 70 yield proxy 71 proxy.value = cmd_proxy.value # type: ignore[attr-defined] 72 73 async def send_and_wait_for( 74 self, 75 cmd: Generator[dict, T, Any], 76 event_type: typing.Type[E], 77 buffer_size=10, 78 ) -> E: 79 async with self.wait_for(event_type, buffer_size) as proxy: 80 await self._conn.execute(cmd) 81 return proxy.value 82 83 def _listen(self, event_type: T, handler): 84 async def a(): 85 async for event in self._conn.listen(event_type): 86 handler(event) 87 88 self._nursery.start_soon(a) 89 90 91class DebugConnection(trio_cdp.CdpConnection): 92 93 async def aclose(self): 94 await super().aclose() 95 self._close_channels() 96 97 async def reader_task(self): 98 try: 99 await super()._reader_task() 100 finally: 101 self._close_channels() 102 103 def _close_channels(self): 104 channels: set[trio.MemorySendChannel] = set([c for s in self.channels.values() for c in s]) 105 self.channels.clear() 106 for ch in channels: 107 ch.close() 108 109 110async def connect_cdp(nursery: trio.Nursery, url, max_retries: int) -> DebugConnection: 111 counter = max_retries 112 while True: 113 try: 114 conn = DebugConnection( 115 await trio_cdp.connect_websocket_url( 116 nursery, 117 url, 118 max_message_size=trio_cdp.MAX_WS_MESSAGE_SIZE, 119 ) 120 ) 121 nursery.start_soon(conn.reader_task) 122 return conn 123 except OSError: 124 counter -= 1 125 if counter == 0: 126 raise 127 await trio.sleep(1) 128 129 130class ScriptsCache: 131 def __init__(self) -> None: 132 self._lock = trio.Lock() 133 self._scripts: Dict[runtime.ScriptId, debugger.ScriptParsed] = {} 134 135 async def __getitem__(self, script_id: runtime.ScriptId) -> debugger.ScriptParsed: 136 async with self._lock: 137 if script_id not in self._scripts: 138 raise KeyError(script_id) 139 return self._scripts[script_id] 140 141 async def get(self, script_id: runtime.ScriptId) -> Optional[debugger.ScriptParsed]: 142 async with self._lock: 143 return self._scripts.get(script_id) 144 145 async def get_by_url(self, url: str) -> Optional[debugger.ScriptParsed]: 146 async with self._lock: 147 for s in self._scripts.values(): 148 if url == s.url: 149 return s 150 return None 151 152 async def add(self, *scripts: debugger.ScriptParsed) -> None: 153 async with self._lock: 154 for s in scripts: 155 self._scripts[s.script_id] = s 156 157 158class SourcesCache: 159 160 def __init__( 161 self, 162 ) -> None: 163 self._lock = trio.Lock() 164 self._scripts: Dict[runtime.ScriptId, str] = {} 165 166 async def get( 167 self, 168 script_id: runtime.ScriptId, 169 get_source: Optional[Callable[[runtime.ScriptId], Coroutine[Any, Any, str]]] = None, 170 ) -> str: 171 async with self._lock: 172 result = self._scripts.get(script_id) 173 if result is None: 174 if get_source is not None: 175 result = await get_source(script_id) 176 self._scripts[script_id] = result 177 else: 178 raise IndexError() 179 return result 180