• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GRAD_PARTIAL_TRANSFORM_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GRAD_PARTIAL_TRANSFORM_H_
19 
20 #include <vector>
21 #include <algorithm>
22 
23 #include "frontend/optimizer/irpass.h"
24 #include "mindspore/core/ops/array_ops.h"
25 #include "frontend/optimizer/optimizer.h"
26 #include "frontend/optimizer/anf_visitor.h"
27 #include "frontend/operator/ops.h"
28 #include "frontend/operator/composite/composite.h"
29 
30 namespace mindspore {
31 namespace opt {
32 namespace irpass {
33 // {S_Prim_grad, {UpackGraph, Partial{fg, args},}} -> {Partial{{S_Prim_grad, ...}, args}}
34 class GradPartialTransform : public AnfVisitor {
35  public:
operator()36   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
37     auto grad_cnode = dyn_cast<CNode>(node);
38     if (grad_cnode == nullptr || grad_cnode->inputs().empty()) {
39       MS_LOG(INTERNAL_EXCEPTION) << "GradPartialTransform encounter invalid node: " << node->DebugString();
40     }
41     const auto &value = GetCNodeValueWithoutDoSignature(grad_cnode);
42     if (value == nullptr || !value->isa<prim::GradOperation>()) {
43       return nullptr;
44     }
45     auto unpack_graph_node = grad_cnode->input(1);
46     auto prim = GetCNodePrimitive(unpack_graph_node);
47     if (prim == nullptr || !prim->isa<prim::UnpackGraphPrimitive>()) {
48       return nullptr;
49     }
50     auto unpack_graph_cnode = dyn_cast<CNode>(unpack_graph_node);
51     MS_EXCEPTION_IF_NULL(unpack_graph_cnode);
52     auto partial_node = unpack_graph_cnode->input(1);
53     if (!IsPrimitiveCNode(partial_node, prim::kPrimPartial)) {
54       return nullptr;
55     }
56     if (transformed_nodes_.count(node) != 0) {
57       return nullptr;
58     }
59     auto partial_cnode = dyn_cast<CNode>(partial_node);
60     MS_EXCEPTION_IF_NULL(partial_cnode);
61     const auto partial_value_node = NewValueNode(prim::kPrimPartial);
62     AnfNodeWeakPtrList inputs = {partial_value_node, node};
63     constexpr auto ignored_partial_input_count = 2;
64     (void)std::transform(partial_cnode->weak_inputs().cbegin() + ignored_partial_input_count,
65                          partial_cnode->weak_inputs().cend(), std::back_inserter(inputs),
66                          [](const AnfNodeWeakPtr &inp) { return inp; });
67 
68     auto new_node = grad_cnode->func_graph()->NewCNodeInOrderWeak(inputs);
69     (void)transformed_nodes_.emplace(node);
70     return new_node;
71   }
72 
73  private:
74   mindspore::HashSet<AnfNodePtr> transformed_nodes_;
75 };
76 }  // namespace irpass
77 }  // namespace opt
78 }  // namespace mindspore
79 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GRAD_PARTIAL_TRANSFORM_H_
80