1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GRAD_PARTIAL_TRANSFORM_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GRAD_PARTIAL_TRANSFORM_H_ 19 20 #include <vector> 21 #include <algorithm> 22 23 #include "frontend/optimizer/irpass.h" 24 #include "mindspore/core/ops/array_ops.h" 25 #include "frontend/optimizer/optimizer.h" 26 #include "frontend/optimizer/anf_visitor.h" 27 #include "frontend/operator/ops.h" 28 #include "frontend/operator/composite/composite.h" 29 30 namespace mindspore { 31 namespace opt { 32 namespace irpass { 33 // {S_Prim_grad, {UpackGraph, Partial{fg, args},}} -> {Partial{{S_Prim_grad, ...}, args}} 34 class GradPartialTransform : public AnfVisitor { 35 public: operator()36 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 37 auto grad_cnode = dyn_cast<CNode>(node); 38 if (grad_cnode == nullptr || grad_cnode->inputs().empty()) { 39 MS_LOG(INTERNAL_EXCEPTION) << "GradPartialTransform encounter invalid node: " << node->DebugString(); 40 } 41 const auto &value = GetCNodeValueWithoutDoSignature(grad_cnode); 42 if (value == nullptr || !value->isa<prim::GradOperation>()) { 43 return nullptr; 44 } 45 auto unpack_graph_node = grad_cnode->input(1); 46 auto prim = GetCNodePrimitive(unpack_graph_node); 47 if (prim == nullptr || !prim->isa<prim::UnpackGraphPrimitive>()) { 48 return nullptr; 49 } 50 auto unpack_graph_cnode = dyn_cast<CNode>(unpack_graph_node); 51 MS_EXCEPTION_IF_NULL(unpack_graph_cnode); 52 auto partial_node = unpack_graph_cnode->input(1); 53 if (!IsPrimitiveCNode(partial_node, prim::kPrimPartial)) { 54 return nullptr; 55 } 56 if (transformed_nodes_.count(node) != 0) { 57 return nullptr; 58 } 59 auto partial_cnode = dyn_cast<CNode>(partial_node); 60 MS_EXCEPTION_IF_NULL(partial_cnode); 61 const auto partial_value_node = NewValueNode(prim::kPrimPartial); 62 AnfNodeWeakPtrList inputs = {partial_value_node, node}; 63 constexpr auto ignored_partial_input_count = 2; 64 (void)std::transform(partial_cnode->weak_inputs().cbegin() + ignored_partial_input_count, 65 partial_cnode->weak_inputs().cend(), std::back_inserter(inputs), 66 [](const AnfNodeWeakPtr &inp) { return inp; }); 67 68 auto new_node = grad_cnode->func_graph()->NewCNodeInOrderWeak(inputs); 69 (void)transformed_nodes_.emplace(node); 70 return new_node; 71 } 72 73 private: 74 mindspore::HashSet<AnfNodePtr> transformed_nodes_; 75 }; 76 } // namespace irpass 77 } // namespace opt 78 } // namespace mindspore 79 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GRAD_PARTIAL_TRANSFORM_H_ 80