• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <string>
17 #include "ir/meta_func_graph.h"
18 #include "ir/func_graph.h"
19 
20 #include "include/common/pybind_api/api_register.h"
21 
22 namespace mindspore {
UpdateFuncGraphHyperParams(const FuncGraphPtr & func_graph,const py::dict & params_init)23 py::dict UpdateFuncGraphHyperParams(const FuncGraphPtr &func_graph, const py::dict &params_init) {
24   py::dict hyper_params;
25   for (const auto &param : func_graph->parameters()) {
26     auto param_node = param->cast<ParameterPtr>();
27     MS_EXCEPTION_IF_NULL(param_node);
28     py::str param_name = py::str(param_node->name());
29     if (param_node->has_default()) {
30       const char kModelName[] = "mindspore.common.parameter";
31       const char kParamName[] = "Parameter";
32       const char kParamInfoName[] = "ParamInfo";
33       const py::module &mod = py::module::import(kModelName);
34       const py::object &ParamInit = mod.attr(kParamName);
35       const py::object &ParamInfoInit = mod.attr(kParamInfoName);
36       const auto &old_value = param_node->default_param()->cast<tensor::TensorPtr>();
37       MS_EXCEPTION_IF_NULL(old_value);
38       py::object new_param;
39 
40       if (params_init.contains(param_name)) {
41         const auto &new_value = params_init[param_name].cast<tensor::TensorPtr>();
42         MS_EXCEPTION_IF_NULL(new_value);
43         if (new_value->shape() != old_value->shape() || new_value->data_type() != old_value->data_type()) {
44           MS_EXCEPTION(ValueError)
45             << "Only support update parameter by Tensor or Parameter with same shape and dtype as it. "
46                "The parameter '"
47             << param_name.cast<std::string>() << "' has shape " << old_value->shape() << " and dtype "
48             << TypeIdLabel(old_value->data_type()) << ", but got the update value with shape " << new_value->shape()
49             << " and dtype " << TypeIdLabel(new_value->data_type()) << ".";
50         }
51         new_param = ParamInit(*new_value);
52       } else {
53         new_param = ParamInit(*old_value);
54       }
55       py::setattr(new_param, "param_info", ParamInfoInit(old_value->param_info()));
56       param_node->set_default_param(new_param.cast<tensor::TensorPtr>());
57       hyper_params[param_name] = new_param;
58     }
59   }
60   return hyper_params;
61 }
62 
RegFuncGraph(const py::module * m)63 void RegFuncGraph(const py::module *m) {
64   // Define python "FuncGraph" class
65   (void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph")
66     .def(py::init())
67     .def("str", &FuncGraph::ToString, "Get FuncGraph string representation.")
68     .def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph");
69 }
70 
RegMetaFuncGraph(const py::module * m)71 void RegMetaFuncGraph(const py::module *m) {
72   // Define python "MetaFuncGraph_" class
73   (void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_")
74     .def("set_signatures", &MetaFuncGraph::set_signatures, "Set primitive inputs signature.");
75 }
76 
RegUpdateFuncGraphHyperParams(py::module * m)77 void RegUpdateFuncGraphHyperParams(py::module *m) {
78   (void)m->def("update_func_graph_hyper_params", &UpdateFuncGraphHyperParams, py::arg("func_graph"),
79                py::arg("params_init"), "Update FuncGraph hyper parameters, and return the updated parameters.");
80 }
81 }  // namespace mindspore
82