• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4Copyright (c) 2024 Huawei Device Co., Ltd.
5Licensed under the Apache License, Version 2.0 (the "License");
6you may not use this file except in compliance with the License.
7You may obtain a copy of the License at
8
9    http://www.apache.org/licenses/LICENSE-2.0
10
11Unless required by applicable law or agreed to in writing, software
12distributed under the License is distributed on an "AS IS" BASIS,
13WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14See the License for the specific language governing permissions and
15limitations under the License.
16
17Description: Responsible for websocket communication.
18"""
19
20import asyncio
21import json
22
23from install_lib import install
24from fport import Fport
25
26install('websockets')
27import websockets.protocol
28from websockets import connect, ConnectionClosed
29
30
31class ToolchainWebSocket(object):
32    def __init__(self, driver, connect_server_port, debugger_server_port, print_protocol=True):
33        self.driver = driver
34        self.server_ip = driver._deivce.host
35        self.connect_server_port = connect_server_port
36        self.debugger_server_port = debugger_server_port
37        self.debugger_server_connection_threshold = 3
38
39        self.to_send_msg_queue_for_connect_server = None
40        self.received_msg_queue_for_connect_server = None
41
42        self.to_send_msg_queues = {}  # key: instance_id, value: to_send_msg_queue
43        self.received_msg_queues = {}  # key: instance_id, value: received_msg_queue
44        self.debugger_server_instance = None
45        self.new_instance_flag = None
46        self.log = self.driver.log_info if print_protocol else (lambda s: None)
47
48    async def recv_msg_of_debugger_server(self, instance_id, queue):
49        message = await queue.get()
50        queue.task_done()
51        self.log(f'[<==] Instance {instance_id} receive message: {message}')
52        return message
53
54    async def send_msg_to_debugger_server(self, instance_id, queue, message):
55        await queue.put(message)
56        self.log(f'[==>] Instance {instance_id} send message: {message}')
57        return True
58
59    async def get_instance(self):
60        instance_id = await self.debugger_server_instance.get()
61        self.debugger_server_instance.task_done()
62        return instance_id
63
64    def no_more_instance(self):
65        self.new_instance_flag = False
66
67    async def recv_msg_of_connect_server(self):
68        message = await self.received_msg_queue_for_connect_server.get()
69        self.received_msg_queue_for_connect_server.task_done()
70        return message
71
72    async def send_msg_to_connect_server(self, message):
73        await self.to_send_msg_queue_for_connect_server.put(message)
74        self.log(f'[==>] Connect server send message: {message}')
75        return True
76
77    async def main_task(self, taskpool, procedure, pid):
78        # the async queue must be initialized in task
79        self.to_send_msg_queue_for_connect_server = asyncio.Queue()
80        self.received_msg_queue_for_connect_server = asyncio.Queue()
81        self.debugger_server_instance = asyncio.Queue(maxsize=1)
82
83        connect_server_client = await self._connect_connect_server()
84        taskpool.submit(self._sender(connect_server_client, self.to_send_msg_queue_for_connect_server))
85        taskpool.submit(self._receiver_of_connect_server(connect_server_client,
86                                                         self.received_msg_queue_for_connect_server,
87                                                         taskpool, pid))
88        taskpool.submit(procedure(self))
89
90    async def _sender(self, client, send_queue):
91        assert client.state == websockets.protocol.OPEN, f'Client state of _sender is: {client.state}'
92        while True:
93            send_message = await send_queue.get()
94            send_queue.task_done()
95            if send_message == 'close':
96                await client.close(reason='close')
97                return
98            await client.send(json.dumps(send_message))
99
100    async def _receiver(self, client, received_queue):
101        assert client.state == websockets.protocol.OPEN, f'Client state of _receiver is: {client.state}'
102        while True:
103            try:
104                response = await client.recv()
105                await received_queue.put(response)
106            except ConnectionClosed:
107                self.log('Debugger server connection closed')
108                return
109
110    def _connect_connect_server(self):
111        client = connect(f'ws://{self.server_ip}:{self.connect_server_port}',
112                         open_timeout=10,
113                         ping_interval=None)
114        return client
115
116    def _connect_debugger_server(self):
117        client = connect(f'ws://{self.server_ip}:{self.debugger_server_port}',
118                         open_timeout=6,
119                         ping_interval=None)
120        return client
121
122    async def _receiver_of_connect_server(self, client, receive_queue, taskpool, pid):
123        assert client.state == websockets.protocol.OPEN, \
124            f'Client state of _receiver_of_connect_server is: {client.state}'
125        num_debugger_server_client = 0
126        while True:
127            try:
128                response = await client.recv()
129                await receive_queue.put(response)
130                self.log(f'[<==] Connect server receive message: {response}')
131                response = json.loads(response)
132
133                # The debugger server client is only responsible for adding and removing instances
134                if (response['type'] == 'addInstance' and
135                        num_debugger_server_client < self.debugger_server_connection_threshold):
136                    instance_id = response['instanceId']
137
138                    port = Fport.fport_debugger_server(self.debugger_server_port, pid, instance_id)
139                    assert port > 0, 'Failed to fport debugger server for 3 times, the port is very likely occupied'
140                    self.debugger_server_port = port
141                    debugger_server_client = await self._connect_debugger_server()
142                    self.log(f'InstanceId: {instance_id}, port: {self.debugger_server_port}, '
143                             f'debugger server connected')
144                    self.debugger_server_port += 1
145
146                    to_send_msg_queue = asyncio.Queue()
147                    received_msg_queue = asyncio.Queue()
148                    self.to_send_msg_queues[instance_id] = to_send_msg_queue
149                    self.received_msg_queues[instance_id] = received_msg_queue
150                    taskpool.submit(coroutine=self._sender(debugger_server_client, to_send_msg_queue))
151                    taskpool.submit(coroutine=self._receiver(debugger_server_client, received_msg_queue))
152
153                    await self._store_instance(instance_id)
154                    num_debugger_server_client += 1
155
156                elif response['type'] == 'destroyInstance':
157                    instance_id = response['instanceId']
158                    to_send_msg_queue = self.to_send_msg_queues[instance_id]
159                    await self.send_msg_to_debugger_server(instance_id, to_send_msg_queue, 'close')
160                    num_debugger_server_client -= 1
161
162            except ConnectionClosed:
163                self.log('Connect server connection closed')
164                return
165
166    async def _store_instance(self, instance_id):
167        await self.debugger_server_instance.put(instance_id)
168        return True
169