1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3""" 4Copyright (c) 2024 Huawei Device Co., Ltd. 5Licensed under the Apache License, Version 2.0 (the "License"); 6you may not use this file except in compliance with the License. 7You may obtain a copy of the License at 8 9 http://www.apache.org/licenses/LICENSE-2.0 10 11Unless required by applicable law or agreed to in writing, software 12distributed under the License is distributed on an "AS IS" BASIS, 13WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14See the License for the specific language governing permissions and 15limitations under the License. 16 17Description: Responsible for websocket communication. 18""" 19 20import asyncio 21import json 22import logging 23 24import websockets.protocol 25from websockets import connect, ConnectionClosed 26 27from aw.fport import Fport 28 29 30class WebSocket(object): 31 def __init__(self, connect_server_port, debugger_server_port): 32 self.connect_server_port = connect_server_port 33 self.debugger_server_port = debugger_server_port 34 self.debugger_server_connection_threshold = 3 35 36 self.to_send_msg_queue_for_connect_server = None 37 self.received_msg_queue_for_connect_server = None 38 39 self.to_send_msg_queues = {} # key: instance_id, value: to_send_msg_queue 40 self.received_msg_queues = {} # key: instance_id, value: received_msg_queue 41 self.debugger_server_instance = None 42 43 @staticmethod 44 async def recv_msg_of_debugger_server(instance_id, queue): 45 message = await queue.get() 46 queue.task_done() 47 logging.info(f'[<==] Instance {instance_id} receive message: {message}') 48 return message 49 50 @staticmethod 51 async def send_msg_to_debugger_server(instance_id, queue, message): 52 await queue.put(message) 53 logging.info(f'[==>] Instance {instance_id} send message: {message}') 54 return True 55 56 @staticmethod 57 async def _sender(client, send_queue): 58 assert client.state == websockets.protocol.OPEN, logging.error(f'Client state of _sender is: {client.state}') 59 while True: 60 send_message = await send_queue.get() 61 send_queue.task_done() 62 if send_message == 'close': 63 await client.close(reason='close') 64 return 65 await client.send(json.dumps(send_message)) 66 67 @staticmethod 68 async def _receiver(client, received_queue): 69 assert client.state == websockets.protocol.OPEN, logging.error(f'Client state of _receiver is: {client.state}') 70 while True: 71 try: 72 response = await client.recv() 73 await received_queue.put(response) 74 except ConnectionClosed: 75 logging.info('Debugger server connection closed') 76 return 77 78 async def get_instance(self): 79 instance_id = await self.debugger_server_instance.get() 80 self.debugger_server_instance.task_done() 81 return instance_id 82 83 async def recv_msg_of_connect_server(self): 84 message = await self.received_msg_queue_for_connect_server.get() 85 self.received_msg_queue_for_connect_server.task_done() 86 return message 87 88 async def send_msg_to_connect_server(self, message): 89 await self.to_send_msg_queue_for_connect_server.put(message) 90 logging.info(f'[==>] Connect server send message: {message}') 91 return True 92 93 async def main_task(self, taskpool, websocket, procedure, pid): 94 # the async queue must be initialized in task 95 self.to_send_msg_queue_for_connect_server = asyncio.Queue() 96 self.received_msg_queue_for_connect_server = asyncio.Queue() 97 self.debugger_server_instance = asyncio.Queue(maxsize=1) 98 99 connect_server_client = await self._connect_connect_server() 100 taskpool.submit(self._sender(connect_server_client, self.to_send_msg_queue_for_connect_server)) 101 taskpool.submit(self._receiver_of_connect_server(connect_server_client, 102 self.received_msg_queue_for_connect_server, 103 taskpool, pid)) 104 taskpool.submit(procedure(websocket)) 105 106 def _connect_connect_server(self): 107 client = connect(f'ws://localhost:{self.connect_server_port}', 108 open_timeout=10, 109 ping_interval=None) 110 return client 111 112 def _connect_debugger_server(self): 113 client = connect(f'ws://localhost:{self.debugger_server_port}', 114 open_timeout=6, 115 ping_interval=None) 116 return client 117 118 async def _receiver_of_connect_server(self, client, receive_queue, taskpool, pid): 119 assert client.state == websockets.protocol.OPEN, \ 120 logging.error(f'Client state of _receiver_of_connect_server is: {client.state}') 121 num_debugger_server_client = 0 122 while True: 123 try: 124 response = await client.recv() 125 await receive_queue.put(response) 126 logging.info(f'[<==] Connect server receive message: {response}') 127 response = json.loads(response) 128 129 # The debugger server client is only responsible for adding and removing instances 130 if (response['type'] == 'addInstance' and 131 num_debugger_server_client < self.debugger_server_connection_threshold): 132 instance_id = response['instanceId'] 133 134 Fport.fport_debugger_server(self.debugger_server_port, pid, instance_id) 135 debugger_server_client = await self._connect_debugger_server() 136 logging.info(f'InstanceId: {instance_id}, port: {self.debugger_server_port}, ' 137 f'debugger server connected') 138 self.debugger_server_port += 1 139 140 to_send_msg_queue = asyncio.Queue() 141 received_msg_queue = asyncio.Queue() 142 self.to_send_msg_queues[instance_id] = to_send_msg_queue 143 self.received_msg_queues[instance_id] = received_msg_queue 144 taskpool.submit(coroutine=self._sender(debugger_server_client, to_send_msg_queue)) 145 taskpool.submit(coroutine=self._receiver(debugger_server_client, received_msg_queue)) 146 147 await self._store_instance(instance_id) 148 num_debugger_server_client += 1 149 150 elif response['type'] == 'destroyInstance': 151 instance_id = response['instanceId'] 152 to_send_msg_queue = self.to_send_msg_queues[instance_id] 153 await self.send_msg_to_debugger_server(instance_id, to_send_msg_queue, 'close') 154 155 except ConnectionClosed: 156 logging.info('Connect server connection closed') 157 return 158 159 async def _store_instance(self, instance_id): 160 await self.debugger_server_instance.put(instance_id) 161 return True 162