# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """kernel build server""" import os from mindspore import log as logger from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process class Messager: '''Messager''' def __init__(self, fdin, fdout): self.fdin = fdin self.fdout = fdout self.fin = os.fdopen(fdin, "r") self.fout = os.fdopen(fdout, "w") self.message = '' def __del__(self): os.close(self.fdin) os.close(self.fdout) def get_message(self): """ Get message from remote Returns: message """ try: # Not read by input() anymore res = self.fin.readline() if not res: logger.debug("[TRACE] read nothing...") self.exit() if res[len(res) - 1] == '\n': res = res[0:len(res) - 1] self.message = res logger.debug(f"[IN] {self.message}") except (EOFError, KeyboardInterrupt): self.exit() finally: pass if self.message == '' or self.message == 'FINISH': self.send_ack() self.exit() return self.message def send_res(self, res, keep_format=True): """ Send result to remote Args: keep_format: True or False """ logger.debug(f"[OUT] {str(res)}") if keep_format: res_str = str(res).replace('\n', '[LF]').replace('\r', '[CR]').replace(' ', '[SP]') else: res_str = str(res).replace('\n', '').replace('\r', '').replace(' ', '') tag = '[~]' # The same as client kTAG # Not write by print(tag + res_str, flush=True) any more try: self.fout.write(tag + res_str + "\n") self.fout.flush() except BrokenPipeError as err: logger.info(f"[TRACE] Write {str(err)}") self.exit() finally: pass def send_ack(self, success=True): """ Send ack to remote Args: success: True or False """ if success: self.send_res('ACK') else: self.send_res('ERR') def loop(self): """ Messaging loop """ while True: self.handle() def run(self): self.loop() def handle(self): """ A interface communicates with remote. Note: All subclasses should override this interface. """ raise NotImplementedError def exit(self): """ A interface handles the procedure before exit. Note: All subclasses should override this interface. """ raise NotImplementedError class AkgBuilder(): """Akg building wrapper""" def __init__(self, platform): self.platform = platform self.attrs = None def create(self, process_num, waitime): """ Create akg processor""" self.akg_processor = create_akg_parallel_process(process_num, waitime, self.platform) def accept_json(self, json): """ Accept json""" return self.akg_processor.accept_json(json) def compile(self): """Compile""" return self.akg_processor.compile(self.attrs) def handle(self, messager, arg): """Handle message about akg""" if arg == 'AKG/START': messager.send_ack() process_num_str = messager.get_message() messager.send_ack() wait_time_str = messager.get_message() messager.send_ack() self.create(int(process_num_str), int(wait_time_str)) elif arg == 'AKG/ATTR': messager.send_ack() self.attrs = messager.get_message() messager.send_ack() elif arg == 'AKG/DATA': messager.send_ack() while True: req = messager.get_message() if req.startswith('{'): self.accept_json(req) messager.send_ack() elif req == 'AKG/WAIT': res = self.compile() messager.send_res(res) break else: messager.send_ack(False) break else: raise RuntimeError("Unknown message type: %s" % arg) def get_logger(): return logger