1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <iostream>
17 #include <unordered_map>
18
19 #include "frontend/optimizer/ad/kpynative.h"
20 #include "common/common_test.h"
21 #include "common/py_func_graph_fetcher.h"
22 #include "ir/manager.h"
23 #include "ir/value.h"
24 #include "ir/func_graph_cloner.h"
25 #include "utils/log_adapter.h"
26 #include "ir/graph_utils.h"
27 #include "pipeline/jit/resource.h"
28 #include "pipeline/jit/parse/parse.h"
29 #include "debug/anf_ir_utils.h"
30 #include "frontend/operator/ops.h"
31
32 namespace mindspore {
33 namespace ad {
34 class TestKPynative : public UT::Common {
35 public:
36 pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
37
38 protected:
BuildArg()39 AbstractBasePtr BuildArg() {
40 std::vector<int64_t> shp = {2, 2};
41 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp);
42 auto abstract = tensor->ToAbstract();
43 return abstract;
44 }
45
BuildPrimalFuncGraph(const std::string & testCase)46 FuncGraphPtr BuildPrimalFuncGraph(const std::string &testCase) {
47 auto g = std::make_shared<FuncGraph>();
48 auto x = g->add_parameter();
49 auto y = g->add_parameter();
50 x->set_abstract(BuildArg());
51 y->set_abstract(BuildArg());
52 auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y});
53 c_node->set_abstract(BuildArg());
54 g->set_output(c_node);
55 return g;
56 }
57
58 // a = x * y
59 // b = stop_gradient(a)
60 // c = b * y
61 // return c
BuildStopGradient(const std::string & testCase)62 FuncGraphPtr BuildStopGradient(const std::string &testCase) {
63 auto g = std::make_shared<FuncGraph>();
64 auto x = g->add_parameter();
65 x->debug_info()->set_name("x");
66 auto y = g->add_parameter();
67 y->debug_info()->set_name("y");
68 x->set_abstract(BuildArg());
69 y->set_abstract(BuildArg());
70 auto a_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y});
71 a_node->set_abstract(BuildArg());
72 auto b_node = g->NewCNode({NewValueNode(prim::kPrimStopGradient), a_node});
73 b_node->set_abstract(BuildArg());
74 auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), b_node, y});
75 c_node->set_abstract(BuildArg());
76 auto d_node =
77 g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), a_node, c_node});
78 d_node->set_abstract(BuildArg());
79 g->set_output(d_node);
80 return g;
81 }
82
BuildBpropFuncGraph(const FuncGraphPtr & primal_fg)83 FuncGraphPtr BuildBpropFuncGraph(const FuncGraphPtr &primal_fg) {
84 auto input_params = primal_fg->parameters();
85 std::vector<ValuePtr> input_param_values;
86 std::for_each(input_params.begin(), input_params.end(),
87 [&](const AnfNodePtr ¶m) { input_param_values.emplace_back(param->abstract()->BuildValue()); });
88 auto k_pynative_cell = GradPynativeCellBegin(input_params, input_param_values);
89 auto node_list = TopoSort(primal_fg->output());
90 for (auto node : node_list) {
91 if (node->isa<CNode>()) {
92 auto c_node = node->cast<CNodePtr>();
93 auto out = c_node->abstract()->GetValueTrack();
94 ValuePtrList args;
95 for (size_t i = 1; i < c_node->inputs().size(); ++i) {
96 args.push_back(c_node->input(i)->abstract()->GetValueTrack());
97 }
98 GradPynativeOp(k_pynative_cell, c_node, args, out);
99 }
100 }
101 auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, true, false, false, true);
102 return bprop_fg;
103 }
104 };
105
TEST_F(TestKPynative,test_simple_add)106 TEST_F(TestKPynative, test_simple_add) {
107 auto primal_fg = BuildPrimalFuncGraph("test_simple_add");
108 resource->manager()->KeepRoots({primal_fg});
109
110 auto bprop_fg = BuildBpropFuncGraph(primal_fg);
111 resource->manager()->KeepRoots({bprop_fg});
112 }
113
TEST_F(TestKPynative,test_stop_gradient)114 TEST_F(TestKPynative, test_stop_gradient) {
115 auto primal_fg = BuildStopGradient("test_stop_gradient");
116 resource->manager()->KeepRoots({primal_fg});
117
118 auto bprop_fg = BuildBpropFuncGraph(primal_fg);
119 resource->manager()->KeepRoots({bprop_fg});
120 }
121 } // namespace ad
122 } // namespace mindspore
123