1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <string>
17 #include "ir/meta_func_graph.h"
18 #include "ir/func_graph.h"
19
20 #include "include/common/pybind_api/api_register.h"
21
22 namespace mindspore {
UpdateFuncGraphHyperParams(const FuncGraphPtr & func_graph,const py::dict & params_init)23 py::dict UpdateFuncGraphHyperParams(const FuncGraphPtr &func_graph, const py::dict ¶ms_init) {
24 py::dict hyper_params;
25 for (const auto ¶m : func_graph->parameters()) {
26 auto param_node = param->cast<ParameterPtr>();
27 MS_EXCEPTION_IF_NULL(param_node);
28 py::str param_name = py::str(param_node->name());
29 if (param_node->has_default()) {
30 const char kModelName[] = "mindspore.common.parameter";
31 const char kParamName[] = "Parameter";
32 const char kParamInfoName[] = "ParamInfo";
33 const py::module &mod = py::module::import(kModelName);
34 const py::object &ParamInit = mod.attr(kParamName);
35 const py::object &ParamInfoInit = mod.attr(kParamInfoName);
36 const auto &old_value = param_node->default_param()->cast<tensor::TensorPtr>();
37 MS_EXCEPTION_IF_NULL(old_value);
38 py::object new_param;
39
40 if (params_init.contains(param_name)) {
41 const auto &new_value = params_init[param_name].cast<tensor::TensorPtr>();
42 MS_EXCEPTION_IF_NULL(new_value);
43 if (new_value->shape() != old_value->shape() || new_value->data_type() != old_value->data_type()) {
44 MS_EXCEPTION(ValueError)
45 << "Only support update parameter by Tensor or Parameter with same shape and dtype as it. "
46 "The parameter '"
47 << param_name.cast<std::string>() << "' has shape " << old_value->shape() << " and dtype "
48 << TypeIdLabel(old_value->data_type()) << ", but got the update value with shape " << new_value->shape()
49 << " and dtype " << TypeIdLabel(new_value->data_type()) << ".";
50 }
51 new_param = ParamInit(*new_value);
52 } else {
53 new_param = ParamInit(*old_value);
54 }
55 py::setattr(new_param, "param_info", ParamInfoInit(old_value->param_info()));
56 param_node->set_default_param(new_param.cast<tensor::TensorPtr>());
57 hyper_params[param_name] = new_param;
58 }
59 }
60 return hyper_params;
61 }
62
RegFuncGraph(const py::module * m)63 void RegFuncGraph(const py::module *m) {
64 // Define python "FuncGraph" class
65 (void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph")
66 .def(py::init())
67 .def("str", &FuncGraph::ToString, "Get FuncGraph string representation.")
68 .def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph");
69 }
70
RegMetaFuncGraph(const py::module * m)71 void RegMetaFuncGraph(const py::module *m) {
72 // Define python "MetaFuncGraph_" class
73 (void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_")
74 .def("set_signatures", &MetaFuncGraph::set_signatures, "Set primitive inputs signature.");
75 }
76
RegUpdateFuncGraphHyperParams(py::module * m)77 void RegUpdateFuncGraphHyperParams(py::module *m) {
78 (void)m->def("update_func_graph_hyper_params", &UpdateFuncGraphHyperParams, py::arg("func_graph"),
79 py::arg("params_init"), "Update FuncGraph hyper parameters, and return the updated parameters.");
80 }
81 } // namespace mindspore
82