1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Moxing adapter for ModelArts""" 17 18import os 19import functools 20from mindspore.profiler import Profiler 21from .config import config 22 23_global_sync_count = 0 24 25def get_device_id(): 26 device_id = os.getenv('DEVICE_ID', '0') 27 return int(device_id) 28 29 30def get_device_num(): 31 device_num = os.getenv('RANK_SIZE', '1') 32 return int(device_num) 33 34 35def get_rank_id(): 36 global_rank_id = os.getenv('RANK_ID', '0') 37 return int(global_rank_id) 38 39 40def get_job_id(): 41 job_id = os.getenv('JOB_ID') 42 job_id = job_id if job_id != "" else "default" 43 return job_id 44 45def sync_data(from_path, to_path): 46 """ 47 Download data from remote obs to local directory if the first url is remote url and the second one is local path 48 Upload data from local directory to remote obs in contrast. 49 """ 50 import moxing as mox 51 import time 52 global _global_sync_count 53 sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count) 54 _global_sync_count += 1 55 56 # Each server contains 8 devices as most. 57 if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): 58 print("from path: ", from_path) 59 print("to path: ", to_path) 60 mox.file.copy_parallel(from_path, to_path) 61 print("===finish data synchronization===") 62 try: 63 os.mknod(sync_lock) 64 except IOError: 65 pass 66 print("===save flag===") 67 68 while True: 69 if os.path.exists(sync_lock): 70 break 71 time.sleep(1) 72 73 print("Finish sync data from {} to {}.".format(from_path, to_path)) 74 75 76def moxing_wrapper(pre_process=None, post_process=None): 77 """ 78 Moxing wrapper to download dataset and upload outputs. 79 """ 80 def wrapper(run_func): 81 @functools.wraps(run_func) 82 def wrapped_func(*args, **kwargs): 83 # Download data from data_url 84 if config.enable_modelarts: 85 if config.data_url: 86 sync_data(config.data_url, config.data_path) 87 print("Dataset downloaded: ", os.listdir(config.data_path)) 88 if config.checkpoint_url: 89 sync_data(config.checkpoint_url, config.load_path) 90 print("Preload downloaded: ", os.listdir(config.load_path)) 91 if config.train_url: 92 sync_data(config.train_url, config.output_path) 93 print("Workspace downloaded: ", os.listdir(config.output_path)) 94 95 config.device_num = get_device_num() 96 config.device_id = get_device_id() 97 if not os.path.exists(config.output_path): 98 os.makedirs(config.output_path) 99 100 if pre_process: 101 pre_process() 102 103 if config.enable_profiling: 104 profiler = Profiler() 105 106 run_func(*args, **kwargs) 107 108 if config.enable_profiling: 109 profiler.analyse() 110 111 # Upload data to train_url 112 if config.enable_modelarts: 113 if post_process: 114 post_process() 115 116 if config.train_url: 117 print("Start to copy output directory") 118 sync_data(config.output_path, config.train_url) 119 return wrapped_func 120 return wrapper 121