1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <iostream> 17 #include <unordered_map> 18 19 #include "frontend/optimizer/ad/kpynative.h" 20 #include "common/common_test.h" 21 #include "common/py_func_graph_fetcher.h" 22 #include "ir/manager.h" 23 #include "ir/value.h" 24 #include "ir/func_graph_cloner.h" 25 #include "utils/log_adapter.h" 26 #include "ir/graph_utils.h" 27 #include "pipeline/jit/resource.h" 28 #include "pipeline/jit/parse/parse.h" 29 #include "debug/anf_ir_utils.h" 30 #include "frontend/operator/ops.h" 31 32 namespace mindspore { 33 namespace ad { 34 class TestKPynative : public UT::Common { 35 public: 36 pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>(); 37 38 protected: 39 AbstractBasePtr BuildArg() { 40 std::vector<int64_t> shp = {2, 2}; 41 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp); 42 auto abstract = tensor->ToAbstract(); 43 return abstract; 44 } 45 46 FuncGraphPtr BuildPrimalFuncGraph(const std::string &testCase) { 47 auto g = std::make_shared<FuncGraph>(); 48 auto x = g->add_parameter(); 49 auto y = g->add_parameter(); 50 x->set_abstract(BuildArg()); 51 y->set_abstract(BuildArg()); 52 auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y}); 53 c_node->set_abstract(BuildArg()); 54 g->set_output(c_node); 55 return g; 56 } 57 58 // a = x * y 59 // b = stop_gradient(a) 60 // c = b * y 61 // return c 62 FuncGraphPtr BuildStopGradient(const std::string &testCase) { 63 auto g = std::make_shared<FuncGraph>(); 64 auto x = g->add_parameter(); 65 x->debug_info()->set_name("x"); 66 auto y = g->add_parameter(); 67 y->debug_info()->set_name("y"); 68 x->set_abstract(BuildArg()); 69 y->set_abstract(BuildArg()); 70 auto a_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y}); 71 a_node->set_abstract(BuildArg()); 72 auto b_node = g->NewCNode({NewValueNode(prim::kPrimStopGradient), a_node}); 73 b_node->set_abstract(BuildArg()); 74 auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), b_node, y}); 75 c_node->set_abstract(BuildArg()); 76 auto d_node = 77 g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), a_node, c_node}); 78 d_node->set_abstract(BuildArg()); 79 g->set_output(d_node); 80 return g; 81 } 82 83 FuncGraphPtr BuildBpropFuncGraph(const FuncGraphPtr &primal_fg) { 84 auto input_params = primal_fg->parameters(); 85 std::vector<ValuePtr> input_param_values; 86 std::for_each(input_params.begin(), input_params.end(), 87 [&](const AnfNodePtr ¶m) { input_param_values.emplace_back(param->abstract()->BuildValue()); }); 88 auto k_pynative_cell = GradPynativeCellBegin(input_params, input_param_values); 89 auto node_list = TopoSort(primal_fg->output()); 90 for (auto node : node_list) { 91 if (node->isa<CNode>()) { 92 auto c_node = node->cast<CNodePtr>(); 93 auto out = c_node->abstract()->GetValueTrack(); 94 ValuePtrList args; 95 for (size_t i = 1; i < c_node->inputs().size(); ++i) { 96 args.push_back(c_node->input(i)->abstract()->GetValueTrack()); 97 } 98 GradPynativeOp(k_pynative_cell, c_node, args, out); 99 } 100 } 101 auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, true, false, false, true); 102 return bprop_fg; 103 } 104 }; 105 106 TEST_F(TestKPynative, test_simple_add) { 107 auto primal_fg = BuildPrimalFuncGraph("test_simple_add"); 108 resource->manager()->KeepRoots({primal_fg}); 109 110 auto bprop_fg = BuildBpropFuncGraph(primal_fg); 111 resource->manager()->KeepRoots({bprop_fg}); 112 } 113 114 TEST_F(TestKPynative, test_stop_gradient) { 115 auto primal_fg = BuildStopGradient("test_stop_gradient"); 116 resource->manager()->KeepRoots({primal_fg}); 117 118 auto bprop_fg = BuildBpropFuncGraph(primal_fg); 119 resource->manager()->KeepRoots({bprop_fg}); 120 } 121 } // namespace ad 122 } // namespace mindspore 123