/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "frontend/parallel/step_parallel_utils.h" #include #include #include #include #include #include #include #include #include "base/core_ops.h" #include "frontend/operator/ops.h" #include "frontend/optimizer/optimizer.h" #include "frontend/parallel/context.h" #include "frontend/parallel/device_manager.h" #include "frontend/parallel/graph_util/generate_graph.h" #include "frontend/parallel/graph_util/graph_info.h" #include "frontend/parallel/graph_util/node_info.h" #include "frontend/parallel/node_check.h" #include "ir/param_info.h" #include "ir/tensor.h" #include "utils/trace_base.h" #include "utils/comm_manager.h" #include "utils/ms_context.h" #include "utils/symbolic.h" #include "mindspore/core/utils/parallel_node_check.h" namespace mindspore { namespace parallel { bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { if (!cnode) { return false; } ValueNodePtr anf_node = cnode->input(0)->cast(); MS_EXCEPTION_IF_NULL(anf_node); PrimitivePtr prim = anf_node->value()->cast(); return (prim->name() == name); } bool IsParallelCareNode(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); ValueNodePtr prim_node = cnode->input(0)->cast(); if (prim_node == nullptr) { return false; } PrimitivePtr prim = prim_node->value()->cast(); if (prim == nullptr) { return false; } if (IsInParallelBlackList(prim)) { MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name(); return false; } // get_next is not in the forward graph, we need mark the get_next as the forward node if (prim->name() == GET_NEXT || prim->name() == VIRTUAL_OUTPUT) { return true; } if ((prim->name() == CAST) && !cnode->has_user_data()) { return false; } return cnode->in_forward_flag(); } Shapes GetValueListShape(const AnfNodePtr &node) { Shapes shapes; std::vector inputs_seq; if (IsValueNode(node)) { inputs_seq = node->cast()->value()->cast()->value(); } else if (IsValueNode(node)) { inputs_seq = node->cast()->value()->cast()->value(); } else { MS_LOG(EXCEPTION) << "node is eigther ValueList or ValueTuple"; } for (auto &ele : inputs_seq) { auto tensor = ele->cast(); if (tensor == nullptr) { MS_LOG(WARNING) << "The value node is not a tensor"; break; } auto one_shape = tensor->shape(); shapes.push_back(one_shape); } return shapes; } Shapes GetNodeShape(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); Shapes shapes; if (IsValueNode(node) || IsValueNode(node)) { return GetValueListShape(node); } BaseShapePtr base_shape_ptr = node->Shape(); if (node->isa()) { auto cnode = node->cast(); if (IsValueNode(cnode->input(0))) { PrimitivePtr prim = GetValueNode(cnode->input(0)); MS_EXCEPTION_IF_NULL(prim); if (prim->name() == MAKEREF) { AnfNodePtr ref_node = cnode->input(1); auto func_graph = cnode->func_graph(); MS_EXCEPTION_IF_NULL(ref_node); MS_EXCEPTION_IF_NULL(func_graph); return GetRefKeyNodeShape(ref_node, func_graph); } } if (cnode->input(0)->isa()) { if (cnode->inputs().size() < 2) { MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is smaller than 2"; } base_shape_ptr = cnode->input(1)->Shape(); } } if (base_shape_ptr == nullptr) { MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is " << node->fullname_with_scope(); } auto tuple_shape_ptr = dyn_cast(base_shape_ptr); if (tuple_shape_ptr != nullptr) { auto tuple_shape = tuple_shape_ptr->shape(); for (auto &shape : tuple_shape) { auto each_shape = dyn_cast(shape); MS_EXCEPTION_IF_NULL(each_shape); shapes.push_back(each_shape->shape()); } } else { auto shape_ptr = dyn_cast(base_shape_ptr); MS_EXCEPTION_IF_NULL(shape_ptr); shapes.push_back(shape_ptr->shape()); } return shapes; } std::string CreateInstanceName(const CNodePtr &node, size_t index) { MS_EXCEPTION_IF_NULL(node); if (!IsValueNode(node->input(0))) { MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive"; } std::string name_base = node->fullname_with_scope(); std::string name = name_base + "_" + std::to_string(index); std::string instance_name = HashInstanceName(name); return instance_name; } void SetCommunicationOpGroupLabel(std::vector new_node_input) { if (new_node_input.empty()) { return; } auto prim_anf_node = new_node_input[0]->cast(); auto prim = GetValueNode(prim_anf_node); MS_EXCEPTION_IF_NULL(prim); auto attrs = prim->attrs(); auto iter = attrs.find(GROUP); if (iter != attrs.end()) { auto value = iter->second; MS_EXCEPTION_IF_NULL(value); if (value->isa()) { std::string hash_name = value->cast()->value(); MS_EXCEPTION_IF_NULL(g_device_manager); std::string rank_list_name = g_device_manager->FindRankListNameByHashName(hash_name); (void)prim->AddAttr(GROUP_RANKS, MakeValue(rank_list_name)); } } } std::vector ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, const CNodePtr &node) { OperatorArgs arg_replace_op = replace_op.second; ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name); if (pyop_instance == nullptr) { MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreatOpInstance failed"; } OperatorParams params = arg_replace_op.second; if (node->inputs().size() < 2) { // GetNext operator dose not has input if (node->inputs().size() == 1) { return {NewValueNode(pyop_instance)}; } MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2"; } std::vector replace_input = {NewValueNode(pyop_instance), node->input(1)}; if (replace_op.first == EMBEDDING_LOOKUP) { replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; } if (!params.empty()) { Param param_first = *(params.begin()); int64_t first_position = param_first.second; if (first_position == 1) { replace_input.pop_back(); } for (auto ¶m : params) { AnfNodePtr val = NewValueNode(param.first.second); if (val == nullptr) { MS_LOG(EXCEPTION) << "Failure:val is nullptr"; } int64_t position = param.second; (void)replace_input.insert(replace_input.begin() + position, val); } } else if (replace_op.first == SYNC_BATCH_NORM) { for (size_t i = 2; i < node->inputs().size(); ++i) { replace_input.push_back(node->input(i)); } } SetCommunicationOpGroupLabel(replace_input); return replace_input; } } // namespace parallel } // namespace mindspore