• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""api definition"""
16import threading
17from mindspore.parallel._auto_parallel_context import auto_parallel_context
18
19
20class Hccl():
21    """Hccl definition"""
22    _instance_lock = threading.Lock()
23    _instance = None
24    _rank_id = 0
25    _rank_size = 1
26
27    def __init__(self):
28        pass
29
30    # pylint: disable=unused-argument
31    def __new__(cls, *args, **kwargs):
32        if not hasattr(Hccl, "_instance") or Hccl._instance is None:
33            with Hccl._instance_lock:
34                if not hasattr(Hccl,
35                               "_instance") or Hccl._instance is None:
36                    Hccl._instance = object.__new__(cls)
37                    Hccl._instance.__init__()
38        return Hccl._instance
39
40    @property
41    def rank_id(self):
42        return self._rank_id
43
44    @rank_id.setter
45    def rank_id(self, rank_id):
46        self._rank_id = rank_id
47
48    @property
49    def rank_size(self):
50        return self._rank_size
51
52    @rank_size.setter
53    def rank_size(self, size):
54        self._rank_size = size
55
56
57# pylint: disable=unused-argument
58def get_rank_id(group=None):
59    hccl = Hccl()
60    return hccl.rank_id
61
62
63def get_rank_size(group=None):
64    hccl = Hccl()
65    if group is None or "nccl_world_group" in group:
66        if auto_parallel_context().get_device_num_is_set() is False:
67            return 1
68        return auto_parallel_context().get_device_num()
69    if isinstance(group, str):
70        return int(group.split("-")[0])
71    raise ValueError
72
73
74# pylint: disable=unused-argument
75def get_world_rank_from_group_rank(group, group_rank_id):
76    return group_rank_id
77
78
79# pylint: disable=unused-argument
80def get_group_rank_from_world_rank(world_rank_id, group):
81    return world_rank_id
82
83
84# pylint: disable=unused-argument
85def create_group(group, rank_size, rank_ids):
86    pass
87
88
89# pylint: disable=unused-argument
90def destroy_group(group):
91    pass
92
93
94# pylint: disable=unused-argument
95def set_fusion_strategy_by_idx():
96    pass
97
98
99# pylint: disable=unused-argument
100def set_fusion_strategy_by_size():
101    pass
102