• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Moxing adapter for ModelArts"""
17
18import os
19import functools
20from mindspore.profiler import Profiler
21from .config import config
22
23_global_sync_count = 0
24
25def get_device_id():
26    device_id = os.getenv('DEVICE_ID', '0')
27    return int(device_id)
28
29
30def get_device_num():
31    device_num = os.getenv('RANK_SIZE', '1')
32    return int(device_num)
33
34
35def get_rank_id():
36    global_rank_id = os.getenv('RANK_ID', '0')
37    return int(global_rank_id)
38
39
40def get_job_id():
41    job_id = os.getenv('JOB_ID')
42    job_id = job_id if job_id != "" else "default"
43    return job_id
44
45def sync_data(from_path, to_path):
46    """
47    Download data from remote obs to local directory if the first url is remote url and the second one is local path
48    Upload data from local directory to remote obs in contrast.
49    """
50    import moxing as mox
51    import time
52    global _global_sync_count
53    sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
54    _global_sync_count += 1
55
56    # Each server contains 8 devices as most.
57    if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
58        print("from path: ", from_path)
59        print("to path: ", to_path)
60        mox.file.copy_parallel(from_path, to_path)
61        print("===finish data synchronization===")
62        try:
63            os.mknod(sync_lock)
64        except IOError:
65            pass
66        print("===save flag===")
67
68    while True:
69        if os.path.exists(sync_lock):
70            break
71        time.sleep(1)
72
73    print("Finish sync data from {} to {}.".format(from_path, to_path))
74
75
76def moxing_wrapper(pre_process=None, post_process=None):
77    """
78    Moxing wrapper to download dataset and upload outputs.
79    """
80    def wrapper(run_func):
81        @functools.wraps(run_func)
82        def wrapped_func(*args, **kwargs):
83            # Download data from data_url
84            if config.enable_modelarts:
85                if config.data_url:
86                    sync_data(config.data_url, config.data_path)
87                    print("Dataset downloaded: ", os.listdir(config.data_path))
88                if config.checkpoint_url:
89                    sync_data(config.checkpoint_url, config.load_path)
90                    print("Preload downloaded: ", os.listdir(config.load_path))
91                if config.train_url:
92                    sync_data(config.train_url, config.output_path)
93                    print("Workspace downloaded: ", os.listdir(config.output_path))
94
95                config.device_num = get_device_num()
96                config.device_id = get_device_id()
97                if not os.path.exists(config.output_path):
98                    os.makedirs(config.output_path)
99
100                if pre_process:
101                    pre_process()
102
103            if config.enable_profiling:
104                profiler = Profiler()
105
106            run_func(*args, **kwargs)
107
108            if config.enable_profiling:
109                profiler.analyse()
110
111            # Upload data to train_url
112            if config.enable_modelarts:
113                if post_process:
114                    post_process()
115
116                if config.train_url:
117                    print("Start to copy output directory")
118                    sync_data(config.output_path, config.train_url)
119        return wrapped_func
120    return wrapper
121