• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "frontend/parallel/silent_check/silent_check.h"
17 #include "ir/graph_utils.h"
18 #include "ir/func_graph.h"
19 #include "ops/other_ops.h"
20 
21 namespace mindspore {
22 namespace parallel {
GetLossScale()23 void SilentCheck::GetLossScale() {
24   MS_EXCEPTION_IF_NULL(root_);
25   auto parameters = root_->parameters();
26   for (const auto &param : parameters) {
27     auto param_ptr = param->cast<ParameterPtr>();
28     MS_EXCEPTION_IF_NULL(param_ptr);
29     const auto &name = param_ptr->name();
30     if (name == kScale_Sense) {
31       loss_scale_ = param;
32     }
33   }
34 }
35 
ModifySilentCheckOps()36 void SilentCheck::ModifySilentCheckOps() {
37   MS_EXCEPTION_IF_NULL(root_);
38   auto ret = root_->get_return();
39   MS_EXCEPTION_IF_NULL(ret);
40   MS_EXCEPTION_IF_NULL(mng_);
41   const auto &all_nodes = DeepScopedGraphSearch(ret);
42   for (const auto &node : all_nodes) {
43     if (node && !IsPrimitiveCNode(node, prim::kPrimMirrorSilentCheck)) {
44       continue;
45     }
46     auto cnode = node->cast<CNodePtr>();
47     if (loss_scale_ != nullptr) {
48       mng_->SetEdge(cnode, LOSS_SCALE_INDEX, loss_scale_);
49     }
50   }
51 }
52 }  // namespace parallel
53 }  // namespace mindspore
54