1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3""" 4Copyright (c) 2024 Huawei Device Co., Ltd. 5Licensed under the Apache License, Version 2.0 (the "License"); 6you may not use this file except in compliance with the License. 7You may obtain a copy of the License at 8 9 http://www.apache.org/licenses/LICENSE-2.0 10 11Unless required by applicable law or agreed to in writing, software 12distributed under the License is distributed on an "AS IS" BASIS, 13WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14See the License for the specific language governing permissions and 15limitations under the License. 16 17Description: Responsible for websocket communication. 18""" 19 20import asyncio 21import json 22 23from install_lib import install 24from fport import Fport 25 26install('websockets') 27import websockets.protocol 28from websockets import connect, ConnectionClosed 29 30 31class ToolchainWebSocket(object): 32 def __init__(self, driver, connect_server_port, debugger_server_port, print_protocol=True): 33 self.driver = driver 34 self.server_ip = driver._deivce.host 35 self.connect_server_port = connect_server_port 36 self.debugger_server_port = debugger_server_port 37 self.debugger_server_connection_threshold = 3 38 39 self.to_send_msg_queue_for_connect_server = None 40 self.received_msg_queue_for_connect_server = None 41 42 self.to_send_msg_queues = {} # key: instance_id, value: to_send_msg_queue 43 self.received_msg_queues = {} # key: instance_id, value: received_msg_queue 44 self.debugger_server_instance = None 45 self.new_instance_flag = None 46 self.log = self.driver.log_info if print_protocol else (lambda s: None) 47 48 async def recv_msg_of_debugger_server(self, instance_id, queue): 49 message = await queue.get() 50 queue.task_done() 51 self.log(f'[<==] Instance {instance_id} receive message: {message}') 52 return message 53 54 async def send_msg_to_debugger_server(self, instance_id, queue, message): 55 await queue.put(message) 56 self.log(f'[==>] Instance {instance_id} send message: {message}') 57 return True 58 59 async def get_instance(self): 60 instance_id = await self.debugger_server_instance.get() 61 self.debugger_server_instance.task_done() 62 return instance_id 63 64 def no_more_instance(self): 65 self.new_instance_flag = False 66 67 async def recv_msg_of_connect_server(self): 68 message = await self.received_msg_queue_for_connect_server.get() 69 self.received_msg_queue_for_connect_server.task_done() 70 return message 71 72 async def send_msg_to_connect_server(self, message): 73 await self.to_send_msg_queue_for_connect_server.put(message) 74 self.log(f'[==>] Connect server send message: {message}') 75 return True 76 77 async def main_task(self, taskpool, procedure, pid): 78 # the async queue must be initialized in task 79 self.to_send_msg_queue_for_connect_server = asyncio.Queue() 80 self.received_msg_queue_for_connect_server = asyncio.Queue() 81 self.debugger_server_instance = asyncio.Queue(maxsize=1) 82 83 connect_server_client = await self._connect_connect_server() 84 taskpool.submit(self._sender(connect_server_client, self.to_send_msg_queue_for_connect_server)) 85 taskpool.submit(self._receiver_of_connect_server(connect_server_client, 86 self.received_msg_queue_for_connect_server, 87 taskpool, pid)) 88 taskpool.submit(procedure(self)) 89 90 async def _sender(self, client, send_queue): 91 assert client.state == websockets.protocol.OPEN, f'Client state of _sender is: {client.state}' 92 while True: 93 send_message = await send_queue.get() 94 send_queue.task_done() 95 if send_message == 'close': 96 await client.close(reason='close') 97 return 98 await client.send(json.dumps(send_message)) 99 100 async def _receiver(self, client, received_queue): 101 assert client.state == websockets.protocol.OPEN, f'Client state of _receiver is: {client.state}' 102 while True: 103 try: 104 response = await client.recv() 105 await received_queue.put(response) 106 except ConnectionClosed: 107 self.log('Debugger server connection closed') 108 return 109 110 def _connect_connect_server(self): 111 client = connect(f'ws://{self.server_ip}:{self.connect_server_port}', 112 open_timeout=10, 113 ping_interval=None) 114 return client 115 116 def _connect_debugger_server(self): 117 client = connect(f'ws://{self.server_ip}:{self.debugger_server_port}', 118 open_timeout=6, 119 ping_interval=None) 120 return client 121 122 async def _receiver_of_connect_server(self, client, receive_queue, taskpool, pid): 123 assert client.state == websockets.protocol.OPEN, \ 124 f'Client state of _receiver_of_connect_server is: {client.state}' 125 num_debugger_server_client = 0 126 while True: 127 try: 128 response = await client.recv() 129 await receive_queue.put(response) 130 self.log(f'[<==] Connect server receive message: {response}') 131 response = json.loads(response) 132 133 # The debugger server client is only responsible for adding and removing instances 134 if (response['type'] == 'addInstance' and 135 num_debugger_server_client < self.debugger_server_connection_threshold): 136 instance_id = response['instanceId'] 137 138 port = Fport.fport_debugger_server(self.debugger_server_port, pid, instance_id) 139 assert port > 0, 'Failed to fport debugger server for 3 times, the port is very likely occupied' 140 self.debugger_server_port = port 141 debugger_server_client = await self._connect_debugger_server() 142 self.log(f'InstanceId: {instance_id}, port: {self.debugger_server_port}, ' 143 f'debugger server connected') 144 self.debugger_server_port += 1 145 146 to_send_msg_queue = asyncio.Queue() 147 received_msg_queue = asyncio.Queue() 148 self.to_send_msg_queues[instance_id] = to_send_msg_queue 149 self.received_msg_queues[instance_id] = received_msg_queue 150 taskpool.submit(coroutine=self._sender(debugger_server_client, to_send_msg_queue)) 151 taskpool.submit(coroutine=self._receiver(debugger_server_client, received_msg_queue)) 152 153 await self._store_instance(instance_id) 154 num_debugger_server_client += 1 155 156 elif response['type'] == 'destroyInstance': 157 instance_id = response['instanceId'] 158 to_send_msg_queue = self.to_send_msg_queues[instance_id] 159 await self.send_msg_to_debugger_server(instance_id, to_send_msg_queue, 'close') 160 num_debugger_server_client -= 1 161 162 except ConnectionClosed: 163 self.log('Connect server connection closed') 164 return 165 166 async def _store_instance(self, instance_id): 167 await self.debugger_server_instance.put(instance_id) 168 return True 169