• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/node_check.h"
18 
19 #include <set>
20 #include <string>
21 
22 #include "frontend/parallel/ops_info/ops_utils.h"
23 #include "mindspore/core/ops/other_ops.h"
24 
25 namespace mindspore {
26 namespace parallel {
27 const std::set<std::string> BATCH_PARALLEL_BLACK_LIST = {STACK, TENSOR_SCATTER_UPDATE, MESHGRID};
28 
IsInBatchParallelBlackList(const PrimitivePtr & prim)29 bool IsInBatchParallelBlackList(const PrimitivePtr &prim) {
30   MS_EXCEPTION_IF_NULL(prim);
31   return (BATCH_PARALLEL_BLACK_LIST.find(prim->name()) != BATCH_PARALLEL_BLACK_LIST.end());
32 }
33 
IsFromParallelOptimizerRs(const AnfNodePtr & node)34 bool IsFromParallelOptimizerRs(const AnfNodePtr &node) {
35   if (!IsPrimitiveCNode(node, prim::kPrimReduceScatter)) {
36     return false;
37   }
38   auto prim = GetCNodePrimitive(node->cast<CNodePtr>());
39   if (prim->instance_name().find("grad_parallel_optimizer") == std::string::npos) {
40     return false;
41   }
42   return true;
43 }
44 
IsFromGradMirrorAR(const AnfNodePtr & node)45 bool IsFromGradMirrorAR(const AnfNodePtr &node) {
46   if (!IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
47     return false;
48   }
49   auto prim = GetCNodePrimitive(node->cast<CNodePtr>());
50   if (prim->instance_name().find("grad_mirror") == std::string::npos) {
51     return false;
52   }
53   return true;
54 }
55 }  // namespace parallel
56 }  // namespace mindspore
57