• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <iostream>
17 #include <unordered_map>
18 
19 #include "frontend/optimizer/ad/kpynative.h"
20 #include "common/common_test.h"
21 #include "common/py_func_graph_fetcher.h"
22 #include "ir/manager.h"
23 #include "ir/value.h"
24 #include "ir/func_graph_cloner.h"
25 #include "utils/log_adapter.h"
26 #include "ir/graph_utils.h"
27 #include "pipeline/jit/resource.h"
28 #include "pipeline/jit/parse/parse.h"
29 #include "debug/anf_ir_utils.h"
30 #include "frontend/operator/ops.h"
31 
32 namespace mindspore {
33 namespace ad {
34 class TestKPynative : public UT::Common {
35  public:
36   pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
37 
38  protected:
BuildArg()39   AbstractBasePtr BuildArg() {
40     std::vector<int64_t> shp = {2, 2};
41     tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp);
42     auto abstract = tensor->ToAbstract();
43     return abstract;
44   }
45 
BuildPrimalFuncGraph(const std::string & testCase)46   FuncGraphPtr BuildPrimalFuncGraph(const std::string &testCase) {
47     auto g = std::make_shared<FuncGraph>();
48     auto x = g->add_parameter();
49     auto y = g->add_parameter();
50     x->set_abstract(BuildArg());
51     y->set_abstract(BuildArg());
52     auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y});
53     c_node->set_abstract(BuildArg());
54     g->set_output(c_node);
55     return g;
56   }
57 
58   // a = x * y
59   // b = stop_gradient(a)
60   // c = b * y
61   // return c
BuildStopGradient(const std::string & testCase)62   FuncGraphPtr BuildStopGradient(const std::string &testCase) {
63     auto g = std::make_shared<FuncGraph>();
64     auto x = g->add_parameter();
65     x->debug_info()->set_name("x");
66     auto y = g->add_parameter();
67     y->debug_info()->set_name("y");
68     x->set_abstract(BuildArg());
69     y->set_abstract(BuildArg());
70     auto a_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y});
71     a_node->set_abstract(BuildArg());
72     auto b_node = g->NewCNode({NewValueNode(prim::kPrimStopGradient), a_node});
73     b_node->set_abstract(BuildArg());
74     auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), b_node, y});
75     c_node->set_abstract(BuildArg());
76     auto d_node =
77       g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), a_node, c_node});
78     d_node->set_abstract(BuildArg());
79     g->set_output(d_node);
80     return g;
81   }
82 
BuildBpropFuncGraph(const FuncGraphPtr & primal_fg)83   FuncGraphPtr BuildBpropFuncGraph(const FuncGraphPtr &primal_fg) {
84     auto input_params = primal_fg->parameters();
85     std::vector<ValuePtr> input_param_values;
86     std::for_each(input_params.begin(), input_params.end(),
87                   [&](const AnfNodePtr &param) { input_param_values.emplace_back(param->abstract()->BuildValue()); });
88     auto k_pynative_cell = GradPynativeCellBegin(input_params, input_param_values);
89     auto node_list = TopoSort(primal_fg->output());
90     for (auto node : node_list) {
91       if (node->isa<CNode>()) {
92         auto c_node = node->cast<CNodePtr>();
93         auto out = c_node->abstract()->GetValueTrack();
94         ValuePtrList args;
95         for (size_t i = 1; i < c_node->inputs().size(); ++i) {
96           args.push_back(c_node->input(i)->abstract()->GetValueTrack());
97         }
98         GradPynativeOp(k_pynative_cell, c_node, args, out);
99       }
100     }
101     auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, true, false, false, true);
102     return bprop_fg;
103   }
104 };
105 
TEST_F(TestKPynative,test_simple_add)106 TEST_F(TestKPynative, test_simple_add) {
107   auto primal_fg = BuildPrimalFuncGraph("test_simple_add");
108   resource->manager()->KeepRoots({primal_fg});
109 
110   auto bprop_fg = BuildBpropFuncGraph(primal_fg);
111   resource->manager()->KeepRoots({bprop_fg});
112 }
113 
TEST_F(TestKPynative,test_stop_gradient)114 TEST_F(TestKPynative, test_stop_gradient) {
115   auto primal_fg = BuildStopGradient("test_stop_gradient");
116   resource->manager()->KeepRoots({primal_fg});
117 
118   auto bprop_fg = BuildBpropFuncGraph(primal_fg);
119   resource->manager()->KeepRoots({bprop_fg});
120 }
121 }  // namespace ad
122 }  // namespace mindspore
123